Spaces:
Sleeping
Sleeping
| """ | |
| EvidenceNER model definition. | |
| Architecture: distilbert-base-uncased with token classification head. | |
| Task: Named Entity Recognition on redacted complaint text. | |
| Entity types: ORG | AMOUNT | DATE | REF_ID | ACCOUNT | PERSON | |
| Input: Redacted complaint text (str, max 512 tokens after tokenisation). | |
| Output: List of Entity(text, label, start, end, confidence). | |
| BIO scheme: O + B-/I- prefix for each of the 6 entity types → 13 labels total. | |
| Note: PERSON entities surviving Presidio redaction are role-references | |
| (e.g. "customer care executive"), not personal names. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from transformers import AutoModelForTokenClassification, AutoTokenizer | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Label constants — shared by model.py, train.py, predict.py | |
| # --------------------------------------------------------------------------- | |
| NER_LABELS = ["ORG", "AMOUNT", "DATE", "REF_ID", "ACCOUNT", "PERSON"] | |
| # O first, then B-/I- pairs in the same order as NER_LABELS → 13 labels | |
| BIO_LABELS: list[str] = ["O"] + [ | |
| f"{bio}-{label}" for label in NER_LABELS for bio in ("B", "I") | |
| ] | |
| LABEL2ID: dict[str, int] = {label: i for i, label in enumerate(BIO_LABELS)} | |
| ID2LABEL: dict[int, str] = {i: label for label, i in LABEL2ID.items()} | |
| NUM_LABELS = len(BIO_LABELS) # 13 | |
| # --------------------------------------------------------------------------- | |
| # Public output type | |
| # --------------------------------------------------------------------------- | |
| class Entity: | |
| """A single recognised entity span.""" | |
| text: str | |
| label: str | |
| start: int | |
| end: int | |
| confidence: float | |
| # --------------------------------------------------------------------------- | |
| # EvidenceNER | |
| # --------------------------------------------------------------------------- | |
| class EvidenceNER: | |
| """ | |
| DistilBERT token classifier for complaint evidence extraction. | |
| Loads a fine-tuned checkpoint produced by train.py. Uses the tokenizer's | |
| offset_mapping to convert subword-level BIO predictions back to | |
| character-level spans without any secondary tokenisation step. | |
| """ | |
| BASE_MODEL = "distilbert-base-uncased" | |
| def __init__(self, model_dir: str) -> None: | |
| """Load a fine-tuned NER checkpoint from *model_dir*.""" | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| self.model = AutoModelForTokenClassification.from_pretrained(model_dir) | |
| self.model.eval() | |
| self._device = torch.device( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| self.model.to(self._device) | |
| logger.info("EvidenceNER loaded from %s on %s", model_dir, self._device) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def extract(self, text: str) -> list[Entity]: | |
| """ | |
| Extract entity spans from *text* and return a list of Entity objects. | |
| Spans are character-level (start/end index into the original string). | |
| Returns [] for empty or whitespace-only input. | |
| """ | |
| if not text or not text.strip(): | |
| return [] | |
| encoding = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| return_offsets_mapping=True, | |
| ) | |
| # offset_mapping is not a model input — pop before forward pass | |
| offset_mapping: list[tuple[int, int]] = ( | |
| encoding.pop("offset_mapping")[0].tolist() | |
| ) | |
| model_inputs = {k: v.to(self._device) for k, v in encoding.items()} | |
| with torch.no_grad(): | |
| logits = self.model(**model_inputs).logits[0] # (seq_len, num_labels) | |
| probs = torch.softmax(logits, dim=-1).cpu() | |
| pred_ids: list[int] = probs.argmax(dim=-1).tolist() | |
| conf_scores: list[float] = probs.max(dim=-1).values.tolist() | |
| return self._aggregate_spans(text, offset_mapping, pred_ids, conf_scores) | |
| # ------------------------------------------------------------------ | |
| # Span aggregation | |
| # ------------------------------------------------------------------ | |
| def _aggregate_spans( | |
| self, | |
| text: str, | |
| offset_mapping: list[tuple[int, int]], | |
| pred_ids: list[int], | |
| conf_scores: list[float], | |
| ) -> list[Entity]: | |
| """ | |
| Convert per-subtoken BIO predictions into character-level Entity spans. | |
| Special tokens ([CLS], [SEP]) have offset (0, 0) — i.e. start == end — | |
| and are skipped. An I- tag that does not continue the current B- entity | |
| type is treated as O (broken sequence). | |
| """ | |
| entities: list[Entity] = [] | |
| current: Optional[dict] = None | |
| current_confs: list[float] = [] | |
| def _flush() -> None: | |
| if current is not None: | |
| entities.append(Entity( | |
| text=current["text"], | |
| label=current["label"], | |
| start=current["start"], | |
| end=current["end"], | |
| confidence=sum(current_confs) / len(current_confs), | |
| )) | |
| for (start, end), label_id, conf in zip(offset_mapping, pred_ids, conf_scores): | |
| # Special tokens have zero-length offset spans | |
| if start == end: | |
| _flush() | |
| current = None | |
| current_confs = [] | |
| continue | |
| label = ID2LABEL[label_id] | |
| if label.startswith("B-"): | |
| _flush() | |
| entity_type = label[2:] | |
| current = { | |
| "text": text[start:end], | |
| "label": entity_type, | |
| "start": start, | |
| "end": end, | |
| } | |
| current_confs = [conf] | |
| elif ( | |
| label.startswith("I-") | |
| and current is not None | |
| and label[2:] == current["label"] | |
| ): | |
| # Extend the current span (including any whitespace between subwords) | |
| current["text"] = text[current["start"]: end] | |
| current["end"] = end | |
| current_confs.append(conf) | |
| else: # O or mismatched I- → close current span | |
| _flush() | |
| current = None | |
| current_confs = [] | |
| # Flush any span still open at end of sequence | |
| _flush() | |
| return entities | |