Spaces:
Running
Running
| """Model loading + sliding-window NER inference. | |
| Kept free of any Streamlit imports so it can be unit-tested / reused. | |
| The Streamlit pages wrap `load_model()` in `st.cache_resource`. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass, field | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import config | |
| class LoadedModel: | |
| tokenizer: object | |
| model: object | |
| id2label: dict | |
| source: str # human-readable description of where weights came from | |
| is_fallback: bool # True => random head, predictions are meaningless | |
| device: str = field(default="cpu") | |
| def _ensure_label_scheme(model): | |
| """If the loaded model lacks our entity labels, overwrite its id2label.""" | |
| cfg_labels = set(getattr(model.config, "id2label", {}).values()) | |
| if "B-SKILL" not in cfg_labels: | |
| model.config.id2label = dict(config.ID2LABEL) | |
| model.config.label2id = dict(config.LABEL2ID) | |
| return {int(k): v for k, v in model.config.id2label.items()} | |
| _USE_CONFIG = "__use_config__" | |
| def _load_local(path: str, device: str) -> LoadedModel: | |
| tok = AutoTokenizer.from_pretrained(path) | |
| model = AutoModelForTokenClassification.from_pretrained(path) | |
| id2label = _ensure_label_scheme(model) | |
| model.to(device).eval() | |
| return LoadedModel(tok, model, id2label, | |
| source=f"Local folder: {path}", | |
| is_fallback=False, device=device) | |
| def _load_hub(model_id: str, device: str) -> LoadedModel: | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForTokenClassification.from_pretrained(model_id) | |
| id2label = _ensure_label_scheme(model) | |
| model.to(device).eval() | |
| return LoadedModel(tok, model, id2label, | |
| source=f"Hugging Face Hub: {model_id}", | |
| is_fallback=False, device=device) | |
| def _load_fallback(device: str) -> LoadedModel: | |
| tok = AutoTokenizer.from_pretrained(config.FALLBACK_MODEL, add_prefix_space=True) | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| config.FALLBACK_MODEL, | |
| num_labels=len(config.LABELS), | |
| id2label=dict(config.ID2LABEL), | |
| label2id=dict(config.LABEL2ID), | |
| ) | |
| model.to(device).eval() | |
| return LoadedModel(tok, model, dict(config.ID2LABEL), | |
| source=f"Fallback base model: {config.FALLBACK_MODEL} (untrained head)", | |
| is_fallback=True, device=device) | |
| def load_model(ref: str | None = _USE_CONFIG) -> LoadedModel: | |
| """Load a NER model from a ref. | |
| - ``ref=_USE_CONFIG`` (default): resolve per config.py priority | |
| (local MODEL_PATH -> MODEL_ID -> fallback). Keeps old callers working. | |
| - ``ref`` is a local directory: load that exported folder. | |
| - ``ref`` is any other non-empty string: treat as a Hugging Face Hub repo id. | |
| - ``ref=None``: go straight to the demo fallback model. | |
| A local folder that's missing, or a Hub repo that can't be loaded | |
| (404 / private / offline), degrades gracefully to the demo fallback. | |
| """ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if ref == _USE_CONFIG: | |
| if config.MODEL_PATH and os.path.isdir(config.MODEL_PATH): | |
| return _load_local(config.MODEL_PATH, device) | |
| if config.MODEL_ID: | |
| ref = config.MODEL_ID | |
| else: | |
| return _load_fallback(device) | |
| if not ref: | |
| return _load_fallback(device) | |
| if os.path.isdir(ref): | |
| return _load_local(ref, device) | |
| try: | |
| return _load_hub(ref, device) | |
| except Exception: # noqa: BLE001 - missing/private/offline repo -> demo | |
| return _load_fallback(device) | |
| def predict(text: str, lm: LoadedModel): | |
| """Run sliding-window token classification over `text`. | |
| Returns (tokens, entities): | |
| tokens = [{"text", "label", "type", "start", "end"}] one per sub-word | |
| entities = [{"text", "type", "start", "end"}] merged BIO spans | |
| """ | |
| text = text or "" | |
| if not text.strip(): | |
| return [], [] | |
| enc = lm.tokenizer( | |
| text, | |
| max_length=config.MAX_LENGTH, | |
| truncation=True, | |
| stride=config.STRIDE, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| offsets = enc["offset_mapping"] | |
| attn = enc["attention_mask"] | |
| input_ids = enc["input_ids"].to(lm.device) | |
| attn_dev = attn.to(lm.device) | |
| logits = lm.model(input_ids=input_ids, attention_mask=attn_dev).logits | |
| preds = logits.argmax(-1).cpu() | |
| # Deduplicate overlapping sliding-window tokens by their global char offset. | |
| seen: dict[int, tuple] = {} | |
| n_windows, seq_len = preds.shape | |
| for w in range(n_windows): | |
| for i in range(seq_len): | |
| s, e = offsets[w][i].tolist() | |
| if (s == 0 and e == 0) or attn[w][i] == 0: | |
| continue # special token or padding | |
| if s in seen: | |
| continue | |
| seen[s] = (s, e, int(preds[w][i])) | |
| tokens = [] | |
| for s in sorted(seen): | |
| _, e, pid = seen[s] | |
| label = lm.id2label.get(pid, "O") | |
| etype = label.split("-", 1)[1] if "-" in label else None | |
| tokens.append({"text": text[s:e], "label": label, "type": etype, | |
| "start": s, "end": e}) | |
| entities = _merge_bio(tokens, text) | |
| return tokens, entities | |
| def _merge_bio(tokens, text): | |
| """Merge consecutive B-/I- tokens of the same type into entity spans.""" | |
| entities = [] | |
| cur = None | |
| for t in tokens: | |
| label = t["label"] | |
| if "-" not in label: # "O" | |
| if cur: | |
| entities.append(cur) | |
| cur = None | |
| continue | |
| prefix, etype = label.split("-", 1) | |
| if prefix == "B" or cur is None or cur["type"] != etype: | |
| if cur: | |
| entities.append(cur) | |
| cur = {"type": etype, "start": t["start"], "end": t["end"]} | |
| else: # I- continuing the same type | |
| cur["end"] = t["end"] | |
| if cur: | |
| entities.append(cur) | |
| for e in entities: | |
| e["text"] = text[e["start"]:e["end"]].strip() | |
| return [e for e in entities if e["text"]] | |
| def group_entities(entities): | |
| """Group merged entities by type, de-duplicating case-insensitively.""" | |
| grouped = {t: [] for t in config.ENTITY_TYPES} | |
| seen = {t: set() for t in config.ENTITY_TYPES} | |
| for e in entities: | |
| t = e["type"] | |
| if t not in grouped: | |
| continue | |
| key = e["text"].lower() | |
| if key in seen[t]: | |
| continue | |
| seen[t].add(key) | |
| grouped[t].append(e["text"]) | |
| return grouped | |