SHA888 commited on
Commit
a408d71
·
verified ·
1 Parent(s): 82b093c

Create Training Space: LoRA NER trainer (Gradio UI) (#1)

Browse files

- Create Training Space: LoRA NER trainer (Gradio UI) (67820d34a1e9c40abbbdf4179a2d0d5a04ee290b)

Files changed (3) hide show
  1. README.md +21 -6
  2. app.py +301 -0
  3. requirements.txt +9 -0
README.md CHANGED
@@ -1,12 +1,27 @@
1
  ---
2
- title: Med Vllm Train
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.46.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Med vLLM Train (LoRA NER)
3
+ emoji: 🧪
4
+ colorFrom: purple
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Med vLLM Train (LoRA NER)
13
+
14
+ This Space fine-tunes a token-classification model with LoRA (PEFT) using a public dataset. It pushes checkpoints to the umbrella model repo under `checkpoints/` as a Pull Request so results persist.
15
+
16
+ - Default base model: `dmis-lab/biobert-base-cased-v1.2`
17
+ - Default dataset (robust demo): `conll2003`
18
+ - Target repo for checkpoints: `Junaidi-AI/med-vllm`
19
+
20
+ You can change the base model and dataset in the UI. Medical datasets (e.g., `bc5cdr`, `ncbi_disease`) might require extra preprocessing which can be added later.
21
+
22
+ ## Run locally
23
+
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ python app.py
27
+ ```
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import threading
4
+ from typing import Optional, Dict, Any
5
+
6
+ import gradio as gr
7
+
8
+ from huggingface_hub import HfApi
9
+
10
+
11
+ DEFAULT_BASE_MODEL = "dmis-lab/biobert-base-cased-v1.2"
12
+ DEFAULT_DATASET = "conll2003" # fallback; medical sets may require custom preprocessing
13
+ TARGET_REPO = os.getenv("MEDVLLM_TARGET_REPO", "Junaidi-AI/med-vllm")
14
+
15
+
16
+ def _train_ner_lora(
17
+ base_model: str,
18
+ dataset_name: str,
19
+ output_dir: str,
20
+ num_train_epochs: int = 1,
21
+ per_device_train_batch_size: int = 8,
22
+ learning_rate: float = 2e-5,
23
+ lora_r: int = 8,
24
+ lora_alpha: int = 16,
25
+ lora_dropout: float = 0.1,
26
+ log_cb=None,
27
+ ) -> Dict[str, Any]:
28
+ """
29
+ Minimal LoRA token-classification trainer.
30
+ Uses conll2003 by default to be robust in Spaces. Extend to medical datasets later.
31
+ """
32
+ from datasets import load_dataset
33
+ from transformers import (
34
+ AutoTokenizer,
35
+ AutoModelForTokenClassification,
36
+ DataCollatorForTokenClassification,
37
+ TrainingArguments,
38
+ Trainer,
39
+ )
40
+ from transformers.trainer_utils import set_seed
41
+ from seqeval.metrics import f1_score, accuracy_score, precision_score, recall_score
42
+ from peft import LoraConfig, get_peft_model, TaskType
43
+
44
+ def log(msg: str):
45
+ if log_cb:
46
+ log_cb(msg)
47
+ else:
48
+ print(msg)
49
+
50
+ set_seed(42)
51
+
52
+ log(f"Loading dataset: {dataset_name}")
53
+ ds = load_dataset(dataset_name)
54
+
55
+ if "train" not in ds:
56
+ raise RuntimeError("Dataset must have a train split")
57
+
58
+ # Expecting CoNLL-style fields
59
+ features = ds["train"].features
60
+ token_col = "tokens" if "tokens" in features else None
61
+ tag_col = "ner_tags" if "ner_tags" in features else None
62
+ if not token_col or not tag_col:
63
+ raise RuntimeError(
64
+ "Dataset does not expose 'tokens' and 'ner_tags'. For medical datasets, add custom preprocessing."
65
+ )
66
+
67
+ label_list = ds["train"].features[tag_col].feature.names
68
+ id2label = {i: l for i, l in enumerate(label_list)}
69
+ label2id = {l: i for i, l in enumerate(label_list)}
70
+
71
+ log(f"Loading tokenizer/model: {base_model}")
72
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
73
+ base = AutoModelForTokenClassification.from_pretrained(
74
+ base_model, num_labels=len(label_list), id2label=id2label, label2id=label2id
75
+ )
76
+
77
+ peft_config = LoraConfig(
78
+ task_type=TaskType.TOKEN_CLS,
79
+ inference_mode=False,
80
+ r=lora_r,
81
+ lora_alpha=lora_alpha,
82
+ lora_dropout=lora_dropout,
83
+ )
84
+ model = get_peft_model(base, peft_config)
85
+
86
+ # Tokenize with alignment
87
+ def tokenize_align(batch):
88
+ tokenized = tokenizer(
89
+ batch[token_col], is_split_into_words=True, truncation=True, padding=False
90
+ )
91
+ labels = []
92
+ for i, word_ids in enumerate(tokenized.word_ids(batch_index=None)):
93
+ # The Transformers tokenizer returns word_ids per example only if batch=False; we do per-example loop
94
+ pass
95
+ # Re-run per example to ensure correct mapping
96
+ new_input_ids = []
97
+ new_labels = []
98
+ for tokens, tags in zip(batch[token_col], batch[tag_col]):
99
+ enc = tokenizer(tokens, is_split_into_words=True, truncation=True, padding=False)
100
+ word_ids = enc.word_ids()
101
+ lab = []
102
+ prev_wid = None
103
+ for wid in word_ids:
104
+ if wid is None:
105
+ lab.append(-100)
106
+ else:
107
+ tag_id = tags[wid]
108
+ # Only label first subword
109
+ if wid != prev_wid:
110
+ lab.append(tag_id)
111
+ prev_wid = wid
112
+ else:
113
+ lab.append(-100)
114
+ new_input_ids.append(enc["input_ids"]) # unused but keeps shape; collator will pad
115
+ new_labels.append(lab)
116
+ enc = tokenizer(
117
+ batch[token_col], is_split_into_words=True, truncation=True, padding=True
118
+ )
119
+ enc["labels"] = new_labels
120
+ return enc
121
+
122
+ log("Tokenizing dataset...")
123
+ tokenized = ds.map(tokenize_align, batched=True)
124
+
125
+ data_collator = DataCollatorForTokenClassification(tokenizer)
126
+
127
+ def compute_metrics(p):
128
+ preds, labels = p
129
+ preds = preds.argmax(-1)
130
+ true_predictions = []
131
+ true_labels = []
132
+ for pred, lab in zip(preds, labels):
133
+ curr_pred = []
134
+ curr_lab = []
135
+ for p_i, l_i in zip(pred, lab):
136
+ if l_i != -100:
137
+ curr_pred.append(id2label[int(p_i)])
138
+ curr_lab.append(id2label[int(l_i)])
139
+ true_predictions.append(curr_pred)
140
+ true_labels.append(curr_lab)
141
+ return {
142
+ "f1": f1_score(true_labels, true_predictions),
143
+ "precision": precision_score(true_labels, true_predictions),
144
+ "recall": recall_score(true_labels, true_predictions),
145
+ "accuracy": accuracy_score(true_labels, true_predictions),
146
+ }
147
+
148
+ training_args = TrainingArguments(
149
+ output_dir=output_dir,
150
+ per_device_train_batch_size=per_device_train_batch_size,
151
+ per_device_eval_batch_size=per_device_train_batch_size,
152
+ learning_rate=learning_rate,
153
+ num_train_epochs=num_train_epochs,
154
+ evaluation_strategy="epoch",
155
+ save_strategy="epoch",
156
+ logging_steps=10,
157
+ report_to=[],
158
+ fp16=False,
159
+ )
160
+
161
+ trainer = Trainer(
162
+ model=model,
163
+ args=training_args,
164
+ train_dataset=tokenized["train"],
165
+ eval_dataset=tokenized.get("validation") or tokenized.get("dev") or tokenized["test"],
166
+ tokenizer=tokenizer,
167
+ data_collator=data_collator,
168
+ compute_metrics=compute_metrics,
169
+ )
170
+
171
+ log("Starting training...")
172
+ trainer.train()
173
+
174
+ log("Saving adapter...")
175
+ model.save_pretrained(output_dir)
176
+ tokenizer.save_pretrained(output_dir)
177
+
178
+ # Push to the umbrella repo under checkpoints/
179
+ api = HfApi()
180
+ run_name = os.path.basename(output_dir.rstrip("/"))
181
+ path_in_repo = f"checkpoints/ner-{run_name}"
182
+ log(f"Pushing to {TARGET_REPO}:{path_in_repo}")
183
+ commit = api.upload_folder(
184
+ repo_id=TARGET_REPO,
185
+ repo_type="model",
186
+ folder_path=output_dir,
187
+ path_in_repo=path_in_repo,
188
+ commit_message=f"Add NER LoRA checkpoint ({run_name})",
189
+ create_pr=True,
190
+ )
191
+ log(f"Pushed: {commit}")
192
+ return {"commit": str(commit), "path_in_repo": path_in_repo}
193
+
194
+
195
+ class TrainerThread:
196
+ def __init__(self):
197
+ self.thread: Optional[threading.Thread] = None
198
+ self.logs = ""
199
+ self.result: Optional[Dict[str, Any]] = None
200
+ self.error: Optional[str] = None
201
+
202
+ def _log(self, msg: str):
203
+ self.logs += msg + "\n"
204
+
205
+ def start(self, **kwargs):
206
+ if self.thread and self.thread.is_alive():
207
+ raise gr.Error("Training is already running")
208
+
209
+ def target():
210
+ try:
211
+ self._log("Initializing training...")
212
+ res = _train_ner_lora(log_cb=self._log, **kwargs)
213
+ self.result = res
214
+ self._log("Training complete")
215
+ except Exception as e:
216
+ self.error = str(e)
217
+ self._log(f"ERROR: {e}")
218
+
219
+ self.logs = ""
220
+ self.result = None
221
+ self.error = None
222
+ self.thread = threading.Thread(target=target, daemon=True)
223
+ self.thread.start()
224
+
225
+ def status(self):
226
+ running = self.thread.is_alive() if self.thread else False
227
+ return running, self.logs, self.result, self.error
228
+
229
+
230
+ TRAINER = TrainerThread()
231
+
232
+
233
+ def build_ui():
234
+ with gr.Blocks(title="Med vLLM Train (LoRA NER)") as demo:
235
+ gr.Markdown(
236
+ f"""
237
+ # Med vLLM Train (LoRA NER)
238
+ This Space fine-tunes a token-classification model with LoRA.
239
+
240
+ - Base model default: `{DEFAULT_BASE_MODEL}`
241
+ - Dataset default: `{DEFAULT_DATASET}` (robust demo). Medical sets like `bc5cdr`/`ncbi_disease` may require custom preprocessing.
242
+ - Checkpoints will be pushed to `{TARGET_REPO}` under `checkpoints/` as a PR.
243
+ """
244
+ )
245
+ with gr.Row():
246
+ base_model = gr.Textbox(value=DEFAULT_BASE_MODEL, label="Base model")
247
+ dataset_name = gr.Textbox(value=DEFAULT_DATASET, label="Dataset (token classification)")
248
+ with gr.Row():
249
+ epochs = gr.Slider(minimum=1, maximum=3, step=1, value=1, label="Epochs")
250
+ batch = gr.Slider(minimum=4, maximum=16, step=2, value=8, label="Batch size")
251
+ lr = gr.Textbox(value="2e-5", label="Learning rate")
252
+ with gr.Row():
253
+ lora_r = gr.Slider(minimum=4, maximum=32, step=2, value=8, label="LoRA r")
254
+ lora_alpha = gr.Slider(minimum=8, maximum=64, step=8, value=16, label="LoRA alpha")
255
+ lora_dropout = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, value=0.1, label="LoRA dropout")
256
+ with gr.Row():
257
+ run_name = gr.Textbox(value=f"run-{int(time.time())}", label="Run name (folder)")
258
+ with gr.Row():
259
+ start_btn = gr.Button("Start Training")
260
+ status_btn = gr.Button("Refresh Status")
261
+ logs = gr.Textbox(label="Logs", lines=18)
262
+ result = gr.Textbox(label="Result / Commit info")
263
+
264
+ def on_start(bm, ds, ep, bs, lr_s, r, alpha, drop, rn):
265
+ try:
266
+ out_dir = os.path.join("outputs", rn)
267
+ os.makedirs(out_dir, exist_ok=True)
268
+ TRAINER.start(
269
+ base_model=bm,
270
+ dataset_name=ds,
271
+ output_dir=out_dir,
272
+ num_train_epochs=int(ep),
273
+ per_device_train_batch_size=int(bs),
274
+ learning_rate=float(lr_s),
275
+ lora_r=int(r),
276
+ lora_alpha=int(alpha),
277
+ lora_dropout=float(drop),
278
+ )
279
+ return "Started"
280
+ except Exception as e:
281
+ return f"ERROR starting: {e}"
282
+
283
+ def on_status():
284
+ running, l, res, err = TRAINER.status()
285
+ info = "Running" if running else ("Error" if err else "Idle/Done")
286
+ res_s = str(res) if res else ""
287
+ return f"[{info}]\n" + l, res_s
288
+
289
+ start_btn.click(
290
+ on_start,
291
+ inputs=[base_model, dataset_name, epochs, batch, lr, lora_r, lora_alpha, lora_dropout, run_name],
292
+ outputs=[logs],
293
+ )
294
+ status_btn.click(on_status, outputs=[logs, result])
295
+
296
+ return demo
297
+
298
+
299
+ if __name__ == "__main__":
300
+ ui = build_ui()
301
+ ui.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.41.0
3
+ datasets>=2.18.0
4
+ evaluate>=0.4.1
5
+ seqeval>=1.2.2
6
+ peft>=0.11.0
7
+ accelerate>=0.28.0
8
+ huggingface_hub>=0.19
9
+ pyyaml>=6.0