raredx / backend /api /main.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
main.py — RareDx FastAPI application.
Run with:
uvicorn backend.api.main:app --reload --port 8080
Or from the project root:
python -m uvicorn backend.api.main:app --reload --port 8080
Endpoints:
POST /diagnose — clinical note → differential diagnosis
GET /health — liveness check
GET /hpo/search?q=... — debug: find HPO terms by keyword
"""
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Ensure scripts/ importable
sys.path.insert(0, str(Path(__file__).parents[1] / "scripts"))
from .models import DiagnoseRequest, DiagnoseResponse, Candidate, HPOMatch
from .pipeline import DiagnosisPipeline
# Pipeline is loaded once at startup (model loading takes ~3s)
pipeline: DiagnosisPipeline | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pipeline
print("Starting up RareDx API...")
pipeline = DiagnosisPipeline()
print("API ready.")
yield
print("Shutting down.")
app = FastAPI(
title="RareDx API",
description=(
"Multi-agent clinical AI for rare disease diagnosis. "
"Combines knowledge graph (Orphanet/HPO) with BioLORD-2023 "
"semantic embeddings to generate differential diagnoses from "
"plain-text clinical notes."
),
version="0.2.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
return {
"status": "ok",
"pipeline_ready": pipeline is not None,
"graph_backend": pipeline.graph_backend if pipeline else None,
"chroma_backend": pipeline.chroma_backend if pipeline else None,
}
@app.post("/diagnose", response_model=DiagnoseResponse)
def diagnose(request: DiagnoseRequest):
if pipeline is None:
raise HTTPException(status_code=503, detail="Pipeline not initialised.")
result = pipeline.diagnose(
note=request.note,
top_n=request.top_n,
threshold=request.threshold,
)
def _to_candidate(c: dict) -> Candidate:
return Candidate(
rank = c["rank"],
orpha_code = c["orpha_code"],
name = c["name"],
definition = c.get("definition") or None,
rrf_score = c["rrf_score"],
graph_rank = c.get("graph_rank"),
chroma_rank = c.get("chroma_rank"),
graph_matches = c.get("graph_matches"),
chroma_sim = c.get("chroma_sim"),
matched_hpo = c.get("matched_hpo", []),
hallucination_flag = c.get("hallucination_flag", False),
flag_reason = c.get("flag_reason"),
evidence_score = c.get("evidence_score", 0.0),
)
candidates = [_to_candidate(c) for c in result["candidates"]]
passed_candidates = [_to_candidate(c) for c in result["passed_candidates"]]
flagged_candidates= [_to_candidate(c) for c in result["flagged_candidates"]]
hpo_matches = [HPOMatch(**m) for m in result["hpo_matches"]]
top = passed_candidates[0] if passed_candidates else (candidates[0] if candidates else None)
return DiagnoseResponse(
note = result["note"],
phrases_extracted = result["phrases_extracted"],
hpo_matches = hpo_matches,
hpo_ids_used = result["hpo_ids_used"],
candidates = candidates,
passed_candidates = passed_candidates,
flagged_candidates = flagged_candidates,
top_diagnosis = top,
graph_backend = result["graph_backend"],
chroma_backend = result["chroma_backend"],
elapsed_seconds = result["elapsed_seconds"],
)
@app.get("/hpo/search")
def hpo_search(q: str, limit: int = 10):
"""Debug endpoint: find HPO terms by keyword in graph store."""
if pipeline is None:
raise HTTPException(status_code=503, detail="Pipeline not initialised.")
store = pipeline.graph_store
q_lower = q.lower()
results = []
for _, attrs in store.graph.nodes(data=True):
if attrs.get("type") == "HPOTerm":
if q_lower in attrs.get("term", "").lower():
results.append({"hpo_id": attrs["hpo_id"], "term": attrs["term"]})
if len(results) >= limit:
break
return {"query": q, "results": results}