Spaces:
Running
Running
File size: 6,740 Bytes
c59578d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """Model loading + sliding-window NER inference.
Kept free of any Streamlit imports so it can be unit-tested / reused.
The Streamlit pages wrap `load_model()` in `st.cache_resource`.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import config
@dataclass
class LoadedModel:
tokenizer: object
model: object
id2label: dict
source: str # human-readable description of where weights came from
is_fallback: bool # True => random head, predictions are meaningless
device: str = field(default="cpu")
def _ensure_label_scheme(model):
"""If the loaded model lacks our entity labels, overwrite its id2label."""
cfg_labels = set(getattr(model.config, "id2label", {}).values())
if "B-SKILL" not in cfg_labels:
model.config.id2label = dict(config.ID2LABEL)
model.config.label2id = dict(config.LABEL2ID)
return {int(k): v for k, v in model.config.id2label.items()}
_USE_CONFIG = "__use_config__"
def _load_local(path: str, device: str) -> LoadedModel:
tok = AutoTokenizer.from_pretrained(path)
model = AutoModelForTokenClassification.from_pretrained(path)
id2label = _ensure_label_scheme(model)
model.to(device).eval()
return LoadedModel(tok, model, id2label,
source=f"Local folder: {path}",
is_fallback=False, device=device)
def _load_hub(model_id: str, device: str) -> LoadedModel:
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id)
id2label = _ensure_label_scheme(model)
model.to(device).eval()
return LoadedModel(tok, model, id2label,
source=f"Hugging Face Hub: {model_id}",
is_fallback=False, device=device)
def _load_fallback(device: str) -> LoadedModel:
tok = AutoTokenizer.from_pretrained(config.FALLBACK_MODEL, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(
config.FALLBACK_MODEL,
num_labels=len(config.LABELS),
id2label=dict(config.ID2LABEL),
label2id=dict(config.LABEL2ID),
)
model.to(device).eval()
return LoadedModel(tok, model, dict(config.ID2LABEL),
source=f"Fallback base model: {config.FALLBACK_MODEL} (untrained head)",
is_fallback=True, device=device)
def load_model(ref: str | None = _USE_CONFIG) -> LoadedModel:
"""Load a NER model from a ref.
- ``ref=_USE_CONFIG`` (default): resolve per config.py priority
(local MODEL_PATH -> MODEL_ID -> fallback). Keeps old callers working.
- ``ref`` is a local directory: load that exported folder.
- ``ref`` is any other non-empty string: treat as a Hugging Face Hub repo id.
- ``ref=None``: go straight to the demo fallback model.
A local folder that's missing, or a Hub repo that can't be loaded
(404 / private / offline), degrades gracefully to the demo fallback.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
if ref == _USE_CONFIG:
if config.MODEL_PATH and os.path.isdir(config.MODEL_PATH):
return _load_local(config.MODEL_PATH, device)
if config.MODEL_ID:
ref = config.MODEL_ID
else:
return _load_fallback(device)
if not ref:
return _load_fallback(device)
if os.path.isdir(ref):
return _load_local(ref, device)
try:
return _load_hub(ref, device)
except Exception: # noqa: BLE001 - missing/private/offline repo -> demo
return _load_fallback(device)
@torch.no_grad()
def predict(text: str, lm: LoadedModel):
"""Run sliding-window token classification over `text`.
Returns (tokens, entities):
tokens = [{"text", "label", "type", "start", "end"}] one per sub-word
entities = [{"text", "type", "start", "end"}] merged BIO spans
"""
text = text or ""
if not text.strip():
return [], []
enc = lm.tokenizer(
text,
max_length=config.MAX_LENGTH,
truncation=True,
stride=config.STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding=True,
return_tensors="pt",
)
offsets = enc["offset_mapping"]
attn = enc["attention_mask"]
input_ids = enc["input_ids"].to(lm.device)
attn_dev = attn.to(lm.device)
logits = lm.model(input_ids=input_ids, attention_mask=attn_dev).logits
preds = logits.argmax(-1).cpu()
# Deduplicate overlapping sliding-window tokens by their global char offset.
seen: dict[int, tuple] = {}
n_windows, seq_len = preds.shape
for w in range(n_windows):
for i in range(seq_len):
s, e = offsets[w][i].tolist()
if (s == 0 and e == 0) or attn[w][i] == 0:
continue # special token or padding
if s in seen:
continue
seen[s] = (s, e, int(preds[w][i]))
tokens = []
for s in sorted(seen):
_, e, pid = seen[s]
label = lm.id2label.get(pid, "O")
etype = label.split("-", 1)[1] if "-" in label else None
tokens.append({"text": text[s:e], "label": label, "type": etype,
"start": s, "end": e})
entities = _merge_bio(tokens, text)
return tokens, entities
def _merge_bio(tokens, text):
"""Merge consecutive B-/I- tokens of the same type into entity spans."""
entities = []
cur = None
for t in tokens:
label = t["label"]
if "-" not in label: # "O"
if cur:
entities.append(cur)
cur = None
continue
prefix, etype = label.split("-", 1)
if prefix == "B" or cur is None or cur["type"] != etype:
if cur:
entities.append(cur)
cur = {"type": etype, "start": t["start"], "end": t["end"]}
else: # I- continuing the same type
cur["end"] = t["end"]
if cur:
entities.append(cur)
for e in entities:
e["text"] = text[e["start"]:e["end"]].strip()
return [e for e in entities if e["text"]]
def group_entities(entities):
"""Group merged entities by type, de-duplicating case-insensitively."""
grouped = {t: [] for t in config.ENTITY_TYPES}
seen = {t: set() for t in config.ENTITY_TYPES}
for e in entities:
t = e["type"]
if t not in grouped:
continue
key = e["text"].lower()
if key in seen[t]:
continue
seen[t].add(key)
grouped[t].append(e["text"])
return grouped
|