SHA888 commited on
Commit
2ba4e0c
·
verified ·
1 Parent(s): 82b093c

Enhance: generic token/tag columns, metrics in PR description, publish med-vllm-* variant

Browse files
Files changed (1) hide show
  1. app.py +344 -0
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import threading
4
+ from typing import Optional, Dict, Any
5
+
6
+ import gradio as gr
7
+
8
+ from huggingface_hub import HfApi, create_repo
9
+
10
+
11
+ DEFAULT_BASE_MODEL = "dmis-lab/biobert-base-cased-v1.2"
12
+ DEFAULT_DATASET = "conll2003" # fallback; medical sets may require custom preprocessing
13
+ TARGET_REPO = os.getenv("MEDVLLM_TARGET_REPO", "Junaidi-AI/med-vllm")
14
+
15
+
16
+ def _train_ner_lora(
17
+ base_model: str,
18
+ dataset_name: str,
19
+ output_dir: str,
20
+ num_train_epochs: int = 1,
21
+ per_device_train_batch_size: int = 8,
22
+ learning_rate: float = 2e-5,
23
+ lora_r: int = 8,
24
+ lora_alpha: int = 16,
25
+ lora_dropout: float = 0.1,
26
+ log_cb=None,
27
+ ) -> Dict[str, Any]:
28
+ """
29
+ Minimal LoRA token-classification trainer.
30
+ Uses conll2003 by default to be robust in Spaces. Extend to medical datasets later.
31
+ """
32
+ from datasets import load_dataset
33
+ from transformers import (
34
+ AutoTokenizer,
35
+ AutoModelForTokenClassification,
36
+ DataCollatorForTokenClassification,
37
+ TrainingArguments,
38
+ Trainer,
39
+ )
40
+ from transformers.trainer_utils import set_seed
41
+ from seqeval.metrics import f1_score, accuracy_score, precision_score, recall_score
42
+ from peft import LoraConfig, get_peft_model, TaskType
43
+
44
+ def log(msg: str):
45
+ if log_cb:
46
+ log_cb(msg)
47
+ else:
48
+ print(msg)
49
+
50
+ set_seed(42)
51
+
52
+ log(f"Loading dataset: {dataset_name}")
53
+ ds = load_dataset(dataset_name)
54
+
55
+ if "train" not in ds:
56
+ raise RuntimeError("Dataset must have a train split")
57
+
58
+ # Detect token and label columns across common schemas
59
+ features = ds["train"].features
60
+ token_candidates = ["tokens", "words"]
61
+ tag_candidates = ["ner_tags", "tags", "labels", "ner_tags_general"]
62
+ token_col = next((c for c in token_candidates if c in features), None)
63
+ tag_col = next((c for c in tag_candidates if c in features), None)
64
+ if not token_col or not tag_col:
65
+ raise RuntimeError(
66
+ "Dataset must provide token and tag columns. Looked for tokens/words and ner_tags/tags/labels."
67
+ )
68
+
69
+ label_list = ds["train"].features[tag_col].feature.names
70
+ id2label = {i: l for i, l in enumerate(label_list)}
71
+ label2id = {l: i for i, l in enumerate(label_list)}
72
+
73
+ log(f"Loading tokenizer/model: {base_model}")
74
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
75
+ base = AutoModelForTokenClassification.from_pretrained(
76
+ base_model, num_labels=len(label_list), id2label=id2label, label2id=label2id
77
+ )
78
+
79
+ peft_config = LoraConfig(
80
+ task_type=TaskType.TOKEN_CLS,
81
+ inference_mode=False,
82
+ r=lora_r,
83
+ lora_alpha=lora_alpha,
84
+ lora_dropout=lora_dropout,
85
+ )
86
+ model = get_peft_model(base, peft_config)
87
+
88
+ # Tokenize with alignment
89
+ def tokenize_align(batch):
90
+ tokenized = tokenizer(
91
+ batch[token_col], is_split_into_words=True, truncation=True, padding=False
92
+ )
93
+ # Build aligned labels per example
94
+ new_input_ids = []
95
+ new_labels = []
96
+ for tokens, tags in zip(batch[token_col], batch[tag_col]):
97
+ enc = tokenizer(tokens, is_split_into_words=True, truncation=True, padding=False)
98
+ word_ids = enc.word_ids()
99
+ lab = []
100
+ prev_wid = None
101
+ for wid in word_ids:
102
+ if wid is None:
103
+ lab.append(-100)
104
+ else:
105
+ tag_id = tags[wid]
106
+ # Only label first subword
107
+ if wid != prev_wid:
108
+ lab.append(tag_id)
109
+ prev_wid = wid
110
+ else:
111
+ lab.append(-100)
112
+ new_input_ids.append(enc["input_ids"]) # unused but keeps shape; collator will pad
113
+ new_labels.append(lab)
114
+ enc = tokenizer(
115
+ batch[token_col], is_split_into_words=True, truncation=True, padding=True
116
+ )
117
+ enc["labels"] = new_labels
118
+ return enc
119
+
120
+ log("Tokenizing dataset...")
121
+ tokenized = ds.map(tokenize_align, batched=True)
122
+
123
+ data_collator = DataCollatorForTokenClassification(tokenizer)
124
+
125
+ metrics_holder: Dict[str, float] = {}
126
+
127
+ def compute_metrics(p):
128
+ preds, labels = p
129
+ preds = preds.argmax(-1)
130
+ true_predictions = []
131
+ true_labels = []
132
+ for pred, lab in zip(preds, labels):
133
+ curr_pred = []
134
+ curr_lab = []
135
+ for p_i, l_i in zip(pred, lab):
136
+ if l_i != -100:
137
+ curr_pred.append(id2label[int(p_i)])
138
+ curr_lab.append(id2label[int(l_i)])
139
+ true_predictions.append(curr_pred)
140
+ true_labels.append(curr_lab)
141
+ out = {
142
+ "f1": f1_score(true_labels, true_predictions),
143
+ "precision": precision_score(true_labels, true_predictions),
144
+ "recall": recall_score(true_labels, true_predictions),
145
+ "accuracy": accuracy_score(true_labels, true_predictions),
146
+ }
147
+ metrics_holder.update(out)
148
+ return out
149
+
150
+ training_args = TrainingArguments(
151
+ output_dir=output_dir,
152
+ per_device_train_batch_size=per_device_train_batch_size,
153
+ per_device_eval_batch_size=per_device_train_batch_size,
154
+ learning_rate=learning_rate,
155
+ num_train_epochs=num_train_epochs,
156
+ evaluation_strategy="epoch",
157
+ save_strategy="epoch",
158
+ logging_steps=10,
159
+ report_to=[],
160
+ fp16=False,
161
+ )
162
+
163
+ trainer = Trainer(
164
+ model=model,
165
+ args=training_args,
166
+ train_dataset=tokenized["train"],
167
+ eval_dataset=tokenized.get("validation") or tokenized.get("dev") or tokenized["test"],
168
+ tokenizer=tokenizer,
169
+ data_collator=data_collator,
170
+ compute_metrics=compute_metrics,
171
+ )
172
+
173
+ log("Starting training...")
174
+ trainer.train()
175
+
176
+ log("Saving adapter...")
177
+ model.save_pretrained(output_dir)
178
+ tokenizer.save_pretrained(output_dir)
179
+
180
+ # Compose commit description with metrics
181
+ desc_lines = [
182
+ f"base_model: {base_model}",
183
+ f"dataset: {dataset_name}",
184
+ f"epochs: {num_train_epochs}",
185
+ f"batch_size: {per_device_train_batch_size}",
186
+ f"learning_rate: {learning_rate}",
187
+ f"lora_r: {lora_r}",
188
+ f"lora_alpha: {lora_alpha}",
189
+ f"lora_dropout: {lora_dropout}",
190
+ "",
191
+ "metrics:",
192
+ *(f"- {k}: {v:.4f}" for k, v in metrics_holder.items()),
193
+ ]
194
+ commit_description = "\n".join(desc_lines)
195
+
196
+ # Push to the umbrella repo under checkpoints/
197
+ api = HfApi()
198
+ run_name = os.path.basename(output_dir.rstrip("/"))
199
+ path_in_repo = f"checkpoints/ner-{run_name}"
200
+ log(f"Pushing to {TARGET_REPO}:{path_in_repo}")
201
+ commit = api.upload_folder(
202
+ repo_id=TARGET_REPO,
203
+ repo_type="model",
204
+ folder_path=output_dir,
205
+ path_in_repo=path_in_repo,
206
+ commit_message=f"Add NER LoRA checkpoint ({run_name})",
207
+ commit_description=commit_description,
208
+ create_pr=True,
209
+ )
210
+ log(f"Pushed: {commit}")
211
+
212
+ # Also publish to a dedicated med-vllm-* variant repo
213
+ try:
214
+ base_short = base_model.split("/")[-1].replace(" ", "-").lower()
215
+ ds_short = dataset_name.split("/")[-1].replace(" ", "-").lower()
216
+ variant_name = f"Junaidi-AI/med-vllm-ner-{ds_short}-{base_short}-lora-v1"
217
+ log(f"Ensuring repo exists: {variant_name}")
218
+ try:
219
+ create_repo(repo_id=variant_name, repo_type="model", exist_ok=True, private=False)
220
+ except Exception:
221
+ pass
222
+ commit2 = api.upload_folder(
223
+ repo_id=variant_name,
224
+ repo_type="model",
225
+ folder_path=output_dir,
226
+ path_in_repo=".",
227
+ commit_message=f"Initial LoRA checkpoint from {base_model} on {dataset_name}",
228
+ commit_description=commit_description,
229
+ create_pr=False,
230
+ )
231
+ log(f"Variant published: {commit2}")
232
+ except Exception as e:
233
+ log(f"Warning: failed to publish variant repo: {e}")
234
+
235
+ return {"commit": str(commit), "path_in_repo": path_in_repo, "metrics": metrics_holder}
236
+
237
+
238
+ class TrainerThread:
239
+ def __init__(self):
240
+ self.thread: Optional[threading.Thread] = None
241
+ self.logs = ""
242
+ self.result: Optional[Dict[str, Any]] = None
243
+ self.error: Optional[str] = None
244
+
245
+ def _log(self, msg: str):
246
+ self.logs += msg + "\n"
247
+
248
+ def start(self, **kwargs):
249
+ if self.thread and self.thread.is_alive():
250
+ raise gr.Error("Training is already running")
251
+
252
+ def target():
253
+ try:
254
+ self._log("Initializing training...")
255
+ res = _train_ner_lora(log_cb=self._log, **kwargs)
256
+ self.result = res
257
+ self._log("Training complete")
258
+ except Exception as e:
259
+ self.error = str(e)
260
+ self._log(f"ERROR: {e}")
261
+
262
+ self.logs = ""
263
+ self.result = None
264
+ self.error = None
265
+ self.thread = threading.Thread(target=target, daemon=True)
266
+ self.thread.start()
267
+
268
+ def status(self):
269
+ running = self.thread.is_alive() if self.thread else False
270
+ return running, self.logs, self.result, self.error
271
+
272
+
273
+ TRAINER = TrainerThread()
274
+
275
+
276
+ def build_ui():
277
+ with gr.Blocks(title="Med vLLM Train (LoRA NER)") as demo:
278
+ gr.Markdown(
279
+ f"""
280
+ # Med vLLM Train (LoRA NER)
281
+ This Space fine-tunes a token-classification model with LoRA.
282
+
283
+ - Base model default: `{DEFAULT_BASE_MODEL}`
284
+ - Dataset default: `{DEFAULT_DATASET}` (robust demo). Medical sets like `bc5cdr`/`ncbi_disease` may require custom preprocessing.
285
+ - Checkpoints will be pushed to `{TARGET_REPO}` under `checkpoints/` as a PR.
286
+ """
287
+ )
288
+ with gr.Row():
289
+ base_model = gr.Textbox(value=DEFAULT_BASE_MODEL, label="Base model")
290
+ dataset_name = gr.Textbox(value=DEFAULT_DATASET, label="Dataset (token classification)")
291
+ with gr.Row():
292
+ epochs = gr.Slider(minimum=1, maximum=3, step=1, value=1, label="Epochs")
293
+ batch = gr.Slider(minimum=4, maximum=16, step=2, value=8, label="Batch size")
294
+ lr = gr.Textbox(value="2e-5", label="Learning rate")
295
+ with gr.Row():
296
+ lora_r = gr.Slider(minimum=4, maximum=32, step=2, value=8, label="LoRA r")
297
+ lora_alpha = gr.Slider(minimum=8, maximum=64, step=8, value=16, label="LoRA alpha")
298
+ lora_dropout = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, value=0.1, label="LoRA dropout")
299
+ with gr.Row():
300
+ run_name = gr.Textbox(value=f"run-{int(time.time())}", label="Run name (folder)")
301
+ with gr.Row():
302
+ start_btn = gr.Button("Start Training")
303
+ status_btn = gr.Button("Refresh Status")
304
+ logs = gr.Textbox(label="Logs", lines=18)
305
+ result = gr.Textbox(label="Result / Commit info")
306
+
307
+ def on_start(bm, ds, ep, bs, lr_s, r, alpha, drop, rn):
308
+ try:
309
+ out_dir = os.path.join("outputs", rn)
310
+ os.makedirs(out_dir, exist_ok=True)
311
+ TRAINER.start(
312
+ base_model=bm,
313
+ dataset_name=ds,
314
+ output_dir=out_dir,
315
+ num_train_epochs=int(ep),
316
+ per_device_train_batch_size=int(bs),
317
+ learning_rate=float(lr_s),
318
+ lora_r=int(r),
319
+ lora_alpha=int(alpha),
320
+ lora_dropout=float(drop),
321
+ )
322
+ return "Started"
323
+ except Exception as e:
324
+ return f"ERROR starting: {e}"
325
+
326
+ def on_status():
327
+ running, l, res, err = TRAINER.status()
328
+ info = "Running" if running else ("Error" if err else "Idle/Done")
329
+ res_s = str(res) if res else ""
330
+ return f"[{info}]\n" + l, res_s
331
+
332
+ start_btn.click(
333
+ on_start,
334
+ inputs=[base_model, dataset_name, epochs, batch, lr, lora_r, lora_alpha, lora_dropout, run_name],
335
+ outputs=[logs],
336
+ )
337
+ status_btn.click(on_status, outputs=[logs, result])
338
+
339
+ return demo
340
+
341
+
342
+ if __name__ == "__main__":
343
+ ui = build_ui()
344
+ ui.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))