| """ |
| main.py — RareDx FastAPI application. |
| |
| Run with: |
| uvicorn backend.api.main:app --reload --port 8080 |
| |
| Or from the project root: |
| python -m uvicorn backend.api.main:app --reload --port 8080 |
| |
| Endpoints: |
| POST /diagnose — clinical note → differential diagnosis |
| GET /health — liveness check |
| GET /hpo/search?q=... — debug: find HPO terms by keyword |
| """ |
|
|
| import sys |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parents[1] / "scripts")) |
|
|
| from .models import DiagnoseRequest, DiagnoseResponse, Candidate, HPOMatch |
| from .pipeline import DiagnosisPipeline |
|
|
| |
| pipeline: DiagnosisPipeline | None = None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global pipeline |
| print("Starting up RareDx API...") |
| pipeline = DiagnosisPipeline() |
| print("API ready.") |
| yield |
| print("Shutting down.") |
|
|
|
|
| app = FastAPI( |
| title="RareDx API", |
| description=( |
| "Multi-agent clinical AI for rare disease diagnosis. " |
| "Combines knowledge graph (Orphanet/HPO) with BioLORD-2023 " |
| "semantic embeddings to generate differential diagnoses from " |
| "plain-text clinical notes." |
| ), |
| version="0.2.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health") |
| def health(): |
| return { |
| "status": "ok", |
| "pipeline_ready": pipeline is not None, |
| "graph_backend": pipeline.graph_backend if pipeline else None, |
| "chroma_backend": pipeline.chroma_backend if pipeline else None, |
| } |
|
|
|
|
| @app.post("/diagnose", response_model=DiagnoseResponse) |
| def diagnose(request: DiagnoseRequest): |
| if pipeline is None: |
| raise HTTPException(status_code=503, detail="Pipeline not initialised.") |
|
|
| result = pipeline.diagnose( |
| note=request.note, |
| top_n=request.top_n, |
| threshold=request.threshold, |
| ) |
|
|
| def _to_candidate(c: dict) -> Candidate: |
| return Candidate( |
| rank = c["rank"], |
| orpha_code = c["orpha_code"], |
| name = c["name"], |
| definition = c.get("definition") or None, |
| rrf_score = c["rrf_score"], |
| graph_rank = c.get("graph_rank"), |
| chroma_rank = c.get("chroma_rank"), |
| graph_matches = c.get("graph_matches"), |
| chroma_sim = c.get("chroma_sim"), |
| matched_hpo = c.get("matched_hpo", []), |
| hallucination_flag = c.get("hallucination_flag", False), |
| flag_reason = c.get("flag_reason"), |
| evidence_score = c.get("evidence_score", 0.0), |
| ) |
|
|
| candidates = [_to_candidate(c) for c in result["candidates"]] |
| passed_candidates = [_to_candidate(c) for c in result["passed_candidates"]] |
| flagged_candidates= [_to_candidate(c) for c in result["flagged_candidates"]] |
| hpo_matches = [HPOMatch(**m) for m in result["hpo_matches"]] |
| top = passed_candidates[0] if passed_candidates else (candidates[0] if candidates else None) |
|
|
| return DiagnoseResponse( |
| note = result["note"], |
| phrases_extracted = result["phrases_extracted"], |
| hpo_matches = hpo_matches, |
| hpo_ids_used = result["hpo_ids_used"], |
| candidates = candidates, |
| passed_candidates = passed_candidates, |
| flagged_candidates = flagged_candidates, |
| top_diagnosis = top, |
| graph_backend = result["graph_backend"], |
| chroma_backend = result["chroma_backend"], |
| elapsed_seconds = result["elapsed_seconds"], |
| ) |
|
|
|
|
| @app.get("/hpo/search") |
| def hpo_search(q: str, limit: int = 10): |
| """Debug endpoint: find HPO terms by keyword in graph store.""" |
| if pipeline is None: |
| raise HTTPException(status_code=503, detail="Pipeline not initialised.") |
|
|
| store = pipeline.graph_store |
| q_lower = q.lower() |
| results = [] |
| for _, attrs in store.graph.nodes(data=True): |
| if attrs.get("type") == "HPOTerm": |
| if q_lower in attrs.get("term", "").lower(): |
| results.append({"hpo_id": attrs["hpo_id"], "term": attrs["term"]}) |
| if len(results) >= limit: |
| break |
| return {"query": q, "results": results} |
|
|