""" EvidenceNER model definition. Architecture: distilbert-base-uncased with token classification head. Task: Named Entity Recognition on redacted complaint text. Entity types: ORG | AMOUNT | DATE | REF_ID | ACCOUNT | PERSON Input: Redacted complaint text (str, max 512 tokens after tokenisation). Output: List of Entity(text, label, start, end, confidence). BIO scheme: O + B-/I- prefix for each of the 6 entity types → 13 labels total. Note: PERSON entities surviving Presidio redaction are role-references (e.g. "customer care executive"), not personal names. """ from __future__ import annotations import logging from dataclasses import dataclass from typing import Optional import torch from transformers import AutoModelForTokenClassification, AutoTokenizer logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Label constants — shared by model.py, train.py, predict.py # --------------------------------------------------------------------------- NER_LABELS = ["ORG", "AMOUNT", "DATE", "REF_ID", "ACCOUNT", "PERSON"] # O first, then B-/I- pairs in the same order as NER_LABELS → 13 labels BIO_LABELS: list[str] = ["O"] + [ f"{bio}-{label}" for label in NER_LABELS for bio in ("B", "I") ] LABEL2ID: dict[str, int] = {label: i for i, label in enumerate(BIO_LABELS)} ID2LABEL: dict[int, str] = {i: label for label, i in LABEL2ID.items()} NUM_LABELS = len(BIO_LABELS) # 13 # --------------------------------------------------------------------------- # Public output type # --------------------------------------------------------------------------- @dataclass class Entity: """A single recognised entity span.""" text: str label: str start: int end: int confidence: float # --------------------------------------------------------------------------- # EvidenceNER # --------------------------------------------------------------------------- class EvidenceNER: """ DistilBERT token classifier for complaint evidence extraction. Loads a fine-tuned checkpoint produced by train.py. Uses the tokenizer's offset_mapping to convert subword-level BIO predictions back to character-level spans without any secondary tokenisation step. """ BASE_MODEL = "distilbert-base-uncased" def __init__(self, model_dir: str) -> None: """Load a fine-tuned NER checkpoint from *model_dir*.""" self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForTokenClassification.from_pretrained(model_dir) self.model.eval() self._device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) self.model.to(self._device) logger.info("EvidenceNER loaded from %s on %s", model_dir, self._device) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def extract(self, text: str) -> list[Entity]: """ Extract entity spans from *text* and return a list of Entity objects. Spans are character-level (start/end index into the original string). Returns [] for empty or whitespace-only input. """ if not text or not text.strip(): return [] encoding = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=512, return_offsets_mapping=True, ) # offset_mapping is not a model input — pop before forward pass offset_mapping: list[tuple[int, int]] = ( encoding.pop("offset_mapping")[0].tolist() ) model_inputs = {k: v.to(self._device) for k, v in encoding.items()} with torch.no_grad(): logits = self.model(**model_inputs).logits[0] # (seq_len, num_labels) probs = torch.softmax(logits, dim=-1).cpu() pred_ids: list[int] = probs.argmax(dim=-1).tolist() conf_scores: list[float] = probs.max(dim=-1).values.tolist() return self._aggregate_spans(text, offset_mapping, pred_ids, conf_scores) # ------------------------------------------------------------------ # Span aggregation # ------------------------------------------------------------------ def _aggregate_spans( self, text: str, offset_mapping: list[tuple[int, int]], pred_ids: list[int], conf_scores: list[float], ) -> list[Entity]: """ Convert per-subtoken BIO predictions into character-level Entity spans. Special tokens ([CLS], [SEP]) have offset (0, 0) — i.e. start == end — and are skipped. An I- tag that does not continue the current B- entity type is treated as O (broken sequence). """ entities: list[Entity] = [] current: Optional[dict] = None current_confs: list[float] = [] def _flush() -> None: if current is not None: entities.append(Entity( text=current["text"], label=current["label"], start=current["start"], end=current["end"], confidence=sum(current_confs) / len(current_confs), )) for (start, end), label_id, conf in zip(offset_mapping, pred_ids, conf_scores): # Special tokens have zero-length offset spans if start == end: _flush() current = None current_confs = [] continue label = ID2LABEL[label_id] if label.startswith("B-"): _flush() entity_type = label[2:] current = { "text": text[start:end], "label": entity_type, "start": start, "end": end, } current_confs = [conf] elif ( label.startswith("I-") and current is not None and label[2:] == current["label"] ): # Extend the current span (including any whitespace between subwords) current["text"] = text[current["start"]: end] current["end"] = end current_confs.append(conf) else: # O or mismatched I- → close current span _flush() current = None current_confs = [] # Flush any span still open at end of sequence _flush() return entities