"""Evaluation harness — NoteGuard's 'reliable' pillar. Because the dataset's PII lives in structured tables, every note has ground-truth identifiers. We measure two things Presidio alone never reports: 1. Detection quality : per-entity precision / recall / F1 against known PII. 2. Residual leakage : after sanitisation, how many KNOWN identifiers still appear in the output text. This is the headline number — an honest, measurable re-identification risk. Caveat we state openly: precision is measured against *structured* PII only. A note may contain PII not present in the tables (e.g. a clinician's name); a correct detection of it counts here as a false positive, so reported precision is a conservative lower bound. Recall and leakage are unaffected. """ from __future__ import annotations from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime from .data import NoteRecord from .detect import Detector from .recognisers import Span from .transform import REDACTION, PseudonymVault, apply_transform _DATE_FORMATS = ["%d/%m/%Y", "%d-%m-%Y", "%Y-%m-%d", "%d/%m/%y", "%d-%m-%y", "%d %b %Y", "%d %B %Y"] def _date_variants(value: str) -> list[str]: for fmt in _DATE_FORMATS: try: dt = datetime.strptime(value.strip(), fmt) return list({dt.strftime(f) for f in _DATE_FORMATS}) except ValueError: continue return [value] def value_variants(value: str, entity_type: str) -> list[str]: """Surface forms of a known PII value as it might appear in free text.""" value = value.strip() if not value: return [] if entity_type == "PERSON": parts = value.split() out = [value] if len(parts) > 1: out.append(parts[-1]) # surname alone out.append(parts[0]) # forename alone return out if entity_type == "UK_NHS": digits = "".join(ch for ch in value if ch.isdigit()) out = {value, digits} if len(digits) == 10: out.add(f"{digits[:3]} {digits[3:6]} {digits[6:]}") out.add(f"{digits[:3]}-{digits[3:6]}-{digits[6:]}") return list(out) if entity_type == "DATE_TIME": return _date_variants(value) return [value] def _find_all(haystack: str, needle: str) -> list[tuple[int, int]]: """Case-insensitive, word-boundary-aware occurrences of needle in haystack.""" if not needle: return [] hl, nl = haystack.lower(), needle.lower() spots: list[tuple[int, int]] = [] start = 0 while True: i = hl.find(nl, start) if i == -1: break left_ok = i == 0 or not (hl[i - 1].isalnum()) right_ok = i + len(nl) == len(hl) or not (hl[i + len(nl)].isalnum()) if left_ok and right_ok: spots.append((i, i + len(nl))) start = i + 1 return spots def ground_truth_spans(record: NoteRecord) -> list[Span]: """Locate each known PII value (and its surface variants) inside the note.""" occ: list[Span] = [] for gt in record.ground_truth: for variant in value_variants(gt.text, gt.entity_type): if len(variant) < 2: continue for s, e in _find_all(record.text, variant): occ.append(Span(s, e, gt.entity_type, record.text[s:e])) return _dedupe(occ) def _dedupe(spans: list[Span]) -> list[Span]: seen: set[tuple[int, int]] = set() out: list[Span] = [] for s in sorted(spans, key=lambda x: (x.start, -(x.end - x.start))): if any(s.start >= a and s.end <= b for (a, b) in seen): continue seen.add((s.start, s.end)) out.append(s) return out def _overlaps(a: Span, b: Span) -> bool: return a.start < b.end and b.start < a.end @dataclass class Counter: tp: int = 0 fp: int = 0 fn: int = 0 @property def precision(self) -> float: return self.tp / (self.tp + self.fp) if (self.tp + self.fp) else 0.0 @property def recall(self) -> float: return self.tp / (self.tp + self.fn) if (self.tp + self.fn) else 0.0 @property def f1(self) -> float: p, r = self.precision, self.recall return 2 * p * r / (p + r) if (p + r) else 0.0 @dataclass class EvalResult: notes: int = 0 per_entity: dict[str, Counter] = field(default_factory=lambda: defaultdict(Counter)) overall: Counter = field(default_factory=Counter) total_gt_occurrences: int = 0 residual_leaks: int = 0 transform_method: str = REDACTION detector_name: str = "" @property def leakage_rate(self) -> float: return self.residual_leaks / self.total_gt_occurrences if self.total_gt_occurrences else 0.0 def to_dict(self) -> dict: return { "detector": self.detector_name, "transform": self.transform_method, "notes_evaluated": self.notes, "detection": { "overall": { "precision": round(self.overall.precision, 4), "recall": round(self.overall.recall, 4), "f1": round(self.overall.f1, 4), "tp": self.overall.tp, "fp": self.overall.fp, "fn": self.overall.fn, }, "per_entity": { et: { "precision": round(c.precision, 4), "recall": round(c.recall, 4), "f1": round(c.f1, 4), "support": c.tp + c.fn, } for et, c in sorted(self.per_entity.items()) }, }, "leakage": { "total_known_pii_occurrences": self.total_gt_occurrences, "residual_leaks_after_sanitisation": self.residual_leaks, "leakage_rate": round(self.leakage_rate, 4), "leakage_rate_pct": round(100 * self.leakage_rate, 2), }, } def evaluate( records: list[NoteRecord], detector: Detector, transform_method: str = REDACTION, ) -> EvalResult: res = EvalResult(transform_method=transform_method, detector_name=getattr(detector, "name", "?")) for rec in records: if not rec.text: continue res.notes += 1 gt = ground_truth_spans(rec) detected = detector.detect(rec.text) # ---- detection precision / recall (overlap-based) ---- matched_det: set[int] = set() for g in gt: hit = next((i for i, d in enumerate(detected) if i not in matched_det and _overlaps(g, d)), None) if hit is not None: matched_det.add(hit) res.per_entity[g.entity_type].tp += 1 res.overall.tp += 1 else: res.per_entity[g.entity_type].fn += 1 res.overall.fn += 1 for i, d in enumerate(detected): if i not in matched_det: res.per_entity[d.entity_type].fp += 1 res.overall.fp += 1 # ---- residual leakage after sanitisation ---- vault = PseudonymVault() sanitised, _ = apply_transform( rec.text, detected, transform_method, vault, rec.person_id ) res.total_gt_occurrences += len(gt) # a known value leaks if any of its surface variants survives in output for g in gt: leaked = False for variant in value_variants(g.text, g.entity_type): if len(variant) >= 2 and _find_all(sanitised, variant): leaked = True break if leaked: res.residual_leaks += 1 return res