| """ |
| milestone_2b.py |
| --------------- |
| Week 2B Milestone: Free-text clinical note → differential diagnosis. |
| |
| Tests the full pipeline end-to-end: |
| Clinical note |
| -> SymptomParser (BioLORD semantic HPO mapping) |
| -> Graph traversal (MANIFESTS_AS phenotype matching) |
| -> ChromaDB semantic search (HPO-enriched embeddings) |
| -> RRF fusion |
| -> Ranked differential diagnosis |
| |
| Target note: |
| "18 year old male, extremely tall, displaced lens in left eye, |
| heart murmur, flexible joints, scoliosis" |
| |
| Expected: Marfan syndrome (ORPHA:558) in top 3. |
| """ |
|
|
| import io |
| import sys |
| import time |
| from pathlib import Path |
|
|
| |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") |
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") |
|
|
| |
| ROOT = Path(__file__).parents[2] |
| sys.path.insert(0, str(ROOT / "backend" / "scripts")) |
| sys.path.insert(0, str(ROOT / "backend")) |
|
|
| from api.pipeline import DiagnosisPipeline |
|
|
| |
| |
| |
|
|
| NOTE = ( |
| "18 year old male, extremely tall, displaced lens in left eye, " |
| "heart murmur, flexible joints, scoliosis" |
| ) |
|
|
| |
| |
| |
|
|
| BOLD = "\033[1m" |
| CYAN = "\033[96m" |
| GREEN = "\033[92m" |
| YELLOW = "\033[93m" |
| MAGENTA = "\033[95m" |
| RED = "\033[91m" |
| DIM = "\033[2m" |
| RESET = "\033[0m" |
| LINE = "-" * 68 |
|
|
|
|
| def section(title: str, color: str) -> None: |
| print(f"\n{BOLD}{color}{title}{RESET}") |
| print(LINE) |
|
|
|
|
| def print_hpo_matches(matches: list[dict]) -> None: |
| section("[ Step 1 — Symptom Parser: Free-text -> HPO Terms ]", CYAN) |
| if not matches: |
| print(f" {YELLOW}No HPO terms resolved.{RESET}") |
| return |
| print(f" {'Score':>6} {'HPO ID':<12} {'HPO Term':<38} Phrase") |
| print(f" {'-'*6} {'-'*12} {'-'*38} {'-'*28}") |
| for m in matches: |
| print(f" {m['score']:>6.4f} {m['hpo_id']:<12} {m['term']:<38} \"{m['phrase']}\"") |
|
|
|
|
| def print_candidates(candidates: list[dict], n: int = 10) -> None: |
| section("[ Step 4 — Fused Differential Diagnosis (RRF) ]", MAGENTA) |
| print(f" {'#':<4} {'RRF':>7} {'Graph':>6} {'Vec':>5} {'Match':>5} Disease") |
| print(f" {'-'*4} {'-'*7} {'-'*6} {'-'*5} {'-'*5} {'-'*38}") |
|
|
| for c in candidates[:n]: |
| gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " - " |
| cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " - " |
| mc = str(c.get("graph_matches", "-")) if c.get("graph_matches") is not None else " - " |
| name = c["name"][:42] |
|
|
| |
| highlight = BOLD + GREEN if "Marfan" in c["name"] else "" |
| reset_hl = RESET if highlight else "" |
|
|
| print( |
| f" {c['rank']:<4} {c['rrf_score']:>7.5f} {gr:>6} {cr:>5} {mc:>5} " |
| f"{highlight}{name}{reset_hl}" |
| ) |
|
|
| |
| if c["rank"] <= 3 and c.get("matched_hpo"): |
| terms = ", ".join(h["term"] for h in c["matched_hpo"][:5]) |
| print(f" {DIM}Phenotypes: {terms}{RESET}") |
|
|
|
|
| |
| |
| |
|
|
| def validate(result: dict) -> bool: |
| """Pass if Marfan syndrome appears in top 5.""" |
| candidates = result.get("candidates", []) |
| for c in candidates[:5]: |
| if "558" in str(c.get("orpha_code", "")) or "Marfan syndrome" == c.get("name", ""): |
| return True |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| print("=" * 68) |
| print("RareDx — Week 2B Milestone: Clinical Note -> Diagnosis") |
| print("=" * 68) |
| print(f"\n{BOLD}Clinical note:{RESET}") |
| print(f" \"{NOTE}\"\n") |
|
|
| |
| t0 = time.time() |
| pipeline = DiagnosisPipeline() |
| print(f"\nPipeline initialised in {time.time()-t0:.1f}s\n") |
|
|
| |
| print(f"Running diagnosis...") |
| result = pipeline.diagnose(NOTE, top_n=15, threshold=0.52) |
| print(f" Completed in {result['elapsed_seconds']}s") |
|
|
| |
| print_hpo_matches(result["hpo_matches"]) |
|
|
| section("[ Step 2+3 — Graph + Semantic Search Summary ]", CYAN) |
| hpo_used = result["hpo_ids_used"] |
| print(f" HPO IDs fed to graph: {', '.join(hpo_used) if hpo_used else 'none'}") |
| print(f" Graph candidates: {sum(1 for c in result['candidates'] if c.get('graph_rank'))}") |
| print(f" ChromaDB candidates: {sum(1 for c in result['candidates'] if c.get('chroma_rank'))}") |
| print(f" Overlap (both): {sum(1 for c in result['candidates'] if c.get('graph_rank') and c.get('chroma_rank'))}") |
|
|
| print_candidates(result["candidates"]) |
|
|
| |
| passed = validate(result) |
| top = result.get("top_diagnosis", {}) |
|
|
| print(f"\n{LINE}") |
| print(f"{BOLD}Week 2B Milestone Summary{RESET}") |
| print(LINE) |
| print(f" HPO terms resolved : {len(result['hpo_matches'])} / {len(result['phrases_extracted'])} phrases matched") |
| print(f" Total candidates : {len(result['candidates'])} unique diseases") |
| print(f" Graph backend : {result['graph_backend']}") |
| print(f" ChromaDB backend : {result['chroma_backend']}") |
| print(f" Elapsed : {result['elapsed_seconds']}s") |
| print() |
|
|
| if passed: |
| marfan_rank = next( |
| (c["rank"] for c in result["candidates"] |
| if "Marfan syndrome" == c.get("name") or "558" in str(c.get("orpha_code", ""))), |
| "?", |
| ) |
| print(f" {BOLD}{GREEN}PASSED{RESET} — Marfan syndrome (ORPHA:558) at rank #{marfan_rank}") |
| else: |
| print(f" {RED}FAILED{RESET} — Marfan syndrome not in top 5") |
| print(f" Top result: {top.get('name')} (ORPHA:{top.get('orpha_code')})") |
| sys.exit(1) |
|
|
| print() |
| print(f" {BOLD}Top diagnosis:{RESET} {top.get('name')} (ORPHA:{top.get('orpha_code')})") |
| if top.get("definition"): |
| words = top["definition"].split() |
| snippet = " ".join(words[:30]) + ("..." if len(words) > 30 else "") |
| print(f" {DIM}{snippet}{RESET}") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|