Spaces:
Running
Running
File size: 5,741 Bytes
aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 aaf7fac e5e9118 45af699 e5e9118 aaf7fac 45af699 aaf7fac e5e9118 45af699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# app.py ββ Biomedical NER demo (full vs. LoRA/CRF)
# --------------------------------------------------
from __future__ import annotations
import html, logging, warnings
from functools import lru_cache
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from peft import PeftModel
from huggingface_hub import hf_hub_download # β missing import added
# βββββββββββ silence library warnings βββββββββββ
warnings.filterwarnings("ignore", category=UserWarning, module="peft")
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# βββββββββββ constants βββββββββββ
BASE = "dmis-lab/biobert-base-cased-v1.2"
REPO = "vishalvaka/biobert-finetuned-ner" # one repo; variants in sub-folders
VARIANTS: dict[str, tuple[str, str]] = {
"Full fine-tune" : ("full", "full"),
"LoRA-r32" : ("lora-r32", "lora"),
"LoRA-r32-fast" : ("lora-r32-fast", "lora"),
"LoRA-r16-CRF" : ("lora-r16-crf", "lora"),
"LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"),
}
LABELS = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"]
id2label = {i: lab for i, lab in enumerate(LABELS)}
label2id = {lab: i for i, lab in id2label.items()}
# βββββββββββ model loader (cached) βββββββββββ
@lru_cache(maxsize=None)
def load_model(folder: str, mode: str):
"""
Returns (tokenizer, model) cached per variant.
β’ mode == "full" β load full checkpoint from sub-folder
β’ mode == "lora" β load BASE + LoRA adapter from sub-folder
"""
if mode == "full":
model = AutoModelForTokenClassification.from_pretrained(
REPO, subfolder=folder
)
tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
# ensure human-readable label maps
model.config.id2label, model.config.label2id = id2label, label2id
return tok, model.eval()
# ---------- LoRA (with or without CRF) ----------
base = AutoModelForTokenClassification.from_pretrained(
BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id
)
model = PeftModel.from_pretrained(base, REPO, subfolder=folder)
tok = AutoTokenizer.from_pretrained(BASE)
# attach CRF/classifier weights **iff the file exists**
try:
if "crf" in folder: # only these have it
head_path = hf_hub_download(REPO,
"non_encoder_head.pth",
subfolder=folder,
repo_type="model")
extra = torch.load(head_path, map_location="cpu")
model.load_state_dict(extra, strict=False)
except Exception as e:
warnings.warn(f"[{folder}] couldnβt load CRF head: {e}")
return tok, model.eval()
# βββββββββββ helper to build HTML output βββββββββββ
def build_html(tokens: list[str], labels: list[str]) -> str:
"""Merge WordPieces and contiguous I-tokens."""
segments: list[tuple[str | None, str]] = [] # (entity_tag, text)
cur_tag, buf = None, ""
for tok, lab in zip(tokens, labels):
tag = None if lab == "O" else lab.split("-")[-1] # Chemical / Disease
text = tok[2:] if tok.startswith("##") else tok # drop ##
continuation = tok.startswith("##") or lab.startswith("I-")
if tag == cur_tag and continuation:
buf += text
else:
if buf:
segments.append((cur_tag, buf))
buf, cur_tag = text, tag
if buf:
segments.append((cur_tag, buf))
html_out, first = "", True
for tag, chunk in segments:
spacer = "" if first else " "
first = False
chunk = html.escape(chunk)
html_out += spacer + (
chunk if tag is None else f'<span class="{tag}">{chunk}</span>'
)
return html_out
# βββββββββββ Gradio inference fn βββββββββββ
def ner(text: str, variant: str):
folder, mode = VARIANTS[variant]
tok, model = load_model(folder, mode)
enc = tok(text, return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits.squeeze(0)
ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1]
labels = [model.config.id2label[i] for i in ids]
return build_html(tokens, labels)
# βββββββββββ UI definition βββββββββββ
CSS = """
span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px}
"""
demo = gr.Interface(
fn=ner,
inputs=[
gr.Textbox(lines=7, label="Paste biomedical text"),
gr.Radio(list(VARIANTS.keys()), value="LoRA-r32", label="Model variant"),
],
outputs=gr.HTML(label="Tagged output"),
examples=[
["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohn's disease patients."],
],
css=CSS,
theme=gr.themes.Soft(), # quiet built-in theme, no 404
cache_examples=False,
title="Biomedical NER β Full vs. LoRA / CRF",
description="Toggle a variant and watch **Chemical** / **Disease** entities light up. "
"Full checkpoints for CRF models, compact adapters for LoRA runs.",
)
if __name__ == "__main__":
demo.launch()
|