File size: 9,606 Bytes
eb83689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""NoteGuard FastAPI backend — PHI-safe REST endpoint for the LangGraph agent.

Exposes:
  GET  /              -> index.html (clinician web UI)
  GET  /health        -> {"status": "ok"}
  GET  /samples       -> paginated list of synthetic notes (requires data/ dir)
  GET  /sample/random -> one random synthetic note
  GET  /sample/{id}   -> full note by clinical_note_id
  POST /summarise     -> {clinician_answer, identifiers_removed, residual_risk,
                          deidentified_excerpt, ok}
  POST /process       -> {clinician_note, ai_note, identifiers, discharge_summary, metrics}

The assert_clean() guarantee is preserved: the graph raises ValueError if any
identifier survives de-identification, which surfaces here as HTTP 422.

Run:  uvicorn app.api:app --reload --port 8000
"""

from __future__ import annotations

import csv
import random
from pathlib import Path

from dotenv import load_dotenv

load_dotenv(override=True)

from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from src.deid import NoteGuard, load_known_from_csv

STATIC_DIR = Path(__file__).parent / "static"
_DATA_DIR = Path(__file__).parent.parent / "data"

app = FastAPI(title="NoteGuard API", version="1.2.4")

# ---------------------------------------------------------------------------
# Dataset — loaded once at startup; degrades gracefully when data/ is absent
# ---------------------------------------------------------------------------

_NOTES: list[dict] = []
_DEFAULT_KNOWN: dict | None = None

try:
    _patients_csv = str(_DATA_DIR / "patients.csv")
    _admissions_csv = str(_DATA_DIR / "admissions.csv")
    _DEFAULT_KNOWN = load_known_from_csv(_patients_csv, _admissions_csv)

    with open(_DATA_DIR / "synthetic_clinical_notes.csv", newline="", encoding="utf-8-sig") as _f:
        for _row in csv.DictReader(_f):
            _text = NoteGuard._fix_mojibake(_row["clean_note_text"])
            _NOTES.append(
                {
                    "clinical_note_id": _row["clinical_note_id"],
                    "person_id": _row["person_id"],
                    "note_type": _row.get("note_type", ""),
                    "note_subject": _row.get("note_subject", ""),
                    "excerpt": _text[:120].strip(),
                    "note_text": _text,
                }
            )
except Exception:
    pass  # data/ not present — /samples returns empty, /process still works

# ---------------------------------------------------------------------------
# Per-vault graph cache — key is a hashable snapshot of the known-identifier dict.
# ---------------------------------------------------------------------------

_graph_cache: dict = {}


def _vault_key(known: dict | None) -> tuple | None:
    if not known:
        return None
    return tuple(sorted((k, tuple(sorted(v))) for k, v in known.items()))


def _get_graph(known: dict | None):
    """Return a compiled NoteGuard graph, building it once per distinct vault."""
    key = _vault_key(known)
    if key not in _graph_cache:
        from agent.graph import build_graph

        _graph_cache[key] = build_graph(known=known)
    return _graph_cache[key]


# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------


class SummariseRequest(BaseModel):
    note: str
    question: str = "Draft an NHS eDischarge summary."
    known: dict | None = None


class SummariseResponse(BaseModel):
    clinician_answer: str
    identifiers_removed: int
    residual_risk: float
    deidentified_excerpt: str
    ok: bool


class ProcessRequest(BaseModel):
    note: str
    question: str = "Draft an NHS eDischarge summary."
    known: dict | None = None
    person_id: str | None = None  # accepted for UI compatibility; unused (patient is never named)


class ProcessResponse(BaseModel):
    clinician_note: str
    ai_note: str
    identifiers: list[str]
    discharge_summary: str
    metrics: dict


class SampleItem(BaseModel):
    clinical_note_id: str
    person_id: str
    note_type: str
    excerpt: str


class SamplesResponse(BaseModel):
    total: int
    items: list[SampleItem]


class SampleDetail(BaseModel):
    clinical_note_id: str
    person_id: str
    note_type: str
    note_subject: str
    note_text: str


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------


@app.get("/")
def index():
    return FileResponse(STATIC_DIR / "index.html")


@app.get("/health")
def health():
    """Liveness probe — no API keys required."""
    return {"status": "ok", "notes_loaded": len(_NOTES)}


@app.get("/samples", response_model=SamplesResponse)
def samples(
    limit: int = Query(50, ge=1, le=200),
    offset: int = Query(0, ge=0),
    q: str = Query(""),
    note_type: str = Query(""),
):
    """Paginated list of synthetic notes with optional text/type filter."""
    hits = _NOTES
    if note_type:
        hits = [n for n in hits if n["note_type"] == note_type]
    if q:
        ql = q.lower()
        hits = [n for n in hits if ql in n["note_text"].lower() or ql in n["note_subject"].lower()]
    page = hits[offset : offset + limit]
    return SamplesResponse(
        total=len(hits),
        items=[
            SampleItem(
                clinical_note_id=n["clinical_note_id"],
                person_id=n["person_id"],
                note_type=n["note_type"],
                excerpt=n["excerpt"],
            )
            for n in page
        ],
    )


@app.get("/sample/random", response_model=SampleDetail)
def sample_random():
    """Return one random synthetic note."""
    if not _NOTES:
        raise HTTPException(status_code=404, detail="No notes loaded — run src/fetch_dataset.py first.")
    note = random.choice(_NOTES)
    return SampleDetail(**{k: note[k] for k in SampleDetail.model_fields})


@app.get("/sample/{clinical_note_id}", response_model=SampleDetail)
def sample_by_id(clinical_note_id: str):
    """Return a single synthetic note by its clinical_note_id."""
    for note in _NOTES:
        if note["clinical_note_id"] == clinical_note_id:
            return SampleDetail(**{k: note[k] for k in SampleDetail.model_fields})
    raise HTTPException(status_code=404, detail=f"Note {clinical_note_id!r} not found.")


@app.post("/summarise", response_model=SummariseResponse)
def summarise(req: SummariseRequest):
    """Run the NoteGuard agent and return a PHI-safe discharge summary.

    Raises:
        HTTPException 422: assert_clean() detected surviving PHI.
        HTTPException 500: unexpected agent error.
    """
    known = req.known if req.known is not None else _DEFAULT_KNOWN
    try:
        g = _get_graph(known)
        state = g.invoke({"messages": [HumanMessage(content=req.note + "\n\n" + req.question)]})
    except ValueError as exc:
        raise HTTPException(status_code=422, detail=str(exc)) from exc
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc)) from exc

    # De-id is correct iff nothing leaked to the model AND every surrogate reverses.
    residual_pii = state.get("residual_pii") or []
    leaked = state.get("leaked_tokens") or []
    ok = not residual_pii and not leaked
    return SummariseResponse(
        clinician_answer=state.get("clinician_answer", ""),
        identifiers_removed=len(state.get("forward", {})),
        residual_risk=0.0 if ok else 1.0,
        deidentified_excerpt=(state.get("deid_text") or "")[:400],
        ok=ok,
    )


@app.post("/process", response_model=ProcessResponse)
def process(req: ProcessRequest):
    """Run NoteGuard and return rich output for the clinician UI.

    When req.known is omitted, uses the pre-built vault from data/patients.csv
    so residual-leakage is measured against ground truth identifiers.
    """
    known = req.known if req.known is not None else _DEFAULT_KNOWN
    try:
        g = _get_graph(known)
        state = g.invoke({"messages": [HumanMessage(content=req.note + "\n\n" + req.question)]})
    except ValueError as exc:
        raise HTTPException(status_code=422, detail=str(exc)) from exc
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc)) from exc

    forward = state.get("forward") or {}
    leaked = state.get("leaked_tokens") or []
    residual_pii = state.get("residual_pii") or []
    reversible = not leaked
    deid_ok = not residual_pii and reversible

    return ProcessResponse(
        clinician_note=req.note,
        ai_note=state.get("deid_text", ""),
        identifiers=list(forward.keys()),
        discharge_summary=state.get("clinician_answer", ""),
        metrics={
            # Every metric reports whether reversible pseudonymisation was done correctly.
            "deid_ok": deid_ok,  # overall verdict: nothing leaked AND fully reversible
            "identifiers_removed": len(forward),  # PII spans pseudonymised this turn
            "residual_pii": residual_pii,  # [{type, text}] PII the model still saw
            "residual_pii_count": len(residual_pii),
            "reversible": reversible,  # every surrogate restores to a real value
            "leaked_tokens": leaked,  # orphaned/unresolved surrogate tokens
        },
    )


if STATIC_DIR.exists():
    app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")