"""Model loading + sliding-window NER inference. Kept free of any Streamlit imports so it can be unit-tested / reused. The Streamlit pages wrap `load_model()` in `st.cache_resource`. """ from __future__ import annotations import os from dataclasses import dataclass, field import torch from transformers import AutoTokenizer, AutoModelForTokenClassification import config @dataclass class LoadedModel: tokenizer: object model: object id2label: dict source: str # human-readable description of where weights came from is_fallback: bool # True => random head, predictions are meaningless device: str = field(default="cpu") def _ensure_label_scheme(model): """If the loaded model lacks our entity labels, overwrite its id2label.""" cfg_labels = set(getattr(model.config, "id2label", {}).values()) if "B-SKILL" not in cfg_labels: model.config.id2label = dict(config.ID2LABEL) model.config.label2id = dict(config.LABEL2ID) return {int(k): v for k, v in model.config.id2label.items()} _USE_CONFIG = "__use_config__" def _load_local(path: str, device: str) -> LoadedModel: tok = AutoTokenizer.from_pretrained(path) model = AutoModelForTokenClassification.from_pretrained(path) id2label = _ensure_label_scheme(model) model.to(device).eval() return LoadedModel(tok, model, id2label, source=f"Local folder: {path}", is_fallback=False, device=device) def _load_hub(model_id: str, device: str) -> LoadedModel: tok = AutoTokenizer.from_pretrained(model_id) model = AutoModelForTokenClassification.from_pretrained(model_id) id2label = _ensure_label_scheme(model) model.to(device).eval() return LoadedModel(tok, model, id2label, source=f"Hugging Face Hub: {model_id}", is_fallback=False, device=device) def _load_fallback(device: str) -> LoadedModel: tok = AutoTokenizer.from_pretrained(config.FALLBACK_MODEL, add_prefix_space=True) model = AutoModelForTokenClassification.from_pretrained( config.FALLBACK_MODEL, num_labels=len(config.LABELS), id2label=dict(config.ID2LABEL), label2id=dict(config.LABEL2ID), ) model.to(device).eval() return LoadedModel(tok, model, dict(config.ID2LABEL), source=f"Fallback base model: {config.FALLBACK_MODEL} (untrained head)", is_fallback=True, device=device) def load_model(ref: str | None = _USE_CONFIG) -> LoadedModel: """Load a NER model from a ref. - ``ref=_USE_CONFIG`` (default): resolve per config.py priority (local MODEL_PATH -> MODEL_ID -> fallback). Keeps old callers working. - ``ref`` is a local directory: load that exported folder. - ``ref`` is any other non-empty string: treat as a Hugging Face Hub repo id. - ``ref=None``: go straight to the demo fallback model. A local folder that's missing, or a Hub repo that can't be loaded (404 / private / offline), degrades gracefully to the demo fallback. """ device = "cuda" if torch.cuda.is_available() else "cpu" if ref == _USE_CONFIG: if config.MODEL_PATH and os.path.isdir(config.MODEL_PATH): return _load_local(config.MODEL_PATH, device) if config.MODEL_ID: ref = config.MODEL_ID else: return _load_fallback(device) if not ref: return _load_fallback(device) if os.path.isdir(ref): return _load_local(ref, device) try: return _load_hub(ref, device) except Exception: # noqa: BLE001 - missing/private/offline repo -> demo return _load_fallback(device) @torch.no_grad() def predict(text: str, lm: LoadedModel): """Run sliding-window token classification over `text`. Returns (tokens, entities): tokens = [{"text", "label", "type", "start", "end"}] one per sub-word entities = [{"text", "type", "start", "end"}] merged BIO spans """ text = text or "" if not text.strip(): return [], [] enc = lm.tokenizer( text, max_length=config.MAX_LENGTH, truncation=True, stride=config.STRIDE, return_overflowing_tokens=True, return_offsets_mapping=True, padding=True, return_tensors="pt", ) offsets = enc["offset_mapping"] attn = enc["attention_mask"] input_ids = enc["input_ids"].to(lm.device) attn_dev = attn.to(lm.device) logits = lm.model(input_ids=input_ids, attention_mask=attn_dev).logits probs = logits.softmax(-1) # keep the distribution, not just argmax conf_all = probs.max(-1).values.cpu() # per-token confidence of the chosen label preds = logits.argmax(-1).cpu() # Deduplicate overlapping sliding-window tokens by their global char offset. seen: dict[int, tuple] = {} n_windows, seq_len = preds.shape for w in range(n_windows): for i in range(seq_len): s, e = offsets[w][i].tolist() if (s == 0 and e == 0) or attn[w][i] == 0: continue # special token or padding if s in seen: continue seen[s] = (s, e, int(preds[w][i]), float(conf_all[w][i])) tokens = [] for s in sorted(seen): _, e, pid, conf = seen[s] label = lm.id2label.get(pid, "O") etype = label.split("-", 1)[1] if "-" in label else None tokens.append({"text": text[s:e], "label": label, "type": etype, "start": s, "end": e, "conf": conf}) entities = _merge_bio(tokens, text) return tokens, entities def _merge_bio(tokens, text): """Merge consecutive B-/I- tokens of the same type into entity spans.""" entities = [] cur = None for t in tokens: label = t["label"] if "-" not in label: # "O" if cur: entities.append(cur) cur = None continue prefix, etype = label.split("-", 1) if prefix == "B" or cur is None or cur["type"] != etype: if cur: entities.append(cur) cur = {"type": etype, "start": t["start"], "end": t["end"], "confs": [t.get("conf", 1.0)]} else: # I- continuing the same type cur["end"] = t["end"] cur["confs"].append(t.get("conf", 1.0)) if cur: entities.append(cur) for e in entities: e["text"] = text[e["start"]:e["end"]].strip() confs = e.pop("confs", []) or [1.0] e["conf"] = sum(confs) / len(confs) # mean confidence over member tokens return [e for e in entities if e["text"]] def group_entities(entities): """Group merged entities by type, de-duplicating case-insensitively.""" grouped = {t: [] for t in config.ENTITY_TYPES} seen = {t: set() for t in config.ENTITY_TYPES} for e in entities: t = e["type"] if t not in grouped: continue key = e["text"].lower() if key in seen[t]: continue seen[t].add(key) grouped[t].append(e["text"]) return grouped