File size: 7,812 Bytes
abfd704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84981a4
abfd704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Evaluation harness — NoteGuard's 'reliable' pillar.

Because the dataset's PII lives in structured tables, every note has ground-truth
identifiers. We measure two things Presidio alone never reports:

  1. Detection quality  : per-entity precision / recall / F1 against known PII.
  2. Residual leakage   : after sanitisation, how many KNOWN identifiers still
                          appear in the output text. This is the headline number —
                          an honest, measurable re-identification risk.

Caveat we state openly: precision is measured against *structured* PII only.
A note may contain PII not present in the tables (e.g. a clinician's name); a
correct detection of it counts here as a false positive, so reported precision is
a conservative lower bound. Recall and leakage are unaffected.
"""
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime

from .data import NoteRecord
from .detect import Detector
from .recognisers import Span
from .transform import REDACTION, PseudonymVault, apply_transform

_DATE_FORMATS = ["%d/%m/%Y", "%d-%m-%Y", "%Y-%m-%d", "%d/%m/%y", "%d-%m-%y", "%d %b %Y", "%d %B %Y"]


def _date_variants(value: str) -> list[str]:
    for fmt in _DATE_FORMATS:
        try:
            dt = datetime.strptime(value.strip(), fmt)
            return list({dt.strftime(f) for f in _DATE_FORMATS})
        except ValueError:
            continue
    return [value]


def value_variants(value: str, entity_type: str) -> list[str]:
    """Surface forms of a known PII value as it might appear in free text."""
    value = value.strip()
    if not value:
        return []
    if entity_type == "PERSON":
        parts = value.split()
        out = [value]
        if len(parts) > 1:
            out.append(parts[-1])   # surname alone
            out.append(parts[0])    # forename alone
        return out
    if entity_type == "UK_NHS":
        digits = "".join(ch for ch in value if ch.isdigit())
        out = {value, digits}
        if len(digits) == 10:
            out.add(f"{digits[:3]} {digits[3:6]} {digits[6:]}")
            out.add(f"{digits[:3]}-{digits[3:6]}-{digits[6:]}")
        return list(out)
    if entity_type == "DATE_TIME":
        return _date_variants(value)
    return [value]


def _find_all(haystack: str, needle: str) -> list[tuple[int, int]]:
    """Case-insensitive, word-boundary-aware occurrences of needle in haystack."""
    if not needle:
        return []
    hl, nl = haystack.lower(), needle.lower()
    spots: list[tuple[int, int]] = []
    start = 0
    while True:
        i = hl.find(nl, start)
        if i == -1:
            break
        left_ok = i == 0 or not (hl[i - 1].isalnum())
        right_ok = i + len(nl) == len(hl) or not (hl[i + len(nl)].isalnum())
        if left_ok and right_ok:
            spots.append((i, i + len(nl)))
        start = i + 1
    return spots


def ground_truth_spans(record: NoteRecord) -> list[Span]:
    """Locate each known PII value (and its surface variants) inside the note."""
    occ: list[Span] = []
    for gt in record.ground_truth:
        for variant in value_variants(gt.text, gt.entity_type):
            if len(variant) < 2:
                continue
            for s, e in _find_all(record.text, variant):
                occ.append(Span(s, e, gt.entity_type, record.text[s:e]))
    return _dedupe(occ)


def _dedupe(spans: list[Span]) -> list[Span]:
    seen: set[tuple[int, int]] = set()
    out: list[Span] = []
    for s in sorted(spans, key=lambda x: (x.start, -(x.end - x.start))):
        if any(s.start >= a and s.end <= b for (a, b) in seen):
            continue
        seen.add((s.start, s.end))
        out.append(s)
    return out


def _overlaps(a: Span, b: Span) -> bool:
    return a.start < b.end and b.start < a.end


@dataclass
class Counter:
    tp: int = 0
    fp: int = 0
    fn: int = 0

    @property
    def precision(self) -> float:
        return self.tp / (self.tp + self.fp) if (self.tp + self.fp) else 0.0

    @property
    def recall(self) -> float:
        return self.tp / (self.tp + self.fn) if (self.tp + self.fn) else 0.0

    @property
    def f1(self) -> float:
        p, r = self.precision, self.recall
        return 2 * p * r / (p + r) if (p + r) else 0.0


@dataclass
class EvalResult:
    notes: int = 0
    per_entity: dict[str, Counter] = field(default_factory=lambda: defaultdict(Counter))
    overall: Counter = field(default_factory=Counter)
    total_gt_occurrences: int = 0
    residual_leaks: int = 0
    transform_method: str = REDACTION
    detector_name: str = ""

    @property
    def leakage_rate(self) -> float:
        return self.residual_leaks / self.total_gt_occurrences if self.total_gt_occurrences else 0.0

    def to_dict(self) -> dict:
        return {
            "detector": self.detector_name,
            "transform": self.transform_method,
            "notes_evaluated": self.notes,
            "detection": {
                "overall": {
                    "precision": round(self.overall.precision, 4),
                    "recall": round(self.overall.recall, 4),
                    "f1": round(self.overall.f1, 4),
                    "tp": self.overall.tp, "fp": self.overall.fp, "fn": self.overall.fn,
                },
                "per_entity": {
                    et: {
                        "precision": round(c.precision, 4),
                        "recall": round(c.recall, 4),
                        "f1": round(c.f1, 4),
                        "support": c.tp + c.fn,
                    }
                    for et, c in sorted(self.per_entity.items())
                },
            },
            "leakage": {
                "total_known_pii_occurrences": self.total_gt_occurrences,
                "residual_leaks_after_sanitisation": self.residual_leaks,
                "leakage_rate": round(self.leakage_rate, 4),
                "leakage_rate_pct": round(100 * self.leakage_rate, 2),
            },
        }


def evaluate(
    records: list[NoteRecord],
    detector: Detector,
    transform_method: str = REDACTION,
) -> EvalResult:
    res = EvalResult(transform_method=transform_method, detector_name=getattr(detector, "name", "?"))
    for rec in records:
        if not rec.text:
            continue
        res.notes += 1
        gt = ground_truth_spans(rec)
        detected = detector.detect(rec.text)

        # ---- detection precision / recall (overlap-based) ----
        matched_det: set[int] = set()
        for g in gt:
            hit = next((i for i, d in enumerate(detected)
                        if i not in matched_det and _overlaps(g, d)), None)
            if hit is not None:
                matched_det.add(hit)
                res.per_entity[g.entity_type].tp += 1
                res.overall.tp += 1
            else:
                res.per_entity[g.entity_type].fn += 1
                res.overall.fn += 1
        for i, d in enumerate(detected):
            if i not in matched_det:
                res.per_entity[d.entity_type].fp += 1
                res.overall.fp += 1

        # ---- residual leakage after sanitisation ----
        vault = PseudonymVault()
        sanitised, _ = apply_transform(
            rec.text, detected, transform_method, vault, rec.person_id
        )
        res.total_gt_occurrences += len(gt)
        # a known value leaks if any of its surface variants survives in output
        for g in gt:
            leaked = False
            for variant in value_variants(g.text, g.entity_type):
                if len(variant) >= 2 and _find_all(sanitised, variant):
                    leaked = True
                    break
            if leaked:
                res.residual_leaks += 1
    return res