File size: 6,679 Bytes
89c6379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
test_week3p1.py
---------------
Week 3 Part 1 test:
  1. Single-word HPO extraction — confirm "scoliosis" is now extracted
  2. Hallucination guard — show which candidates pass / are flagged
  3. Marfan validation — confirm ORPHA:558 is in the passed set

Clinical note:
  "18 year old male, extremely tall, displaced lens in left eye,
   heart murmur, flexible joints, scoliosis"
"""

import io
import sys
from pathlib import Path

sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")

ROOT = Path(__file__).parents[2]
sys.path.insert(0, str(ROOT / "backend" / "scripts"))
sys.path.insert(0, str(ROOT / "backend"))

from api.pipeline import DiagnosisPipeline

NOTE = (
    "18 year old male, extremely tall, displaced lens in left eye, "
    "heart murmur, flexible joints, scoliosis"
)

BOLD    = "\033[1m"
CYAN    = "\033[96m"
GREEN   = "\033[92m"
YELLOW  = "\033[93m"
MAGENTA = "\033[95m"
RED     = "\033[91m"
DIM     = "\033[2m"
RESET   = "\033[0m"
LINE    = "-" * 70


def section(title: str, color: str = CYAN) -> None:
    print(f"\n{BOLD}{color}{title}{RESET}")
    print(LINE)


def main() -> None:
    print("=" * 70)
    print("RareDx — Week 3 Part 1 Test")
    print("=" * 70)
    print(f"\n{BOLD}Note:{RESET} \"{NOTE}\"\n")

    pipeline = DiagnosisPipeline()
    result   = pipeline.diagnose(NOTE, top_n=15, threshold=0.52)

    # -----------------------------------------------------------------------
    # 1. Single-word extraction
    # -----------------------------------------------------------------------
    section("[ Fix 1 — Single-word HPO Extraction ]")
    matches = result["hpo_matches"]
    print(f"  {'Score':>6}  {'HPO ID':<12}  {'Term':<35}  Phrase")
    print(f"  {'-'*6}  {'-'*12}  {'-'*35}  {'-'*28}")
    for m in matches:
        tag = f"{DIM}(single word){RESET}" if len(m["phrase"].split()) == 1 else ""
        print(f"  {m['score']:>6.4f}  {m['hpo_id']:<12}  {m['term']:<35}  \"{m['phrase']}\" {tag}")

    scoliosis_found = any(m["hpo_id"] == "HP:0002650" for m in matches)
    status = f"{GREEN}EXTRACTED{RESET}" if scoliosis_found else f"{RED}MISSING{RESET}"
    print(f"\n  Scoliosis (HP:0002650): {status}")

    # -----------------------------------------------------------------------
    # 2. Hallucination guard results
    # -----------------------------------------------------------------------
    passed  = result["passed_candidates"]
    flagged = result["flagged_candidates"]
    total_q = len(result["hpo_ids_used"])

    section("[ Fix 2 — FusionNode Hallucination Guard ]", MAGENTA)
    print(f"  Query HPO terms: {total_q}  |  "
          f"Passed: {GREEN}{len(passed)}{RESET}  |  "
          f"Flagged: {YELLOW}{len(flagged)}{RESET}\n")

    print(f"  {BOLD}{GREEN}PASSED candidates:{RESET}")
    print(f"  {'#':<4}  {'Ev':>5}  {'G':>3}  {'V':>3}  {'M':>3}  Disease")
    print(f"  {'-'*4}  {'-'*5}  {'-'*3}  {'-'*3}  {'-'*3}  {'-'*40}")
    for c in passed:
        gr = f"#{c['graph_rank']}"  if c.get("graph_rank")  else "  -"
        cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else "  -"
        mc = str(c.get("graph_matches", "-"))
        hi = BOLD + GREEN if "558" in str(c.get("orpha_code")) else ""
        rs = RESET if hi else ""
        print(f"  {c['rank']:<4}  {c['evidence_score']:>5.3f}  {gr:>3}  {cr:>3}  {mc:>3}  {hi}{c['name'][:44]}{rs}")

    if flagged:
        print(f"\n  {BOLD}{YELLOW}FLAGGED candidates:{RESET}")
        print(f"  {'#':<4}  {'Ev':>5}  {'G':>3}  {'V':>3}  {'M':>3}  Disease  |  Reason")
        print(f"  {'-'*4}  {'-'*5}  {'-'*3}  {'-'*3}  {'-'*3}  {'-'*40}")
        for c in flagged:
            gr = f"#{c['graph_rank']}"  if c.get("graph_rank")  else "  -"
            cr = f"#{c['chroma_rank']}" if c.get("chroma_rank") else "  -"
            mc = str(c.get("graph_matches", "-"))
            print(f"  {c['rank']:<4}  {c['evidence_score']:>5.3f}  {gr:>3}  {cr:>3}  {mc:>3}  "
                  f"{c['name'][:30]}  |  {DIM}{c.get('flag_reason', '')[:50]}{RESET}")

    # -----------------------------------------------------------------------
    # 3. Marfan validation
    # -----------------------------------------------------------------------
    section("[ Validation — Marfan Syndrome (ORPHA:558) ]", GREEN)
    marfan_all    = next((c for c in result["candidates"] if c["orpha_code"] == "558"), None)
    marfan_passed = next((c for c in passed            if c["orpha_code"] == "558"), None)

    if marfan_all:
        print(f"  Overall rank  : #{marfan_all['rank']}  (RRF {marfan_all['rrf_score']:.5f})")
        print(f"  Evidence score: {marfan_all.get('evidence_score', 0):.3f}")
        print(f"  Graph rank    : #{marfan_all['graph_rank']}" if marfan_all.get("graph_rank") else "  Graph rank    : not in graph results")
        print(f"  Chroma rank   : #{marfan_all['chroma_rank']}" if marfan_all.get("chroma_rank") else "  Chroma rank   : not in vector results")
        hpo_matched = marfan_all.get("matched_hpo", [])
        if hpo_matched:
            print(f"  Matched HPO   : {', '.join(h['term'] for h in hpo_matched)}")
        guarded = not marfan_all.get("hallucination_flag", False)
        print(f"  Guard result  : {GREEN+'PASSED'+RESET if guarded else RED+'FLAGGED — '+marfan_all.get('flag_reason','?')+RESET}")
    else:
        print(f"  {RED}Marfan syndrome not found in any candidates.{RESET}")

    # -----------------------------------------------------------------------
    # Summary
    # -----------------------------------------------------------------------
    top = result["top_diagnosis"]
    print(f"\n{LINE}")
    print(f"{BOLD}Week 3 Part 1 Summary{RESET}")
    print(LINE)

    checks = {
        "Single-word extraction (scoliosis)": scoliosis_found,
        "Hallucination guard active":         len(flagged) > 0 or len(passed) > 0,
        "Marfan in candidates":               marfan_all is not None,
        "Marfan passes guard":                marfan_passed is not None,
    }

    all_pass = True
    for label, ok in checks.items():
        icon = f"{GREEN}PASS{RESET}" if ok else f"{RED}FAIL{RESET}"
        print(f"  {icon}  {label}")
        if not ok:
            all_pass = False

    print()
    if all_pass:
        print(f"  {BOLD}{GREEN}ALL CHECKS PASSED{RESET}")
    else:
        print(f"  {RED}SOME CHECKS FAILED — review above{RESET}")
        sys.exit(1)

    print(f"\n  Top diagnosis  : {top['name']} (ORPHA:{top['orpha_code']})")
    print(f"  Elapsed        : {result['elapsed_seconds']}s")
    print()


if __name__ == "__main__":
    main()