raredx / backend /scripts /milestone_2b.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
milestone_2b.py
---------------
Week 2B Milestone: Free-text clinical note → differential diagnosis.
Tests the full pipeline end-to-end:
Clinical note
-> SymptomParser (BioLORD semantic HPO mapping)
-> Graph traversal (MANIFESTS_AS phenotype matching)
-> ChromaDB semantic search (HPO-enriched embeddings)
-> RRF fusion
-> Ranked differential diagnosis
Target note:
"18 year old male, extremely tall, displaced lens in left eye,
heart murmur, flexible joints, scoliosis"
Expected: Marfan syndrome (ORPHA:558) in top 3.
"""
import io
import sys
import time
from pathlib import Path
# UTF-8 output for Windows
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
# Make sure both scripts/ and api/ are importable
ROOT = Path(__file__).parents[2]
sys.path.insert(0, str(ROOT / "backend" / "scripts"))
sys.path.insert(0, str(ROOT / "backend"))
from api.pipeline import DiagnosisPipeline
# ---------------------------------------------------------------------------
# Test case
# ---------------------------------------------------------------------------
NOTE = (
"18 year old male, extremely tall, displaced lens in left eye, "
"heart murmur, flexible joints, scoliosis"
)
# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------
BOLD = "\033[1m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
MAGENTA = "\033[95m"
RED = "\033[91m"
DIM = "\033[2m"
RESET = "\033[0m"
LINE = "-" * 68
def section(title: str, color: str) -> None:
print(f"\n{BOLD}{color}{title}{RESET}")
print(LINE)
def print_hpo_matches(matches: list[dict]) -> None:
section("[ Step 1 — Symptom Parser: Free-text -> HPO Terms ]", CYAN)
if not matches:
print(f" {YELLOW}No HPO terms resolved.{RESET}")
return
print(f" {'Score':>6} {'HPO ID':<12} {'HPO Term':<38} Phrase")
print(f" {'-'*6} {'-'*12} {'-'*38} {'-'*28}")
for m in matches:
print(f" {m['score']:>6.4f} {m['hpo_id']:<12} {m['term']:<38} \"{m['phrase']}\"")
def print_candidates(candidates: list[dict], n: int = 10) -> None:
section("[ Step 4 — Fused Differential Diagnosis (RRF) ]", MAGENTA)
print(f" {'#':<4} {'RRF':>7} {'Graph':>6} {'Vec':>5} {'Match':>5} Disease")
print(f" {'-'*4} {'-'*7} {'-'*6} {'-'*5} {'-'*5} {'-'*38}")
for c in candidates[:n]:
gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " - "
cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " - "
mc = str(c.get("graph_matches", "-")) if c.get("graph_matches") is not None else " - "
name = c["name"][:42]
# Highlight Marfan
highlight = BOLD + GREEN if "Marfan" in c["name"] else ""
reset_hl = RESET if highlight else ""
print(
f" {c['rank']:<4} {c['rrf_score']:>7.5f} {gr:>6} {cr:>5} {mc:>5} "
f"{highlight}{name}{reset_hl}"
)
# Show matched phenotypes for top 3
if c["rank"] <= 3 and c.get("matched_hpo"):
terms = ", ".join(h["term"] for h in c["matched_hpo"][:5])
print(f" {DIM}Phenotypes: {terms}{RESET}")
# ---------------------------------------------------------------------------
# Milestone validation
# ---------------------------------------------------------------------------
def validate(result: dict) -> bool:
"""Pass if Marfan syndrome appears in top 5."""
candidates = result.get("candidates", [])
for c in candidates[:5]:
if "558" in str(c.get("orpha_code", "")) or "Marfan syndrome" == c.get("name", ""):
return True
return False
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 68)
print("RareDx — Week 2B Milestone: Clinical Note -> Diagnosis")
print("=" * 68)
print(f"\n{BOLD}Clinical note:{RESET}")
print(f" \"{NOTE}\"\n")
# Initialise pipeline (loads model + HPO index + graph + ChromaDB)
t0 = time.time()
pipeline = DiagnosisPipeline()
print(f"\nPipeline initialised in {time.time()-t0:.1f}s\n")
# Run diagnosis
print(f"Running diagnosis...")
result = pipeline.diagnose(NOTE, top_n=15, threshold=0.52)
print(f" Completed in {result['elapsed_seconds']}s")
# Display
print_hpo_matches(result["hpo_matches"])
section("[ Step 2+3 — Graph + Semantic Search Summary ]", CYAN)
hpo_used = result["hpo_ids_used"]
print(f" HPO IDs fed to graph: {', '.join(hpo_used) if hpo_used else 'none'}")
print(f" Graph candidates: {sum(1 for c in result['candidates'] if c.get('graph_rank'))}")
print(f" ChromaDB candidates: {sum(1 for c in result['candidates'] if c.get('chroma_rank'))}")
print(f" Overlap (both): {sum(1 for c in result['candidates'] if c.get('graph_rank') and c.get('chroma_rank'))}")
print_candidates(result["candidates"])
# Summary
passed = validate(result)
top = result.get("top_diagnosis", {})
print(f"\n{LINE}")
print(f"{BOLD}Week 2B Milestone Summary{RESET}")
print(LINE)
print(f" HPO terms resolved : {len(result['hpo_matches'])} / {len(result['phrases_extracted'])} phrases matched")
print(f" Total candidates : {len(result['candidates'])} unique diseases")
print(f" Graph backend : {result['graph_backend']}")
print(f" ChromaDB backend : {result['chroma_backend']}")
print(f" Elapsed : {result['elapsed_seconds']}s")
print()
if passed:
marfan_rank = next(
(c["rank"] for c in result["candidates"]
if "Marfan syndrome" == c.get("name") or "558" in str(c.get("orpha_code", ""))),
"?",
)
print(f" {BOLD}{GREEN}PASSED{RESET} — Marfan syndrome (ORPHA:558) at rank #{marfan_rank}")
else:
print(f" {RED}FAILED{RESET} — Marfan syndrome not in top 5")
print(f" Top result: {top.get('name')} (ORPHA:{top.get('orpha_code')})")
sys.exit(1)
print()
print(f" {BOLD}Top diagnosis:{RESET} {top.get('name')} (ORPHA:{top.get('orpha_code')})")
if top.get("definition"):
words = top["definition"].split()
snippet = " ".join(words[:30]) + ("..." if len(words) > 30 else "")
print(f" {DIM}{snippet}{RESET}")
print()
if __name__ == "__main__":
main()