guide / src /ner /model.py
saravanakum1
add privacy layer tests with openspec and LangSmith integration
1f54b5c
Raw
History Blame Contribute Delete
6.77 kB
"""
EvidenceNER model definition.
Architecture: distilbert-base-uncased with token classification head.
Task: Named Entity Recognition on redacted complaint text.
Entity types: ORG | AMOUNT | DATE | REF_ID | ACCOUNT | PERSON
Input: Redacted complaint text (str, max 512 tokens after tokenisation).
Output: List of Entity(text, label, start, end, confidence).
BIO scheme: O + B-/I- prefix for each of the 6 entity types → 13 labels total.
Note: PERSON entities surviving Presidio redaction are role-references
(e.g. "customer care executive"), not personal names.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Label constants — shared by model.py, train.py, predict.py
# ---------------------------------------------------------------------------
NER_LABELS = ["ORG", "AMOUNT", "DATE", "REF_ID", "ACCOUNT", "PERSON"]
# O first, then B-/I- pairs in the same order as NER_LABELS → 13 labels
BIO_LABELS: list[str] = ["O"] + [
f"{bio}-{label}" for label in NER_LABELS for bio in ("B", "I")
]
LABEL2ID: dict[str, int] = {label: i for i, label in enumerate(BIO_LABELS)}
ID2LABEL: dict[int, str] = {i: label for label, i in LABEL2ID.items()}
NUM_LABELS = len(BIO_LABELS) # 13
# ---------------------------------------------------------------------------
# Public output type
# ---------------------------------------------------------------------------
@dataclass
class Entity:
"""A single recognised entity span."""
text: str
label: str
start: int
end: int
confidence: float
# ---------------------------------------------------------------------------
# EvidenceNER
# ---------------------------------------------------------------------------
class EvidenceNER:
"""
DistilBERT token classifier for complaint evidence extraction.
Loads a fine-tuned checkpoint produced by train.py. Uses the tokenizer's
offset_mapping to convert subword-level BIO predictions back to
character-level spans without any secondary tokenisation step.
"""
BASE_MODEL = "distilbert-base-uncased"
def __init__(self, model_dir: str) -> None:
"""Load a fine-tuned NER checkpoint from *model_dir*."""
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForTokenClassification.from_pretrained(model_dir)
self.model.eval()
self._device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
self.model.to(self._device)
logger.info("EvidenceNER loaded from %s on %s", model_dir, self._device)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def extract(self, text: str) -> list[Entity]:
"""
Extract entity spans from *text* and return a list of Entity objects.
Spans are character-level (start/end index into the original string).
Returns [] for empty or whitespace-only input.
"""
if not text or not text.strip():
return []
encoding = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
return_offsets_mapping=True,
)
# offset_mapping is not a model input — pop before forward pass
offset_mapping: list[tuple[int, int]] = (
encoding.pop("offset_mapping")[0].tolist()
)
model_inputs = {k: v.to(self._device) for k, v in encoding.items()}
with torch.no_grad():
logits = self.model(**model_inputs).logits[0] # (seq_len, num_labels)
probs = torch.softmax(logits, dim=-1).cpu()
pred_ids: list[int] = probs.argmax(dim=-1).tolist()
conf_scores: list[float] = probs.max(dim=-1).values.tolist()
return self._aggregate_spans(text, offset_mapping, pred_ids, conf_scores)
# ------------------------------------------------------------------
# Span aggregation
# ------------------------------------------------------------------
def _aggregate_spans(
self,
text: str,
offset_mapping: list[tuple[int, int]],
pred_ids: list[int],
conf_scores: list[float],
) -> list[Entity]:
"""
Convert per-subtoken BIO predictions into character-level Entity spans.
Special tokens ([CLS], [SEP]) have offset (0, 0) — i.e. start == end —
and are skipped. An I- tag that does not continue the current B- entity
type is treated as O (broken sequence).
"""
entities: list[Entity] = []
current: Optional[dict] = None
current_confs: list[float] = []
def _flush() -> None:
if current is not None:
entities.append(Entity(
text=current["text"],
label=current["label"],
start=current["start"],
end=current["end"],
confidence=sum(current_confs) / len(current_confs),
))
for (start, end), label_id, conf in zip(offset_mapping, pred_ids, conf_scores):
# Special tokens have zero-length offset spans
if start == end:
_flush()
current = None
current_confs = []
continue
label = ID2LABEL[label_id]
if label.startswith("B-"):
_flush()
entity_type = label[2:]
current = {
"text": text[start:end],
"label": entity_type,
"start": start,
"end": end,
}
current_confs = [conf]
elif (
label.startswith("I-")
and current is not None
and label[2:] == current["label"]
):
# Extend the current span (including any whitespace between subwords)
current["text"] = text[current["start"]: end]
current["end"] = end
current_confs.append(conf)
else: # O or mismatched I- → close current span
_flush()
current = None
current_confs = []
# Flush any span still open at end of sequence
_flush()
return entities