File size: 10,273 Bytes
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f36481
393ff7f
 
 
 
 
 
 
 
 
 
1f36481
 
 
0fe4d92
 
 
1f36481
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d02eb2
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d02eb2
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d02eb2
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d02eb2
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f36481
393ff7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d02eb2
393ff7f
 
 
 
 
 
 
 
 
 
 
 
1f36481
 
 
393ff7f
 
 
 
3d02eb2
393ff7f
 
 
 
 
 
 
 
 
 
 
 
3d02eb2
393ff7f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
Unified validation runner for the Clinical Decision Support Agent.

Runs all three dataset validations (MedQA, MTSamples, PMC Case Reports)
and produces a combined summary report.

Usage:
    # From src/backend directory:
    python -m validation.run_validation --all --max-cases 10
    python -m validation.run_validation --medqa --max-cases 20
    python -m validation.run_validation --mtsamples --max-cases 15
    python -m validation.run_validation --pmc --max-cases 10

    # Fetch data only (no pipeline execution):
    python -m validation.run_validation --fetch-only
"""
from __future__ import annotations

import asyncio
import json
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

# Ensure backend is importable
BACKEND_DIR = Path(__file__).resolve().parent.parent
if str(BACKEND_DIR) not in sys.path:
    sys.path.insert(0, str(BACKEND_DIR))

# Load .env and export HF_TOKEN so huggingface_hub picks it up
from dotenv import load_dotenv
load_dotenv(BACKEND_DIR / ".env")
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    os.environ["HF_TOKEN"] = hf_token

from validation.base import (
    ValidationSummary,
    print_summary,
    save_results,
)
from validation.harness_medqa import fetch_medqa, validate_medqa
from validation.harness_mtsamples import fetch_mtsamples, validate_mtsamples
from validation.harness_pmc import fetch_pmc_cases, validate_pmc


async def run_all_validations(
    run_medqa: bool = True,
    run_mtsamples: bool = True,
    run_pmc: bool = True,
    max_cases: int = 10,
    seed: int = 42,
    include_drug_check: bool = True,
    include_guidelines: bool = True,
    delay: float = 2.0,
    fetch_only: bool = False,
    resume: bool = False,
) -> dict:
    """
    Run validation against selected datasets.

    Returns dict of {dataset_name: ValidationSummary}
    """
    results = {}
    start = time.time()

    # ── MedQA ──
    if run_medqa:
        print("\n" + "=" * 60)
        print("  DATASET 1: MedQA (USMLE-style diagnostic accuracy)")
        print("=" * 60)

        cases = await fetch_medqa(max_cases=max_cases, seed=seed)

        if fetch_only:
            print(f"  Fetched {len(cases)} MedQA cases (fetch-only mode)")
        else:
            summary = await validate_medqa(
                cases,
                include_drug_check=include_drug_check,
                include_guidelines=include_guidelines,
                delay_between_cases=delay,
                resume=resume,
            )
            print_summary(summary)
            save_results(summary)
            results["medqa"] = summary

    # ── MTSamples ──
    if run_mtsamples:
        print("\n" + "=" * 60)
        print("  DATASET 2: MTSamples (clinical note parsing robustness)")
        print("=" * 60)

        cases = await fetch_mtsamples(max_cases=max_cases, seed=seed)

        if fetch_only:
            print(f"  Fetched {len(cases)} MTSamples cases (fetch-only mode)")
        else:
            summary = await validate_mtsamples(
                cases,
                include_drug_check=include_drug_check,
                include_guidelines=include_guidelines,
                delay_between_cases=delay,
                resume=resume,
            )
            print_summary(summary)
            save_results(summary)
            results["mtsamples"] = summary

    # ── PMC Case Reports ──
    if run_pmc:
        print("\n" + "=" * 60)
        print("  DATASET 3: PMC Case Reports (real-world diagnostic accuracy)")
        print("=" * 60)

        cases = await fetch_pmc_cases(max_cases=max_cases, seed=seed)

        if fetch_only:
            print(f"  Fetched {len(cases)} PMC cases (fetch-only mode)")
        else:
            summary = await validate_pmc(
                cases,
                include_drug_check=include_drug_check,
                include_guidelines=include_guidelines,
                delay_between_cases=delay,
                resume=resume,
            )
            print_summary(summary)
            save_results(summary)
            results["pmc"] = summary

    # ── Combined Summary ──
    total_duration = time.time() - start

    if results and not fetch_only:
        _print_combined_summary(results, total_duration)
        _save_combined_report(results, total_duration)

    return results


def _print_combined_summary(results: dict, total_duration: float):
    """Print a combined summary across all datasets."""
    print("\n" + "=" * 70)
    print("  COMBINED VALIDATION REPORT")
    print("=" * 70)

    # Header
    print(f"\n  {'Dataset':<15} {'Cases':>6} {'Success':>8} {'Key Metric':>25} {'Value':>8}")
    print(f"  {'-'*15} {'-'*6} {'-'*8} {'-'*25} {'-'*8}")

    for name, summary in results.items():
        # Pick the most important metric for each dataset
        if name == "medqa":
            key_metric = "top3_accuracy"
        elif name == "mtsamples":
            key_metric = "parse_success"
        elif name == "pmc":
            key_metric = "diagnostic_accuracy"
        else:
            key_metric = list(summary.metrics.keys())[0] if summary.metrics else "N/A"

        value = summary.metrics.get(key_metric, 0.0)
        print(
            f"  {name:<15} {summary.total_cases:>6} "
            f"{summary.successful_cases:>8} "
            f"{key_metric:>25} {value:>7.1%}"
        )

    # All metrics
    print(f"\n  {'-' * 66}")
    for name, summary in results.items():
        print(f"\n  {name.upper()} metrics:")
        for metric, value in sorted(summary.metrics.items()):
            if "time" in metric and isinstance(value, (int, float)):
                print(f"    {metric:<35} {value:.0f}ms")
            elif isinstance(value, float):
                print(f"    {metric:<35} {value:.1%}")

    # Totals
    total_cases = sum(s.total_cases for s in results.values())
    total_success = sum(s.successful_cases for s in results.values())
    print(f"\n  Total cases:     {total_cases}")
    print(f"  Total success:   {total_success}")
    print(f"  Total duration:  {total_duration:.1f}s ({total_duration/60:.1f}min)")
    print(f"  Timestamp:       {datetime.now(timezone.utc).isoformat()}")
    print("=" * 70)


def _save_combined_report(results: dict, total_duration: float):
    """Save combined report to JSON."""
    results_dir = Path(__file__).resolve().parent / "results"
    results_dir.mkdir(parents=True, exist_ok=True)

    ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
    path = results_dir / f"combined_{ts}.json"

    combined = {
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "total_duration_sec": total_duration,
        "datasets": {},
    }

    for name, summary in results.items():
        combined["datasets"][name] = {
            "total_cases": summary.total_cases,
            "successful_cases": summary.successful_cases,
            "failed_cases": summary.failed_cases,
            "metrics": summary.metrics,
            "run_duration_sec": summary.run_duration_sec,
        }

    path.write_text(json.dumps(combined, indent=2, default=str))
    print(f"\n  Combined report saved to: {path}")


def main():
    """CLI entry point."""
    import argparse

    parser = argparse.ArgumentParser(
        description="CDS Agent Validation Suite",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python -m validation.run_validation --all --max-cases 10
  python -m validation.run_validation --medqa --max-cases 50
  python -m validation.run_validation --fetch-only
  python -m validation.run_validation --medqa --pmc --max-cases 20 --no-drugs
        """,
    )

    # Dataset selection
    data_group = parser.add_argument_group("Datasets")
    data_group.add_argument("--all", action="store_true", help="Run all three datasets")
    data_group.add_argument("--medqa", action="store_true", help="Run MedQA validation")
    data_group.add_argument("--mtsamples", action="store_true", help="Run MTSamples validation")
    data_group.add_argument("--pmc", action="store_true", help="Run PMC Case Reports validation")

    # Configuration
    config_group = parser.add_argument_group("Configuration")
    config_group.add_argument("--max-cases", type=int, default=10, help="Cases per dataset (default: 10)")
    config_group.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
    config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
    config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
    config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
    config_group.add_argument("--resume", action="store_true", help="Resume from checkpoint (skip already-completed cases)")
    config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")

    args = parser.parse_args()

    # Default to --all if nothing specified
    if not any([args.all, args.medqa, args.mtsamples, args.pmc]):
        args.all = True

    run_medqa = args.all or args.medqa
    run_mtsamples = args.all or args.mtsamples
    run_pmc = args.all or args.pmc

    print("=" * 58)
    print("   Clinical Decision Support Agent - Validation Suite")
    print("=" * 58)
    print(f"\n  Datasets:     {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
    print(f"  Cases/dataset: {args.max_cases}")
    print(f"  Drug check:    {'Yes' if not args.no_drugs else 'No'}")
    print(f"  Guidelines:    {'Yes' if not args.no_guidelines else 'No'}")
    print(f"  Resume:        {'Yes' if args.resume else 'No'}")
    print(f"  Fetch only:    {'Yes' if args.fetch_only else 'No'}")

    asyncio.run(run_all_validations(
        run_medqa=run_medqa,
        run_mtsamples=run_mtsamples,
        run_pmc=run_pmc,
        max_cases=args.max_cases,
        seed=args.seed,
        include_drug_check=not args.no_drugs,
        include_guidelines=not args.no_guidelines,
        delay=args.delay,
        fetch_only=args.fetch_only,
        resume=args.resume,
    ))


if __name__ == "__main__":
    main()