Spaces:
Sleeping
Sleeping
| # app.py ββ Biomedical NER demo (full vs. LoRA/CRF) | |
| # -------------------------------------------------- | |
| from __future__ import annotations | |
| import html, logging, warnings | |
| from functools import lru_cache | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| from peft import PeftModel | |
| from huggingface_hub import hf_hub_download # β missing import added | |
| # βββββββββββ silence library warnings βββββββββββ | |
| warnings.filterwarnings("ignore", category=UserWarning, module="peft") | |
| logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) | |
| # βββββββββββ constants βββββββββββ | |
| BASE = "dmis-lab/biobert-base-cased-v1.2" | |
| REPO = "vishalvaka/biobert-finetuned-ner" # one repo; variants in sub-folders | |
| VARIANTS: dict[str, tuple[str, str]] = { | |
| "Full fine-tune" : ("full", "full"), | |
| "LoRA-r32" : ("lora-r32", "lora"), | |
| "LoRA-r32-fast" : ("lora-r32-fast", "lora"), | |
| "LoRA-r16-CRF" : ("lora-r16-crf", "lora"), | |
| "LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"), | |
| } | |
| LABELS = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"] | |
| id2label = {i: lab for i, lab in enumerate(LABELS)} | |
| label2id = {lab: i for i, lab in id2label.items()} | |
| # βββββββββββ model loader (cached) βββββββββββ | |
| def load_model(folder: str, mode: str): | |
| """ | |
| Returns (tokenizer, model) cached per variant. | |
| β’ mode == "full" β load full checkpoint from sub-folder | |
| β’ mode == "lora" β load BASE + LoRA adapter from sub-folder | |
| """ | |
| if mode == "full": | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| REPO, subfolder=folder | |
| ) | |
| tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder) | |
| # ensure human-readable label maps | |
| model.config.id2label, model.config.label2id = id2label, label2id | |
| return tok, model.eval() | |
| # ---------- LoRA (with or without CRF) ---------- | |
| base = AutoModelForTokenClassification.from_pretrained( | |
| BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id | |
| ) | |
| model = PeftModel.from_pretrained(base, REPO, subfolder=folder) | |
| tok = AutoTokenizer.from_pretrained(BASE) | |
| # attach CRF/classifier weights **iff the file exists** | |
| try: | |
| if "crf" in folder: # only these have it | |
| head_path = hf_hub_download(REPO, | |
| "non_encoder_head.pth", | |
| subfolder=folder, | |
| repo_type="model") | |
| extra = torch.load(head_path, map_location="cpu") | |
| model.load_state_dict(extra, strict=False) | |
| except Exception as e: | |
| warnings.warn(f"[{folder}] couldnβt load CRF head: {e}") | |
| return tok, model.eval() | |
| # βββββββββββ helper to build HTML output βββββββββββ | |
| def build_html(tokens: list[str], labels: list[str]) -> str: | |
| """Merge WordPieces and contiguous I-tokens.""" | |
| segments: list[tuple[str | None, str]] = [] # (entity_tag, text) | |
| cur_tag, buf = None, "" | |
| for tok, lab in zip(tokens, labels): | |
| tag = None if lab == "O" else lab.split("-")[-1] # Chemical / Disease | |
| text = tok[2:] if tok.startswith("##") else tok # drop ## | |
| continuation = tok.startswith("##") or lab.startswith("I-") | |
| if tag == cur_tag and continuation: | |
| buf += text | |
| else: | |
| if buf: | |
| segments.append((cur_tag, buf)) | |
| buf, cur_tag = text, tag | |
| if buf: | |
| segments.append((cur_tag, buf)) | |
| html_out, first = "", True | |
| for tag, chunk in segments: | |
| spacer = "" if first else " " | |
| first = False | |
| chunk = html.escape(chunk) | |
| html_out += spacer + ( | |
| chunk if tag is None else f'<span class="{tag}">{chunk}</span>' | |
| ) | |
| return html_out | |
| # βββββββββββ Gradio inference fn βββββββββββ | |
| def ner(text: str, variant: str): | |
| folder, mode = VARIANTS[variant] | |
| tok, model = load_model(folder, mode) | |
| enc = tok(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**enc).logits.squeeze(0) | |
| ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP | |
| tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1] | |
| labels = [model.config.id2label[i] for i in ids] | |
| return build_html(tokens, labels) | |
| # βββββββββββ UI definition βββββββββββ | |
| CSS = """ | |
| span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px} | |
| span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px} | |
| """ | |
| demo = gr.Interface( | |
| fn=ner, | |
| inputs=[ | |
| gr.Textbox(lines=7, label="Paste biomedical text"), | |
| gr.Radio(list(VARIANTS.keys()), value="LoRA-r32", label="Model variant"), | |
| ], | |
| outputs=gr.HTML(label="Tagged output"), | |
| examples=[ | |
| ["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohn's disease patients."], | |
| ], | |
| css=CSS, | |
| theme=gr.themes.Soft(), # quiet built-in theme, no 404 | |
| cache_examples=False, | |
| title="Biomedical NER β Full vs. LoRA / CRF", | |
| description="Toggle a variant and watch **Chemical** / **Disease** entities light up. " | |
| "Full checkpoints for CRF models, compact adapters for LoRA runs.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |