bio-ner-demo / app.py
vishalvaka's picture
made some fixes for CRF models
aaf7fac
raw
history blame
5.72 kB
# app.py ── Biomedical NER demo (full vs. LoRA/CRF)
# --------------------------------------------------
from __future__ import annotations
import html, logging, warnings
from functools import lru_cache
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from peft import PeftModel
from huggingface_hub import hf_hub_download # ← missing import added
# ─────────── silence library warnings ───────────
warnings.filterwarnings("ignore", category=UserWarning, module="peft")
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
# ─────────── constants ───────────
BASE = "dmis-lab/biobert-base-cased-v1.2"
REPO = "vishalvaka/biobert-finetuned-ner" # one repo; variants in sub-folders
VARIANTS: dict[str, tuple[str, str]] = {
"Full fine-tune" : ("full", "full"),
"LoRA-r32" : ("lora-r32", "lora"),
"LoRA-r32-fast" : ("lora-r32-fast", "lora"),
"LoRA-r16-CRF" : ("lora-r16-crf", "lora"),
"LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"),
}
LABELS = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"]
id2label = {i: lab for i, lab in enumerate(LABELS)}
label2id = {lab: i for i, lab in id2label.items()}
# ─────────── model loader (cached) ───────────
@lru_cache(maxsize=None)
def load_model(folder: str, mode: str):
"""
Returns (tokenizer, model) cached per variant.
β€’ mode == "full" β†’ load full checkpoint from sub-folder
β€’ mode == "lora" β†’ load BASE + LoRA adapter from sub-folder
"""
if mode == "full":
model = AutoModelForTokenClassification.from_pretrained(
REPO, subfolder=folder
)
tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
# ensure human-readable label maps
model.config.id2label, model.config.label2id = id2label, label2id
return tok, model.eval()
# ---------- LoRA (with or without CRF) ----------
base = AutoModelForTokenClassification.from_pretrained(
BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id
)
model = PeftModel.from_pretrained(base, REPO, subfolder=folder)
tok = AutoTokenizer.from_pretrained(BASE)
# attach CRF/classifier weights **iff the file exists**
try:
if "crf" in folder: # only these have it
head_path = hf_hub_download(REPO,
"non_encoder_head.pth",
subfolder=folder,
repo_type="model")
extra = torch.load(head_path, map_location="cpu")
model.load_state_dict(extra, strict=False)
except Exception as e:
warnings.warn(f"[{folder}] couldn’t load CRF head: {e}")
return tok, model.eval()
# ─────────── helper to build HTML output ───────────
def build_html(tokens: list[str], labels: list[str]) -> str:
"""Merge WordPieces and contiguous I-tokens."""
segments: list[tuple[str | None, str]] = [] # (entity_tag, text)
cur_tag, buf = None, ""
for tok, lab in zip(tokens, labels):
tag = None if lab == "O" else lab.split("-")[-1] # Chemical / Disease
text = tok[2:] if tok.startswith("##") else tok # drop ##
continuation = tok.startswith("##") or lab.startswith("I-")
if tag == cur_tag and continuation:
buf += text
else:
if buf:
segments.append((cur_tag, buf))
buf, cur_tag = text, tag
if buf:
segments.append((cur_tag, buf))
html_out, first = "", True
for tag, chunk in segments:
spacer = "" if first else " "
first = False
chunk = html.escape(chunk)
html_out += spacer + (
chunk if tag is None else f'<span class="{tag}">{chunk}</span>'
)
return html_out
# ─────────── Gradio inference fn ───────────
def ner(text: str, variant: str):
folder, mode = VARIANTS[variant]
tok, model = load_model(folder, mode)
enc = tok(text, return_tensors="pt")
with torch.no_grad():
logits = model(**enc).logits.squeeze(0)
ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1]
labels = [model.config.id2label[i] for i in ids]
return build_html(tokens, labels)
# ─────────── UI definition ───────────
CSS = """
span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px}
"""
demo = gr.Interface(
fn=ner,
inputs=[
gr.Textbox(lines=7, label="Paste biomedical text"),
gr.Radio(list(VARIANTS.keys()), value="LoRA-r32", label="Model variant"),
],
outputs=gr.HTML(label="Tagged output"),
examples=[
["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohn’s disease patients."],
],
css=CSS,
theme=gr.themes.Soft(), # quiet built-in theme, no 404
title="Biomedical NER β€” Full vs. LoRA / CRF",
description="Toggle a variant and watch **Chemical** / **Disease** entities light up. "
"Full checkpoints for CRF models, compact adapters for LoRA runs.",
)
if __name__ == "__main__":
demo.launch()