File size: 5,741 Bytes
aaf7fac
 
 
 
 
 
 
 
e5e9118
 
aaf7fac
 
 
 
 
e5e9118
aaf7fac
e5e9118
aaf7fac
 
 
 
 
 
 
 
e5e9118
 
aaf7fac
 
 
 
 
 
 
e5e9118
aaf7fac
 
 
 
e5e9118
 
 
 
 
 
aaf7fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5e9118
 
aaf7fac
 
 
 
 
e5e9118
aaf7fac
 
 
 
 
 
 
e5e9118
aaf7fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5e9118
aaf7fac
e5e9118
 
 
 
 
 
 
 
 
 
 
 
 
45af699
e5e9118
 
aaf7fac
45af699
aaf7fac
 
 
e5e9118
 
 
45af699
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# app.py  ── Biomedical NER demo (full vs. LoRA/CRF)
# --------------------------------------------------
from __future__ import annotations
import html, logging, warnings
from functools import lru_cache

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from peft import PeftModel
from huggingface_hub import hf_hub_download      # ← missing import added

# ─────────── silence library warnings ───────────
warnings.filterwarnings("ignore", category=UserWarning, module="peft")
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)

# ─────────── constants ───────────
BASE = "dmis-lab/biobert-base-cased-v1.2"
REPO = "vishalvaka/biobert-finetuned-ner"     # one repo; variants in sub-folders

VARIANTS: dict[str, tuple[str, str]] = {
    "Full fine-tune"     : ("full",              "full"),
    "LoRA-r32"           : ("lora-r32",          "lora"),
    "LoRA-r32-fast"      : ("lora-r32-fast",     "lora"),
    "LoRA-r16-CRF"       : ("lora-r16-crf",      "lora"),   
    "LoRA-r16-CRF-long"  : ("lora-r16-crf-long", "lora"),   
}

LABELS   = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"]
id2label = {i: lab for i, lab in enumerate(LABELS)}
label2id = {lab: i for i, lab in id2label.items()}

# ─────────── model loader (cached) ───────────
@lru_cache(maxsize=None)
def load_model(folder: str, mode: str):
    """
    Returns (tokenizer, model) cached per variant.

    β€’ mode == "full"  β†’ load full checkpoint from sub-folder
    β€’ mode == "lora"  β†’ load BASE + LoRA adapter from sub-folder
    """
    if mode == "full":
        model = AutoModelForTokenClassification.from_pretrained(
            REPO, subfolder=folder
        )
        tok   = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
        # ensure human-readable label maps
        model.config.id2label, model.config.label2id = id2label, label2id
        return tok, model.eval()

    # ---------- LoRA (with or without CRF) ----------
    base = AutoModelForTokenClassification.from_pretrained(
        BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id
    )
    model = PeftModel.from_pretrained(base, REPO, subfolder=folder)
    tok   = AutoTokenizer.from_pretrained(BASE)

    # attach CRF/classifier weights **iff the file exists**
    try:
        if "crf" in folder:                       # only these have it
            head_path = hf_hub_download(REPO,
                                         "non_encoder_head.pth",
                                         subfolder=folder,
                                         repo_type="model")
            extra = torch.load(head_path, map_location="cpu")
            model.load_state_dict(extra, strict=False)
    except Exception as e:
        warnings.warn(f"[{folder}] couldn’t load CRF head: {e}")

    return tok, model.eval()

# ─────────── helper to build HTML output ───────────
def build_html(tokens: list[str], labels: list[str]) -> str:
    """Merge WordPieces and contiguous I-tokens."""
    segments: list[tuple[str | None, str]] = []   # (entity_tag, text)
    cur_tag, buf = None, ""

    for tok, lab in zip(tokens, labels):
        tag  = None if lab == "O" else lab.split("-")[-1]          # Chemical / Disease
        text = tok[2:] if tok.startswith("##") else tok            # drop ##
        continuation = tok.startswith("##") or lab.startswith("I-")

        if tag == cur_tag and continuation:
            buf += text
        else:
            if buf:
                segments.append((cur_tag, buf))
            buf, cur_tag = text, tag
    if buf:
        segments.append((cur_tag, buf))

    html_out, first = "", True
    for tag, chunk in segments:
        spacer = "" if first else " "
        first  = False
        chunk  = html.escape(chunk)
        html_out += spacer + (
            chunk if tag is None else f'<span class="{tag}">{chunk}</span>'
        )
    return html_out

# ─────────── Gradio inference fn ───────────
def ner(text: str, variant: str):
    folder, mode = VARIANTS[variant]
    tok, model   = load_model(folder, mode)

    enc = tok(text, return_tensors="pt")
    with torch.no_grad():
        logits = model(**enc).logits.squeeze(0)

    ids     = logits.argmax(dim=-1).tolist()[1:-1]              # drop CLS/SEP
    tokens  = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1]
    labels  = [model.config.id2label[i] for i in ids]

    return build_html(tokens, labels)

# ─────────── UI definition ───────────
CSS = """
span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
span.Disease  {background:#ffdddd; padding:2px 4px; border-radius:4px}
"""

demo = gr.Interface(
    fn=ner,
    inputs=[
        gr.Textbox(lines=7, label="Paste biomedical text"),
        gr.Radio(list(VARIANTS.keys()), value="LoRA-r32", label="Model variant"),
    ],
    outputs=gr.HTML(label="Tagged output"),
    examples=[
        ["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohn's disease patients."],
    ],
    css=CSS,
    theme=gr.themes.Soft(),    # quiet built-in theme, no 404
    cache_examples=False,
    title="Biomedical NER β€” Full vs. LoRA / CRF",
    description="Toggle a variant and watch **Chemical** / **Disease** entities light up. "
                "Full checkpoints for CRF models, compact adapters for LoRA runs.",
)

if __name__ == "__main__":
    demo.launch()