File size: 4,790 Bytes
89c6379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
main.py — RareDx FastAPI application.

Run with:
  uvicorn backend.api.main:app --reload --port 8080

Or from the project root:
  python -m uvicorn backend.api.main:app --reload --port 8080

Endpoints:
  POST /diagnose          — clinical note → differential diagnosis
  GET  /health            — liveness check
  GET  /hpo/search?q=...  — debug: find HPO terms by keyword
"""

import sys
from contextlib import asynccontextmanager
from pathlib import Path

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware

# Ensure scripts/ importable
sys.path.insert(0, str(Path(__file__).parents[1] / "scripts"))

from .models import DiagnoseRequest, DiagnoseResponse, Candidate, HPOMatch
from .pipeline import DiagnosisPipeline

# Pipeline is loaded once at startup (model loading takes ~3s)
pipeline: DiagnosisPipeline | None = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global pipeline
    print("Starting up RareDx API...")
    pipeline = DiagnosisPipeline()
    print("API ready.")
    yield
    print("Shutting down.")


app = FastAPI(
    title="RareDx API",
    description=(
        "Multi-agent clinical AI for rare disease diagnosis. "
        "Combines knowledge graph (Orphanet/HPO) with BioLORD-2023 "
        "semantic embeddings to generate differential diagnoses from "
        "plain-text clinical notes."
    ),
    version="0.2.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/health")
def health():
    return {
        "status": "ok",
        "pipeline_ready": pipeline is not None,
        "graph_backend":  pipeline.graph_backend  if pipeline else None,
        "chroma_backend": pipeline.chroma_backend if pipeline else None,
    }


@app.post("/diagnose", response_model=DiagnoseResponse)
def diagnose(request: DiagnoseRequest):
    if pipeline is None:
        raise HTTPException(status_code=503, detail="Pipeline not initialised.")

    result = pipeline.diagnose(
        note=request.note,
        top_n=request.top_n,
        threshold=request.threshold,
    )

    def _to_candidate(c: dict) -> Candidate:
        return Candidate(
            rank               = c["rank"],
            orpha_code         = c["orpha_code"],
            name               = c["name"],
            definition         = c.get("definition") or None,
            rrf_score          = c["rrf_score"],
            graph_rank         = c.get("graph_rank"),
            chroma_rank        = c.get("chroma_rank"),
            graph_matches      = c.get("graph_matches"),
            chroma_sim         = c.get("chroma_sim"),
            matched_hpo        = c.get("matched_hpo", []),
            hallucination_flag = c.get("hallucination_flag", False),
            flag_reason        = c.get("flag_reason"),
            evidence_score     = c.get("evidence_score", 0.0),
        )

    candidates        = [_to_candidate(c) for c in result["candidates"]]
    passed_candidates = [_to_candidate(c) for c in result["passed_candidates"]]
    flagged_candidates= [_to_candidate(c) for c in result["flagged_candidates"]]
    hpo_matches       = [HPOMatch(**m) for m in result["hpo_matches"]]
    top               = passed_candidates[0] if passed_candidates else (candidates[0] if candidates else None)

    return DiagnoseResponse(
        note               = result["note"],
        phrases_extracted  = result["phrases_extracted"],
        hpo_matches        = hpo_matches,
        hpo_ids_used       = result["hpo_ids_used"],
        candidates         = candidates,
        passed_candidates  = passed_candidates,
        flagged_candidates = flagged_candidates,
        top_diagnosis      = top,
        graph_backend      = result["graph_backend"],
        chroma_backend     = result["chroma_backend"],
        elapsed_seconds    = result["elapsed_seconds"],
    )


@app.get("/hpo/search")
def hpo_search(q: str, limit: int = 10):
    """Debug endpoint: find HPO terms by keyword in graph store."""
    if pipeline is None:
        raise HTTPException(status_code=503, detail="Pipeline not initialised.")

    store   = pipeline.graph_store
    q_lower = q.lower()
    results = []
    for _, attrs in store.graph.nodes(data=True):
        if attrs.get("type") == "HPOTerm":
            if q_lower in attrs.get("term", "").lower():
                results.append({"hpo_id": attrs["hpo_id"], "term": attrs["term"]})
                if len(results) >= limit:
                    break
    return {"query": q, "results": results}