""" milestone_2b.py --------------- Week 2B Milestone: Free-text clinical note → differential diagnosis. Tests the full pipeline end-to-end: Clinical note -> SymptomParser (BioLORD semantic HPO mapping) -> Graph traversal (MANIFESTS_AS phenotype matching) -> ChromaDB semantic search (HPO-enriched embeddings) -> RRF fusion -> Ranked differential diagnosis Target note: "18 year old male, extremely tall, displaced lens in left eye, heart murmur, flexible joints, scoliosis" Expected: Marfan syndrome (ORPHA:558) in top 3. """ import io import sys import time from pathlib import Path # UTF-8 output for Windows sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") # Make sure both scripts/ and api/ are importable ROOT = Path(__file__).parents[2] sys.path.insert(0, str(ROOT / "backend" / "scripts")) sys.path.insert(0, str(ROOT / "backend")) from api.pipeline import DiagnosisPipeline # --------------------------------------------------------------------------- # Test case # --------------------------------------------------------------------------- NOTE = ( "18 year old male, extremely tall, displaced lens in left eye, " "heart murmur, flexible joints, scoliosis" ) # --------------------------------------------------------------------------- # Display helpers # --------------------------------------------------------------------------- BOLD = "\033[1m" CYAN = "\033[96m" GREEN = "\033[92m" YELLOW = "\033[93m" MAGENTA = "\033[95m" RED = "\033[91m" DIM = "\033[2m" RESET = "\033[0m" LINE = "-" * 68 def section(title: str, color: str) -> None: print(f"\n{BOLD}{color}{title}{RESET}") print(LINE) def print_hpo_matches(matches: list[dict]) -> None: section("[ Step 1 — Symptom Parser: Free-text -> HPO Terms ]", CYAN) if not matches: print(f" {YELLOW}No HPO terms resolved.{RESET}") return print(f" {'Score':>6} {'HPO ID':<12} {'HPO Term':<38} Phrase") print(f" {'-'*6} {'-'*12} {'-'*38} {'-'*28}") for m in matches: print(f" {m['score']:>6.4f} {m['hpo_id']:<12} {m['term']:<38} \"{m['phrase']}\"") def print_candidates(candidates: list[dict], n: int = 10) -> None: section("[ Step 4 — Fused Differential Diagnosis (RRF) ]", MAGENTA) print(f" {'#':<4} {'RRF':>7} {'Graph':>6} {'Vec':>5} {'Match':>5} Disease") print(f" {'-'*4} {'-'*7} {'-'*6} {'-'*5} {'-'*5} {'-'*38}") for c in candidates[:n]: gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " - " cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " - " mc = str(c.get("graph_matches", "-")) if c.get("graph_matches") is not None else " - " name = c["name"][:42] # Highlight Marfan highlight = BOLD + GREEN if "Marfan" in c["name"] else "" reset_hl = RESET if highlight else "" print( f" {c['rank']:<4} {c['rrf_score']:>7.5f} {gr:>6} {cr:>5} {mc:>5} " f"{highlight}{name}{reset_hl}" ) # Show matched phenotypes for top 3 if c["rank"] <= 3 and c.get("matched_hpo"): terms = ", ".join(h["term"] for h in c["matched_hpo"][:5]) print(f" {DIM}Phenotypes: {terms}{RESET}") # --------------------------------------------------------------------------- # Milestone validation # --------------------------------------------------------------------------- def validate(result: dict) -> bool: """Pass if Marfan syndrome appears in top 5.""" candidates = result.get("candidates", []) for c in candidates[:5]: if "558" in str(c.get("orpha_code", "")) or "Marfan syndrome" == c.get("name", ""): return True return False # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: print("=" * 68) print("RareDx — Week 2B Milestone: Clinical Note -> Diagnosis") print("=" * 68) print(f"\n{BOLD}Clinical note:{RESET}") print(f" \"{NOTE}\"\n") # Initialise pipeline (loads model + HPO index + graph + ChromaDB) t0 = time.time() pipeline = DiagnosisPipeline() print(f"\nPipeline initialised in {time.time()-t0:.1f}s\n") # Run diagnosis print(f"Running diagnosis...") result = pipeline.diagnose(NOTE, top_n=15, threshold=0.52) print(f" Completed in {result['elapsed_seconds']}s") # Display print_hpo_matches(result["hpo_matches"]) section("[ Step 2+3 — Graph + Semantic Search Summary ]", CYAN) hpo_used = result["hpo_ids_used"] print(f" HPO IDs fed to graph: {', '.join(hpo_used) if hpo_used else 'none'}") print(f" Graph candidates: {sum(1 for c in result['candidates'] if c.get('graph_rank'))}") print(f" ChromaDB candidates: {sum(1 for c in result['candidates'] if c.get('chroma_rank'))}") print(f" Overlap (both): {sum(1 for c in result['candidates'] if c.get('graph_rank') and c.get('chroma_rank'))}") print_candidates(result["candidates"]) # Summary passed = validate(result) top = result.get("top_diagnosis", {}) print(f"\n{LINE}") print(f"{BOLD}Week 2B Milestone Summary{RESET}") print(LINE) print(f" HPO terms resolved : {len(result['hpo_matches'])} / {len(result['phrases_extracted'])} phrases matched") print(f" Total candidates : {len(result['candidates'])} unique diseases") print(f" Graph backend : {result['graph_backend']}") print(f" ChromaDB backend : {result['chroma_backend']}") print(f" Elapsed : {result['elapsed_seconds']}s") print() if passed: marfan_rank = next( (c["rank"] for c in result["candidates"] if "Marfan syndrome" == c.get("name") or "558" in str(c.get("orpha_code", ""))), "?", ) print(f" {BOLD}{GREEN}PASSED{RESET} — Marfan syndrome (ORPHA:558) at rank #{marfan_rank}") else: print(f" {RED}FAILED{RESET} — Marfan syndrome not in top 5") print(f" Top result: {top.get('name')} (ORPHA:{top.get('orpha_code')})") sys.exit(1) print() print(f" {BOLD}Top diagnosis:{RESET} {top.get('name')} (ORPHA:{top.get('orpha_code')})") if top.get("definition"): words = top["definition"].split() snippet = " ".join(words[:30]) + ("..." if len(words) > 30 else "") print(f" {DIM}{snippet}{RESET}") print() if __name__ == "__main__": main()