raredx / backend /scripts /test_week3p1.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
test_week3p1.py
---------------
Week 3 Part 1 test:
1. Single-word HPO extraction — confirm "scoliosis" is now extracted
2. Hallucination guard — show which candidates pass / are flagged
3. Marfan validation — confirm ORPHA:558 is in the passed set
Clinical note:
"18 year old male, extremely tall, displaced lens in left eye,
heart murmur, flexible joints, scoliosis"
"""
import io
import sys
from pathlib import Path
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
ROOT = Path(__file__).parents[2]
sys.path.insert(0, str(ROOT / "backend" / "scripts"))
sys.path.insert(0, str(ROOT / "backend"))
from api.pipeline import DiagnosisPipeline
NOTE = (
"18 year old male, extremely tall, displaced lens in left eye, "
"heart murmur, flexible joints, scoliosis"
)
BOLD = "\033[1m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
MAGENTA = "\033[95m"
RED = "\033[91m"
DIM = "\033[2m"
RESET = "\033[0m"
LINE = "-" * 70
def section(title: str, color: str = CYAN) -> None:
print(f"\n{BOLD}{color}{title}{RESET}")
print(LINE)
def main() -> None:
print("=" * 70)
print("RareDx — Week 3 Part 1 Test")
print("=" * 70)
print(f"\n{BOLD}Note:{RESET} \"{NOTE}\"\n")
pipeline = DiagnosisPipeline()
result = pipeline.diagnose(NOTE, top_n=15, threshold=0.52)
# -----------------------------------------------------------------------
# 1. Single-word extraction
# -----------------------------------------------------------------------
section("[ Fix 1 — Single-word HPO Extraction ]")
matches = result["hpo_matches"]
print(f" {'Score':>6} {'HPO ID':<12} {'Term':<35} Phrase")
print(f" {'-'*6} {'-'*12} {'-'*35} {'-'*28}")
for m in matches:
tag = f"{DIM}(single word){RESET}" if len(m["phrase"].split()) == 1 else ""
print(f" {m['score']:>6.4f} {m['hpo_id']:<12} {m['term']:<35} \"{m['phrase']}\" {tag}")
scoliosis_found = any(m["hpo_id"] == "HP:0002650" for m in matches)
status = f"{GREEN}EXTRACTED{RESET}" if scoliosis_found else f"{RED}MISSING{RESET}"
print(f"\n Scoliosis (HP:0002650): {status}")
# -----------------------------------------------------------------------
# 2. Hallucination guard results
# -----------------------------------------------------------------------
passed = result["passed_candidates"]
flagged = result["flagged_candidates"]
total_q = len(result["hpo_ids_used"])
section("[ Fix 2 — FusionNode Hallucination Guard ]", MAGENTA)
print(f" Query HPO terms: {total_q} | "
f"Passed: {GREEN}{len(passed)}{RESET} | "
f"Flagged: {YELLOW}{len(flagged)}{RESET}\n")
print(f" {BOLD}{GREEN}PASSED candidates:{RESET}")
print(f" {'#':<4} {'Ev':>5} {'G':>3} {'V':>3} {'M':>3} Disease")
print(f" {'-'*4} {'-'*5} {'-'*3} {'-'*3} {'-'*3} {'-'*40}")
for c in passed:
gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " -"
cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " -"
mc = str(c.get("graph_matches", "-"))
hi = BOLD + GREEN if "558" in str(c.get("orpha_code")) else ""
rs = RESET if hi else ""
print(f" {c['rank']:<4} {c['evidence_score']:>5.3f} {gr:>3} {cr:>3} {mc:>3} {hi}{c['name'][:44]}{rs}")
if flagged:
print(f"\n {BOLD}{YELLOW}FLAGGED candidates:{RESET}")
print(f" {'#':<4} {'Ev':>5} {'G':>3} {'V':>3} {'M':>3} Disease | Reason")
print(f" {'-'*4} {'-'*5} {'-'*3} {'-'*3} {'-'*3} {'-'*40}")
for c in flagged:
gr = f"#{c['graph_rank']}" if c.get("graph_rank") else " -"
cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else " -"
mc = str(c.get("graph_matches", "-"))
print(f" {c['rank']:<4} {c['evidence_score']:>5.3f} {gr:>3} {cr:>3} {mc:>3} "
f"{c['name'][:30]} | {DIM}{c.get('flag_reason', '')[:50]}{RESET}")
# -----------------------------------------------------------------------
# 3. Marfan validation
# -----------------------------------------------------------------------
section("[ Validation — Marfan Syndrome (ORPHA:558) ]", GREEN)
marfan_all = next((c for c in result["candidates"] if c["orpha_code"] == "558"), None)
marfan_passed = next((c for c in passed if c["orpha_code"] == "558"), None)
if marfan_all:
print(f" Overall rank : #{marfan_all['rank']} (RRF {marfan_all['rrf_score']:.5f})")
print(f" Evidence score: {marfan_all.get('evidence_score', 0):.3f}")
print(f" Graph rank : #{marfan_all['graph_rank']}" if marfan_all.get("graph_rank") else " Graph rank : not in graph results")
print(f" Chroma rank : #{marfan_all['chroma_rank']}" if marfan_all.get("chroma_rank") else " Chroma rank : not in vector results")
hpo_matched = marfan_all.get("matched_hpo", [])
if hpo_matched:
print(f" Matched HPO : {', '.join(h['term'] for h in hpo_matched)}")
guarded = not marfan_all.get("hallucination_flag", False)
print(f" Guard result : {GREEN+'PASSED'+RESET if guarded else RED+'FLAGGED — '+marfan_all.get('flag_reason','?')+RESET}")
else:
print(f" {RED}Marfan syndrome not found in any candidates.{RESET}")
# -----------------------------------------------------------------------
# Summary
# -----------------------------------------------------------------------
top = result["top_diagnosis"]
print(f"\n{LINE}")
print(f"{BOLD}Week 3 Part 1 Summary{RESET}")
print(LINE)
checks = {
"Single-word extraction (scoliosis)": scoliosis_found,
"Hallucination guard active": len(flagged) > 0 or len(passed) > 0,
"Marfan in candidates": marfan_all is not None,
"Marfan passes guard": marfan_passed is not None,
}
all_pass = True
for label, ok in checks.items():
icon = f"{GREEN}PASS{RESET}" if ok else f"{RED}FAIL{RESET}"
print(f" {icon} {label}")
if not ok:
all_pass = False
print()
if all_pass:
print(f" {BOLD}{GREEN}ALL CHECKS PASSED{RESET}")
else:
print(f" {RED}SOME CHECKS FAILED — review above{RESET}")
sys.exit(1)
print(f"\n Top diagnosis : {top['name']} (ORPHA:{top['orpha_code']})")
print(f" Elapsed : {result['elapsed_seconds']}s")
print()
if __name__ == "__main__":
main()