File size: 7,812 Bytes
abfd704 84981a4 abfd704 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | """Evaluation harness — NoteGuard's 'reliable' pillar.
Because the dataset's PII lives in structured tables, every note has ground-truth
identifiers. We measure two things Presidio alone never reports:
1. Detection quality : per-entity precision / recall / F1 against known PII.
2. Residual leakage : after sanitisation, how many KNOWN identifiers still
appear in the output text. This is the headline number —
an honest, measurable re-identification risk.
Caveat we state openly: precision is measured against *structured* PII only.
A note may contain PII not present in the tables (e.g. a clinician's name); a
correct detection of it counts here as a false positive, so reported precision is
a conservative lower bound. Recall and leakage are unaffected.
"""
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from .data import NoteRecord
from .detect import Detector
from .recognisers import Span
from .transform import REDACTION, PseudonymVault, apply_transform
_DATE_FORMATS = ["%d/%m/%Y", "%d-%m-%Y", "%Y-%m-%d", "%d/%m/%y", "%d-%m-%y", "%d %b %Y", "%d %B %Y"]
def _date_variants(value: str) -> list[str]:
for fmt in _DATE_FORMATS:
try:
dt = datetime.strptime(value.strip(), fmt)
return list({dt.strftime(f) for f in _DATE_FORMATS})
except ValueError:
continue
return [value]
def value_variants(value: str, entity_type: str) -> list[str]:
"""Surface forms of a known PII value as it might appear in free text."""
value = value.strip()
if not value:
return []
if entity_type == "PERSON":
parts = value.split()
out = [value]
if len(parts) > 1:
out.append(parts[-1]) # surname alone
out.append(parts[0]) # forename alone
return out
if entity_type == "UK_NHS":
digits = "".join(ch for ch in value if ch.isdigit())
out = {value, digits}
if len(digits) == 10:
out.add(f"{digits[:3]} {digits[3:6]} {digits[6:]}")
out.add(f"{digits[:3]}-{digits[3:6]}-{digits[6:]}")
return list(out)
if entity_type == "DATE_TIME":
return _date_variants(value)
return [value]
def _find_all(haystack: str, needle: str) -> list[tuple[int, int]]:
"""Case-insensitive, word-boundary-aware occurrences of needle in haystack."""
if not needle:
return []
hl, nl = haystack.lower(), needle.lower()
spots: list[tuple[int, int]] = []
start = 0
while True:
i = hl.find(nl, start)
if i == -1:
break
left_ok = i == 0 or not (hl[i - 1].isalnum())
right_ok = i + len(nl) == len(hl) or not (hl[i + len(nl)].isalnum())
if left_ok and right_ok:
spots.append((i, i + len(nl)))
start = i + 1
return spots
def ground_truth_spans(record: NoteRecord) -> list[Span]:
"""Locate each known PII value (and its surface variants) inside the note."""
occ: list[Span] = []
for gt in record.ground_truth:
for variant in value_variants(gt.text, gt.entity_type):
if len(variant) < 2:
continue
for s, e in _find_all(record.text, variant):
occ.append(Span(s, e, gt.entity_type, record.text[s:e]))
return _dedupe(occ)
def _dedupe(spans: list[Span]) -> list[Span]:
seen: set[tuple[int, int]] = set()
out: list[Span] = []
for s in sorted(spans, key=lambda x: (x.start, -(x.end - x.start))):
if any(s.start >= a and s.end <= b for (a, b) in seen):
continue
seen.add((s.start, s.end))
out.append(s)
return out
def _overlaps(a: Span, b: Span) -> bool:
return a.start < b.end and b.start < a.end
@dataclass
class Counter:
tp: int = 0
fp: int = 0
fn: int = 0
@property
def precision(self) -> float:
return self.tp / (self.tp + self.fp) if (self.tp + self.fp) else 0.0
@property
def recall(self) -> float:
return self.tp / (self.tp + self.fn) if (self.tp + self.fn) else 0.0
@property
def f1(self) -> float:
p, r = self.precision, self.recall
return 2 * p * r / (p + r) if (p + r) else 0.0
@dataclass
class EvalResult:
notes: int = 0
per_entity: dict[str, Counter] = field(default_factory=lambda: defaultdict(Counter))
overall: Counter = field(default_factory=Counter)
total_gt_occurrences: int = 0
residual_leaks: int = 0
transform_method: str = REDACTION
detector_name: str = ""
@property
def leakage_rate(self) -> float:
return self.residual_leaks / self.total_gt_occurrences if self.total_gt_occurrences else 0.0
def to_dict(self) -> dict:
return {
"detector": self.detector_name,
"transform": self.transform_method,
"notes_evaluated": self.notes,
"detection": {
"overall": {
"precision": round(self.overall.precision, 4),
"recall": round(self.overall.recall, 4),
"f1": round(self.overall.f1, 4),
"tp": self.overall.tp, "fp": self.overall.fp, "fn": self.overall.fn,
},
"per_entity": {
et: {
"precision": round(c.precision, 4),
"recall": round(c.recall, 4),
"f1": round(c.f1, 4),
"support": c.tp + c.fn,
}
for et, c in sorted(self.per_entity.items())
},
},
"leakage": {
"total_known_pii_occurrences": self.total_gt_occurrences,
"residual_leaks_after_sanitisation": self.residual_leaks,
"leakage_rate": round(self.leakage_rate, 4),
"leakage_rate_pct": round(100 * self.leakage_rate, 2),
},
}
def evaluate(
records: list[NoteRecord],
detector: Detector,
transform_method: str = REDACTION,
) -> EvalResult:
res = EvalResult(transform_method=transform_method, detector_name=getattr(detector, "name", "?"))
for rec in records:
if not rec.text:
continue
res.notes += 1
gt = ground_truth_spans(rec)
detected = detector.detect(rec.text)
# ---- detection precision / recall (overlap-based) ----
matched_det: set[int] = set()
for g in gt:
hit = next((i for i, d in enumerate(detected)
if i not in matched_det and _overlaps(g, d)), None)
if hit is not None:
matched_det.add(hit)
res.per_entity[g.entity_type].tp += 1
res.overall.tp += 1
else:
res.per_entity[g.entity_type].fn += 1
res.overall.fn += 1
for i, d in enumerate(detected):
if i not in matched_det:
res.per_entity[d.entity_type].fp += 1
res.overall.fp += 1
# ---- residual leakage after sanitisation ----
vault = PseudonymVault()
sanitised, _ = apply_transform(
rec.text, detected, transform_method, vault, rec.person_id
)
res.total_gt_occurrences += len(gt)
# a known value leaks if any of its surface variants survives in output
for g in gt:
leaked = False
for variant in value_variants(g.text, g.entity_type):
if len(variant) >= 2 and _find_all(sanitised, variant):
leaked = True
break
if leaked:
res.residual_leaks += 1
return res
|