Spaces:
Running
Running
| """NoteGuard FastAPI backend — PHI-safe REST endpoint for the LangGraph agent. | |
| Exposes: | |
| GET / -> index.html (clinician web UI) | |
| GET /health -> {"status": "ok"} | |
| GET /samples -> paginated list of synthetic notes (requires data/ dir) | |
| GET /sample/random -> one random synthetic note | |
| GET /sample/{id} -> full note by clinical_note_id | |
| POST /summarise -> {clinician_answer, identifiers_removed, residual_risk, | |
| deidentified_excerpt, ok} | |
| POST /process -> {clinician_note, ai_note, identifiers, discharge_summary, metrics} | |
| The assert_clean() guarantee is preserved: the graph raises ValueError if any | |
| identifier survives de-identification, which surfaces here as HTTP 422. | |
| Run: uvicorn app.api:app --reload --port 8000 | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import random | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from langchain_core.messages import HumanMessage | |
| from pydantic import BaseModel | |
| from src.deid import NoteGuard, load_known_from_csv | |
| STATIC_DIR = Path(__file__).parent / "static" | |
| _DATA_DIR = Path(__file__).parent.parent / "data" | |
| app = FastAPI(title="NoteGuard API", version="0.1.0") | |
| # --------------------------------------------------------------------------- | |
| # Dataset — loaded once at startup; degrades gracefully when data/ is absent | |
| # --------------------------------------------------------------------------- | |
| _NOTES: list[dict] = [] | |
| _DEFAULT_KNOWN: dict | None = None | |
| _PATIENT_NAMES: dict[str, str] = {} # person_id -> full_name for {{PATIENT}} resolution | |
| try: | |
| _patients_csv = str(_DATA_DIR / "patients.csv") | |
| _admissions_csv = str(_DATA_DIR / "admissions.csv") | |
| _DEFAULT_KNOWN = load_known_from_csv(_patients_csv, _admissions_csv) | |
| # Build person_id → name lookup for {{PATIENT}} resolution | |
| with open(_patients_csv, newline="", encoding="utf-8-sig") as _pf: | |
| for _row in csv.DictReader(_pf): | |
| _pid = (_row.get("person_id") or "").strip() | |
| _name = ( | |
| _row.get("full_name") | |
| or _row.get("patient_name") | |
| or f"{_row.get('first_name', '').strip()} {_row.get('surname', '').strip()}".strip() | |
| ) | |
| if _pid and _name: | |
| _PATIENT_NAMES[_pid] = _name | |
| with open(_DATA_DIR / "synthetic_clinical_notes.csv", newline="", encoding="utf-8-sig") as _f: | |
| for _row in csv.DictReader(_f): | |
| _text = NoteGuard._fix_mojibake(_row["clean_note_text"]) | |
| _NOTES.append( | |
| { | |
| "clinical_note_id": _row["clinical_note_id"], | |
| "person_id": _row["person_id"], | |
| "note_type": _row.get("note_type", ""), | |
| "note_subject": _row.get("note_subject", ""), | |
| "excerpt": _text[:120].strip(), | |
| "note_text": _text, | |
| } | |
| ) | |
| except Exception: | |
| pass # data/ not present — /samples returns empty, /process still works | |
| # --------------------------------------------------------------------------- | |
| # Per-vault graph cache — key is a hashable snapshot of the known-identifier dict. | |
| # --------------------------------------------------------------------------- | |
| _graph_cache: dict = {} | |
| def _vault_key(known: dict | None) -> tuple | None: | |
| if not known: | |
| return None | |
| return tuple(sorted((k, tuple(sorted(v))) for k, v in known.items())) | |
| def _get_graph(known: dict | None): | |
| """Return a compiled NoteGuard graph, building it once per distinct vault.""" | |
| key = _vault_key(known) | |
| if key not in _graph_cache: | |
| from agent.graph import build_graph | |
| _graph_cache[key] = build_graph(known=known) | |
| return _graph_cache[key] | |
| # --------------------------------------------------------------------------- | |
| # Models | |
| # --------------------------------------------------------------------------- | |
| class SummariseRequest(BaseModel): | |
| note: str | |
| question: str = "Draft an NHS eDischarge summary." | |
| known: dict | None = None | |
| class SummariseResponse(BaseModel): | |
| clinician_answer: str | |
| identifiers_removed: int | |
| residual_risk: float | |
| deidentified_excerpt: str | |
| ok: bool | |
| class ProcessRequest(BaseModel): | |
| note: str | |
| question: str = "Draft an NHS eDischarge summary." | |
| known: dict | None = None | |
| person_id: str | None = None # when set, {{PATIENT}} resolves to the real name | |
| class ProcessResponse(BaseModel): | |
| clinician_note: str | |
| ai_note: str | |
| identifiers: list[str] | |
| discharge_summary: str | |
| metrics: dict | |
| class SampleItem(BaseModel): | |
| clinical_note_id: str | |
| person_id: str | |
| note_type: str | |
| excerpt: str | |
| class SamplesResponse(BaseModel): | |
| total: int | |
| items: list[SampleItem] | |
| class SampleDetail(BaseModel): | |
| clinical_note_id: str | |
| person_id: str | |
| note_type: str | |
| note_subject: str | |
| note_text: str | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def index(): | |
| return FileResponse(STATIC_DIR / "index.html") | |
| def health(): | |
| """Liveness probe — no API keys required.""" | |
| return {"status": "ok", "notes_loaded": len(_NOTES)} | |
| def samples( | |
| limit: int = Query(50, ge=1, le=200), | |
| offset: int = Query(0, ge=0), | |
| q: str = Query(""), | |
| note_type: str = Query(""), | |
| ): | |
| """Paginated list of synthetic notes with optional text/type filter.""" | |
| hits = _NOTES | |
| if note_type: | |
| hits = [n for n in hits if n["note_type"] == note_type] | |
| if q: | |
| ql = q.lower() | |
| hits = [n for n in hits if ql in n["note_text"].lower() or ql in n["note_subject"].lower()] | |
| page = hits[offset : offset + limit] | |
| return SamplesResponse( | |
| total=len(hits), | |
| items=[ | |
| SampleItem( | |
| clinical_note_id=n["clinical_note_id"], | |
| person_id=n["person_id"], | |
| note_type=n["note_type"], | |
| excerpt=n["excerpt"], | |
| ) | |
| for n in page | |
| ], | |
| ) | |
| def sample_random(): | |
| """Return one random synthetic note.""" | |
| if not _NOTES: | |
| raise HTTPException(status_code=404, detail="No notes loaded — run src/fetch_dataset.py first.") | |
| note = random.choice(_NOTES) | |
| return SampleDetail(**{k: note[k] for k in SampleDetail.model_fields}) | |
| def sample_by_id(clinical_note_id: str): | |
| """Return a single synthetic note by its clinical_note_id.""" | |
| for note in _NOTES: | |
| if note["clinical_note_id"] == clinical_note_id: | |
| return SampleDetail(**{k: note[k] for k in SampleDetail.model_fields}) | |
| raise HTTPException(status_code=404, detail=f"Note {clinical_note_id!r} not found.") | |
| def summarise(req: SummariseRequest): | |
| """Run the NoteGuard agent and return a PHI-safe discharge summary. | |
| Raises: | |
| HTTPException 422: assert_clean() detected surviving PHI. | |
| HTTPException 500: unexpected agent error. | |
| """ | |
| known = req.known if req.known is not None else _DEFAULT_KNOWN | |
| try: | |
| g = _get_graph(known) | |
| state = g.invoke({"messages": [HumanMessage(content=req.note + "\n\n" + req.question)]}) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=422, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| residual = state.get("residual_count", 0) | |
| return SummariseResponse( | |
| clinician_answer=state.get("clinician_answer", ""), | |
| identifiers_removed=len(state.get("forward", {})), | |
| residual_risk=0.0 if residual == 0 else 1.0, | |
| deidentified_excerpt=(state.get("deid_text") or "")[:400], | |
| ok=residual == 0, | |
| ) | |
| def process(req: ProcessRequest): | |
| """Run NoteGuard and return rich output for the clinician UI. | |
| When req.known is omitted, uses the pre-built vault from data/patients.csv | |
| so residual-leakage is measured against ground truth identifiers. | |
| """ | |
| known = req.known if req.known is not None else _DEFAULT_KNOWN | |
| person_name = _PATIENT_NAMES.get(req.person_id, "Patient") if req.person_id else "Patient" | |
| try: | |
| g = _get_graph(known) | |
| state = g.invoke( | |
| { | |
| "messages": [HumanMessage(content=req.note + "\n\n" + req.question)], | |
| "person_name": person_name, | |
| } | |
| ) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=422, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| forward = state.get("forward") or {} | |
| residual = state.get("residual_count", 0) | |
| leaked = state.get("leaked_tokens") or [] | |
| retrieved = state.get("retrieved_context") or [] | |
| faith = state.get("faithfulness_score", 0.0) | |
| has_leak = residual > 0 or bool(leaked) | |
| return ProcessResponse( | |
| clinician_note=req.note, | |
| ai_note=state.get("deid_text", ""), | |
| identifiers=list(forward.keys()), | |
| discharge_summary=state.get("clinician_answer", ""), | |
| metrics={ | |
| "identifiers_removed": len(forward), | |
| "residual_risk": 0.0 | |
| if not has_leak | |
| else min(1.0, (residual + len(leaked)) / max(len(forward), 1)), | |
| "grounded_sources": len(state.get("sources") or []), | |
| "faithfulness": faith if retrieved else None, | |
| "leaked_tokens": leaked, | |
| }, | |
| ) | |
| if STATIC_DIR.exists(): | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |