Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import threading | |
| from typing import Optional, Dict, Any | |
| import gradio as gr | |
| from huggingface_hub import HfApi, create_repo | |
| DEFAULT_BASE_MODEL = "dmis-lab/biobert-base-cased-v1.2" | |
| DEFAULT_DATASET = "conll2003" # fallback; medical sets may require custom preprocessing | |
| TARGET_REPO = os.getenv("MEDVLLM_TARGET_REPO", "Junaidi-AI/med-vllm") | |
| def _train_ner_lora( | |
| base_model: str, | |
| dataset_name: str, | |
| output_dir: str, | |
| num_train_epochs: int = 1, | |
| per_device_train_batch_size: int = 8, | |
| learning_rate: float = 2e-5, | |
| lora_r: int = 8, | |
| lora_alpha: int = 16, | |
| lora_dropout: float = 0.1, | |
| log_cb=None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Minimal LoRA token-classification trainer. | |
| Uses conll2003 by default to be robust in Spaces. Extend to medical datasets later. | |
| """ | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForTokenClassification, | |
| DataCollatorForTokenClassification, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| from transformers.trainer_utils import set_seed | |
| from seqeval.metrics import f1_score, accuracy_score, precision_score, recall_score | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| def log(msg: str): | |
| if log_cb: | |
| log_cb(msg) | |
| else: | |
| print(msg) | |
| set_seed(42) | |
| log(f"Loading dataset: {dataset_name}") | |
| ds = load_dataset(dataset_name) | |
| if "train" not in ds: | |
| raise RuntimeError("Dataset must have a train split") | |
| # Detect token and label columns across common schemas | |
| features = ds["train"].features | |
| token_candidates = ["tokens", "words"] | |
| tag_candidates = ["ner_tags", "tags", "labels", "ner_tags_general"] | |
| token_col = next((c for c in token_candidates if c in features), None) | |
| tag_col = next((c for c in tag_candidates if c in features), None) | |
| if not token_col or not tag_col: | |
| raise RuntimeError( | |
| "Dataset must provide token and tag columns. Looked for tokens/words and ner_tags/tags/labels." | |
| ) | |
| label_list = ds["train"].features[tag_col].feature.names | |
| id2label = {i: l for i, l in enumerate(label_list)} | |
| label2id = {l: i for i, l in enumerate(label_list)} | |
| log(f"Loading tokenizer/model: {base_model}") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| base = AutoModelForTokenClassification.from_pretrained( | |
| base_model, num_labels=len(label_list), id2label=id2label, label2id=label2id | |
| ) | |
| peft_config = LoraConfig( | |
| task_type=TaskType.TOKEN_CLS, | |
| inference_mode=False, | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| ) | |
| model = get_peft_model(base, peft_config) | |
| # Tokenize with alignment | |
| def tokenize_align(batch): | |
| tokenized = tokenizer( | |
| batch[token_col], is_split_into_words=True, truncation=True, padding=False | |
| ) | |
| # Build aligned labels per example | |
| new_input_ids = [] | |
| new_labels = [] | |
| for tokens, tags in zip(batch[token_col], batch[tag_col]): | |
| enc = tokenizer(tokens, is_split_into_words=True, truncation=True, padding=False) | |
| word_ids = enc.word_ids() | |
| lab = [] | |
| prev_wid = None | |
| for wid in word_ids: | |
| if wid is None: | |
| lab.append(-100) | |
| else: | |
| tag_id = tags[wid] | |
| # Only label first subword | |
| if wid != prev_wid: | |
| lab.append(tag_id) | |
| prev_wid = wid | |
| else: | |
| lab.append(-100) | |
| new_input_ids.append(enc["input_ids"]) # unused but keeps shape; collator will pad | |
| new_labels.append(lab) | |
| enc = tokenizer( | |
| batch[token_col], is_split_into_words=True, truncation=True, padding=True | |
| ) | |
| enc["labels"] = new_labels | |
| return enc | |
| log("Tokenizing dataset...") | |
| tokenized = ds.map(tokenize_align, batched=True) | |
| data_collator = DataCollatorForTokenClassification(tokenizer) | |
| metrics_holder: Dict[str, float] = {} | |
| def compute_metrics(p): | |
| preds, labels = p | |
| preds = preds.argmax(-1) | |
| true_predictions = [] | |
| true_labels = [] | |
| for pred, lab in zip(preds, labels): | |
| curr_pred = [] | |
| curr_lab = [] | |
| for p_i, l_i in zip(pred, lab): | |
| if l_i != -100: | |
| curr_pred.append(id2label[int(p_i)]) | |
| curr_lab.append(id2label[int(l_i)]) | |
| true_predictions.append(curr_pred) | |
| true_labels.append(curr_lab) | |
| out = { | |
| "f1": f1_score(true_labels, true_predictions), | |
| "precision": precision_score(true_labels, true_predictions), | |
| "recall": recall_score(true_labels, true_predictions), | |
| "accuracy": accuracy_score(true_labels, true_predictions), | |
| } | |
| metrics_holder.update(out) | |
| return out | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| per_device_eval_batch_size=per_device_train_batch_size, | |
| learning_rate=learning_rate, | |
| num_train_epochs=num_train_epochs, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| logging_steps=10, | |
| report_to=[], | |
| fp16=False, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized.get("validation") or tokenized.get("dev") or tokenized["test"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| log("Starting training...") | |
| trainer.train() | |
| log("Saving adapter...") | |
| model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| # Compose commit description with metrics | |
| desc_lines = [ | |
| f"base_model: {base_model}", | |
| f"dataset: {dataset_name}", | |
| f"epochs: {num_train_epochs}", | |
| f"batch_size: {per_device_train_batch_size}", | |
| f"learning_rate: {learning_rate}", | |
| f"lora_r: {lora_r}", | |
| f"lora_alpha: {lora_alpha}", | |
| f"lora_dropout: {lora_dropout}", | |
| "", | |
| "metrics:", | |
| *(f"- {k}: {v:.4f}" for k, v in metrics_holder.items()), | |
| ] | |
| commit_description = "\n".join(desc_lines) | |
| # Push to the umbrella repo under checkpoints/ | |
| api = HfApi() | |
| run_name = os.path.basename(output_dir.rstrip("/")) | |
| path_in_repo = f"checkpoints/ner-{run_name}" | |
| log(f"Pushing to {TARGET_REPO}:{path_in_repo}") | |
| commit = api.upload_folder( | |
| repo_id=TARGET_REPO, | |
| repo_type="model", | |
| folder_path=output_dir, | |
| path_in_repo=path_in_repo, | |
| commit_message=f"Add NER LoRA checkpoint ({run_name})", | |
| commit_description=commit_description, | |
| create_pr=True, | |
| ) | |
| log(f"Pushed: {commit}") | |
| # Also publish to a dedicated med-vllm-* variant repo | |
| try: | |
| base_short = base_model.split("/")[-1].replace(" ", "-").lower() | |
| ds_short = dataset_name.split("/")[-1].replace(" ", "-").lower() | |
| variant_name = f"Junaidi-AI/med-vllm-ner-{ds_short}-{base_short}-lora-v1" | |
| log(f"Ensuring repo exists: {variant_name}") | |
| try: | |
| create_repo(repo_id=variant_name, repo_type="model", exist_ok=True, private=False) | |
| except Exception: | |
| pass | |
| commit2 = api.upload_folder( | |
| repo_id=variant_name, | |
| repo_type="model", | |
| folder_path=output_dir, | |
| path_in_repo=".", | |
| commit_message=f"Initial LoRA checkpoint from {base_model} on {dataset_name}", | |
| commit_description=commit_description, | |
| create_pr=False, | |
| ) | |
| log(f"Variant published: {commit2}") | |
| except Exception as e: | |
| log(f"Warning: failed to publish variant repo: {e}") | |
| return {"commit": str(commit), "path_in_repo": path_in_repo, "metrics": metrics_holder} | |
| class TrainerThread: | |
| def __init__(self): | |
| self.thread: Optional[threading.Thread] = None | |
| self.logs = "" | |
| self.result: Optional[Dict[str, Any]] = None | |
| self.error: Optional[str] = None | |
| def _log(self, msg: str): | |
| self.logs += msg + "\n" | |
| def start(self, **kwargs): | |
| if self.thread and self.thread.is_alive(): | |
| raise gr.Error("Training is already running") | |
| def target(): | |
| try: | |
| self._log("Initializing training...") | |
| res = _train_ner_lora(log_cb=self._log, **kwargs) | |
| self.result = res | |
| self._log("Training complete") | |
| except Exception as e: | |
| self.error = str(e) | |
| self._log(f"ERROR: {e}") | |
| self.logs = "" | |
| self.result = None | |
| self.error = None | |
| self.thread = threading.Thread(target=target, daemon=True) | |
| self.thread.start() | |
| def status(self): | |
| running = self.thread.is_alive() if self.thread else False | |
| return running, self.logs, self.result, self.error | |
| TRAINER = TrainerThread() | |
| def build_ui(): | |
| with gr.Blocks(title="Med vLLM Train (LoRA NER)") as demo: | |
| gr.Markdown( | |
| f""" | |
| # Med vLLM Train (LoRA NER) | |
| This Space fine-tunes a token-classification model with LoRA. | |
| - Base model default: `{DEFAULT_BASE_MODEL}` | |
| - Dataset default: `{DEFAULT_DATASET}` (robust demo). Medical sets like `bc5cdr`/`ncbi_disease` may require custom preprocessing. | |
| - Checkpoints will be pushed to `{TARGET_REPO}` under `checkpoints/` as a PR. | |
| """ | |
| ) | |
| with gr.Row(): | |
| base_model = gr.Textbox(value=DEFAULT_BASE_MODEL, label="Base model") | |
| dataset_name = gr.Textbox(value=DEFAULT_DATASET, label="Dataset (token classification)") | |
| with gr.Row(): | |
| epochs = gr.Slider(minimum=1, maximum=3, step=1, value=1, label="Epochs") | |
| batch = gr.Slider(minimum=4, maximum=16, step=2, value=8, label="Batch size") | |
| lr = gr.Textbox(value="2e-5", label="Learning rate") | |
| with gr.Row(): | |
| lora_r = gr.Slider(minimum=4, maximum=32, step=2, value=8, label="LoRA r") | |
| lora_alpha = gr.Slider(minimum=8, maximum=64, step=8, value=16, label="LoRA alpha") | |
| lora_dropout = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, value=0.1, label="LoRA dropout") | |
| with gr.Row(): | |
| run_name = gr.Textbox(value=f"run-{int(time.time())}", label="Run name (folder)") | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Training") | |
| status_btn = gr.Button("Refresh Status") | |
| logs = gr.Textbox(label="Logs", lines=18) | |
| result = gr.Textbox(label="Result / Commit info") | |
| def on_start(bm, ds, ep, bs, lr_s, r, alpha, drop, rn): | |
| try: | |
| out_dir = os.path.join("outputs", rn) | |
| os.makedirs(out_dir, exist_ok=True) | |
| TRAINER.start( | |
| base_model=bm, | |
| dataset_name=ds, | |
| output_dir=out_dir, | |
| num_train_epochs=int(ep), | |
| per_device_train_batch_size=int(bs), | |
| learning_rate=float(lr_s), | |
| lora_r=int(r), | |
| lora_alpha=int(alpha), | |
| lora_dropout=float(drop), | |
| ) | |
| return "Started" | |
| except Exception as e: | |
| return f"ERROR starting: {e}" | |
| def on_status(): | |
| running, l, res, err = TRAINER.status() | |
| info = "Running" if running else ("Error" if err else "Idle/Done") | |
| res_s = str(res) if res else "" | |
| return f"[{info}]\n" + l, res_s | |
| start_btn.click( | |
| on_start, | |
| inputs=[base_model, dataset_name, epochs, batch, lr, lora_r, lora_alpha, lora_dropout, run_name], | |
| outputs=[logs], | |
| ) | |
| status_btn.click(on_status, outputs=[logs, result]) | |
| return demo | |
| if __name__ == "__main__": | |
| ui = build_ui() | |
| ui.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) | |