File size: 4,790 Bytes
89c6379 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """
main.py — RareDx FastAPI application.
Run with:
uvicorn backend.api.main:app --reload --port 8080
Or from the project root:
python -m uvicorn backend.api.main:app --reload --port 8080
Endpoints:
POST /diagnose — clinical note → differential diagnosis
GET /health — liveness check
GET /hpo/search?q=... — debug: find HPO terms by keyword
"""
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Ensure scripts/ importable
sys.path.insert(0, str(Path(__file__).parents[1] / "scripts"))
from .models import DiagnoseRequest, DiagnoseResponse, Candidate, HPOMatch
from .pipeline import DiagnosisPipeline
# Pipeline is loaded once at startup (model loading takes ~3s)
pipeline: DiagnosisPipeline | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pipeline
print("Starting up RareDx API...")
pipeline = DiagnosisPipeline()
print("API ready.")
yield
print("Shutting down.")
app = FastAPI(
title="RareDx API",
description=(
"Multi-agent clinical AI for rare disease diagnosis. "
"Combines knowledge graph (Orphanet/HPO) with BioLORD-2023 "
"semantic embeddings to generate differential diagnoses from "
"plain-text clinical notes."
),
version="0.2.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
return {
"status": "ok",
"pipeline_ready": pipeline is not None,
"graph_backend": pipeline.graph_backend if pipeline else None,
"chroma_backend": pipeline.chroma_backend if pipeline else None,
}
@app.post("/diagnose", response_model=DiagnoseResponse)
def diagnose(request: DiagnoseRequest):
if pipeline is None:
raise HTTPException(status_code=503, detail="Pipeline not initialised.")
result = pipeline.diagnose(
note=request.note,
top_n=request.top_n,
threshold=request.threshold,
)
def _to_candidate(c: dict) -> Candidate:
return Candidate(
rank = c["rank"],
orpha_code = c["orpha_code"],
name = c["name"],
definition = c.get("definition") or None,
rrf_score = c["rrf_score"],
graph_rank = c.get("graph_rank"),
chroma_rank = c.get("chroma_rank"),
graph_matches = c.get("graph_matches"),
chroma_sim = c.get("chroma_sim"),
matched_hpo = c.get("matched_hpo", []),
hallucination_flag = c.get("hallucination_flag", False),
flag_reason = c.get("flag_reason"),
evidence_score = c.get("evidence_score", 0.0),
)
candidates = [_to_candidate(c) for c in result["candidates"]]
passed_candidates = [_to_candidate(c) for c in result["passed_candidates"]]
flagged_candidates= [_to_candidate(c) for c in result["flagged_candidates"]]
hpo_matches = [HPOMatch(**m) for m in result["hpo_matches"]]
top = passed_candidates[0] if passed_candidates else (candidates[0] if candidates else None)
return DiagnoseResponse(
note = result["note"],
phrases_extracted = result["phrases_extracted"],
hpo_matches = hpo_matches,
hpo_ids_used = result["hpo_ids_used"],
candidates = candidates,
passed_candidates = passed_candidates,
flagged_candidates = flagged_candidates,
top_diagnosis = top,
graph_backend = result["graph_backend"],
chroma_backend = result["chroma_backend"],
elapsed_seconds = result["elapsed_seconds"],
)
@app.get("/hpo/search")
def hpo_search(q: str, limit: int = 10):
"""Debug endpoint: find HPO terms by keyword in graph store."""
if pipeline is None:
raise HTTPException(status_code=503, detail="Pipeline not initialised.")
store = pipeline.graph_store
q_lower = q.lower()
results = []
for _, attrs in store.graph.nodes(data=True):
if attrs.get("type") == "HPOTerm":
if q_lower in attrs.get("term", "").lower():
results.append({"hpo_id": attrs["hpo_id"], "term": attrs["term"]})
if len(results) >= limit:
break
return {"query": q, "results": results}
|