File size: 4,520 Bytes
5d53fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Quick analysis of MedQA checkpoint data."""
import json

path = "validation/results/medqa_checkpoint.jsonl"
with open(path) as f:
    results = [json.loads(l) for l in f]

print(f"Cases completed: {len(results)}\n")

# ── Table view ──
fmt = "{:<12} {:>3} {:>3} {:>4} {:>7} {:>3} {:>4}  {:<15} {:<42} {}"
print(fmt.format("ID", "t1", "t3", "diff", "ms", "#dx", "rnk", "match_loc", "correct_answer", "top_diagnosis"))
print("-" * 145)

for r in results:
    d = r["details"]
    t1 = "Y" if r["scores"]["top1_accuracy"] else "N"
    t3 = "Y" if r["scores"]["top3_accuracy"] else "N"
    da = "Y" if r["scores"].get("differential_accuracy") else "N"
    rank = d.get("found_at_rank", -1)
    loc = d.get("match_location", "?")
    ca = d["correct_answer"][:42]
    td = d.get("top_diagnosis", "?")[:45]
    print(fmt.format(r["case_id"], t1, t3, da, r["pipeline_time_ms"], d.get("num_diagnoses", 0), rank, loc, ca, td))

print()

# ── Timing analysis ──
correct = [r for r in results if r["scores"]["top1_accuracy"]]
wrong = [r for r in results if not r["scores"]["top1_accuracy"]]
mentioned = [r for r in results if r["scores"].get("mentioned_accuracy")]
top3 = [r for r in results if r["scores"]["top3_accuracy"]]
diff_only = [r for r in results if r["scores"].get("differential_accuracy")]

if correct:
    avg = sum(r["pipeline_time_ms"] for r in correct) / len(correct)
    print(f"Correct (top1) avg time: {avg:.0f}ms  ({len(correct)}/{len(results)} = {len(correct)/len(results)*100:.0f}%)")
if top3:
    avg = sum(r["pipeline_time_ms"] for r in top3) / len(top3)
    print(f"Correct (top3) avg time: {avg:.0f}ms  ({len(top3)}/{len(results)} = {len(top3)/len(results)*100:.0f}%)")
if diff_only:
    avg = sum(r["pipeline_time_ms"] for r in diff_only) / len(diff_only)
    print(f"Differential only:       {avg:.0f}ms  ({len(diff_only)}/{len(results)} = {len(diff_only)/len(results)*100:.0f}%)")
if wrong:
    avg = sum(r["pipeline_time_ms"] for r in wrong) / len(wrong)
    print(f"Wrong   (top1) avg time: {avg:.0f}ms  ({len(wrong)}/{len(results)} = {len(wrong)/len(results)*100:.0f}%)")
if mentioned:
    print(f"Mentioned anywhere:      {len(mentioned)}/{len(results)}")

# ── Match location breakdown ──
print("\n=== MATCH LOCATION BREAKDOWN ===")
loc_counts = {}
for r in results:
    loc = r["details"].get("match_location", "not_found")
    loc_counts[loc] = loc_counts.get(loc, 0) + 1
for loc, count in sorted(loc_counts.items()):
    print(f"  {loc:<20} {count:>3} ({count/len(results)*100:.0f}%)")

# ── Detailed per-case (new fields if available) ──
print("\n=== PER-CASE DETAIL ===")
for r in results:
    d = r["details"]
    cid = r["case_id"]
    loc = d.get("match_location", "?")
    ca = d["correct_answer"]
    td = d.get("top_diagnosis", "?")
    all_dx = d.get("all_diagnoses", [td])
    all_next = d.get("all_next_steps", [])
    all_recs = d.get("all_recommendations", [])
    t1 = "Y" if r["scores"]["top1_accuracy"] else "N"

    print(f"\n  {cid} [t1={t1}, loc={loc}]")
    print(f"    Expected: {ca}")
    print(f"    Differential: {', '.join(all_dx)}")
    if all_next:
        print(f"    Next steps: {'; '.join(all_next[:3])}")
    if all_recs:
        print(f"    Recommendations: {'; '.join(str(r)[:60] for r in all_recs[:3])}")

# ── Answer type vs accuracy ──
print("\n=== ANSWER TYPE vs ACCURACY ===")
dx_correct = dx_total = mgmt_correct = mgmt_total = 0
action_words = ["start", "stop", "give", "prescribe", "perform", "order", "refer",
                "increase", "decrease", "switch", "add", "monitor", "observation",
                "reassure", "discharge", "admit", "excess", "adaptation", "exclusion",
                "it is", "right-sided", "affective", "exploratory", "lytic"]
for r in results:
    ca = r["details"]["correct_answer"]
    is_dx = not any(w.lower() in ca.lower() for w in action_words)
    if is_dx:
        dx_total += 1
        if r["scores"]["top1_accuracy"]:
            dx_correct += 1
    else:
        mgmt_total += 1
        if r["scores"]["top1_accuracy"]:
            mgmt_correct += 1

if dx_total:
    print(f"  Diagnosis questions:    {dx_correct}/{dx_total} = {dx_correct/dx_total*100:.0f}%")
if mgmt_total:
    print(f"  Mgmt/concept questions: {mgmt_correct}/{mgmt_total} = {mgmt_correct/mgmt_total*100:.0f}%")

dx_counts = [r["details"].get("num_diagnoses", 0) for r in results]
print(f"\nDiagnoses generated: min={min(dx_counts)}, max={max(dx_counts)}, avg={sum(dx_counts)/len(dx_counts):.1f}")