Zeqhx's picture
Deploy CV parser dashboard with dataset 2 model
c59578d verified
"""Model loading + sliding-window NER inference.
Kept free of any Streamlit imports so it can be unit-tested / reused.
The Streamlit pages wrap `load_model()` in `st.cache_resource`.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import config
@dataclass
class LoadedModel:
tokenizer: object
model: object
id2label: dict
source: str # human-readable description of where weights came from
is_fallback: bool # True => random head, predictions are meaningless
device: str = field(default="cpu")
def _ensure_label_scheme(model):
"""If the loaded model lacks our entity labels, overwrite its id2label."""
cfg_labels = set(getattr(model.config, "id2label", {}).values())
if "B-SKILL" not in cfg_labels:
model.config.id2label = dict(config.ID2LABEL)
model.config.label2id = dict(config.LABEL2ID)
return {int(k): v for k, v in model.config.id2label.items()}
_USE_CONFIG = "__use_config__"
def _load_local(path: str, device: str) -> LoadedModel:
tok = AutoTokenizer.from_pretrained(path)
model = AutoModelForTokenClassification.from_pretrained(path)
id2label = _ensure_label_scheme(model)
model.to(device).eval()
return LoadedModel(tok, model, id2label,
source=f"Local folder: {path}",
is_fallback=False, device=device)
def _load_hub(model_id: str, device: str) -> LoadedModel:
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id)
id2label = _ensure_label_scheme(model)
model.to(device).eval()
return LoadedModel(tok, model, id2label,
source=f"Hugging Face Hub: {model_id}",
is_fallback=False, device=device)
def _load_fallback(device: str) -> LoadedModel:
tok = AutoTokenizer.from_pretrained(config.FALLBACK_MODEL, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(
config.FALLBACK_MODEL,
num_labels=len(config.LABELS),
id2label=dict(config.ID2LABEL),
label2id=dict(config.LABEL2ID),
)
model.to(device).eval()
return LoadedModel(tok, model, dict(config.ID2LABEL),
source=f"Fallback base model: {config.FALLBACK_MODEL} (untrained head)",
is_fallback=True, device=device)
def load_model(ref: str | None = _USE_CONFIG) -> LoadedModel:
"""Load a NER model from a ref.
- ``ref=_USE_CONFIG`` (default): resolve per config.py priority
(local MODEL_PATH -> MODEL_ID -> fallback). Keeps old callers working.
- ``ref`` is a local directory: load that exported folder.
- ``ref`` is any other non-empty string: treat as a Hugging Face Hub repo id.
- ``ref=None``: go straight to the demo fallback model.
A local folder that's missing, or a Hub repo that can't be loaded
(404 / private / offline), degrades gracefully to the demo fallback.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
if ref == _USE_CONFIG:
if config.MODEL_PATH and os.path.isdir(config.MODEL_PATH):
return _load_local(config.MODEL_PATH, device)
if config.MODEL_ID:
ref = config.MODEL_ID
else:
return _load_fallback(device)
if not ref:
return _load_fallback(device)
if os.path.isdir(ref):
return _load_local(ref, device)
try:
return _load_hub(ref, device)
except Exception: # noqa: BLE001 - missing/private/offline repo -> demo
return _load_fallback(device)
@torch.no_grad()
def predict(text: str, lm: LoadedModel):
"""Run sliding-window token classification over `text`.
Returns (tokens, entities):
tokens = [{"text", "label", "type", "start", "end"}] one per sub-word
entities = [{"text", "type", "start", "end"}] merged BIO spans
"""
text = text or ""
if not text.strip():
return [], []
enc = lm.tokenizer(
text,
max_length=config.MAX_LENGTH,
truncation=True,
stride=config.STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding=True,
return_tensors="pt",
)
offsets = enc["offset_mapping"]
attn = enc["attention_mask"]
input_ids = enc["input_ids"].to(lm.device)
attn_dev = attn.to(lm.device)
logits = lm.model(input_ids=input_ids, attention_mask=attn_dev).logits
preds = logits.argmax(-1).cpu()
# Deduplicate overlapping sliding-window tokens by their global char offset.
seen: dict[int, tuple] = {}
n_windows, seq_len = preds.shape
for w in range(n_windows):
for i in range(seq_len):
s, e = offsets[w][i].tolist()
if (s == 0 and e == 0) or attn[w][i] == 0:
continue # special token or padding
if s in seen:
continue
seen[s] = (s, e, int(preds[w][i]))
tokens = []
for s in sorted(seen):
_, e, pid = seen[s]
label = lm.id2label.get(pid, "O")
etype = label.split("-", 1)[1] if "-" in label else None
tokens.append({"text": text[s:e], "label": label, "type": etype,
"start": s, "end": e})
entities = _merge_bio(tokens, text)
return tokens, entities
def _merge_bio(tokens, text):
"""Merge consecutive B-/I- tokens of the same type into entity spans."""
entities = []
cur = None
for t in tokens:
label = t["label"]
if "-" not in label: # "O"
if cur:
entities.append(cur)
cur = None
continue
prefix, etype = label.split("-", 1)
if prefix == "B" or cur is None or cur["type"] != etype:
if cur:
entities.append(cur)
cur = {"type": etype, "start": t["start"], "end": t["end"]}
else: # I- continuing the same type
cur["end"] = t["end"]
if cur:
entities.append(cur)
for e in entities:
e["text"] = text[e["start"]:e["end"]].strip()
return [e for e in entities if e["text"]]
def group_entities(entities):
"""Group merged entities by type, de-duplicating case-insensitively."""
grouped = {t: [] for t in config.ENTITY_TYPES}
seen = {t: set() for t in config.ENTITY_TYPES}
for e in entities:
t = e["type"]
if t not in grouped:
continue
key = e["text"].lower()
if key in seen[t]:
continue
seen[t].add(key)
grouped[t].append(e["text"])
return grouped