bshepp commited on
Commit
1f36481
·
1 Parent(s): 3d02eb2

MedGemma validation: 50-case MedQA run, TGI endpoint config, prompt improvements

Browse files

- Fixed TGI token limits (MAX_INPUT_TOKENS=12288, MAX_TOTAL_TOKENS=16384)
- Reduced per-step max_tokens for faster generation
- Improved clinical reasoning prompt (disease-level dx, not symptoms)
- Fixed Unicode encoding issues for Windows console
- Fixed error masking in orchestrator (failed steps now surface errors)
- Fixed endpoint URL in .env
- Added analyze_results.py for question-type categorization
- Results: 94% pipeline success, 38% top3 accuracy, 14% dx-only accuracy
- Paused endpoint to save costs

README.md CHANGED
@@ -333,6 +333,6 @@ curl -X POST http://localhost:8000/api/cases/submit \
333
 
334
  Licensed under the [Apache License 2.0](LICENSE).
335
 
336
- This project uses the Gemma model, which is subject to the [HAI-DEF Terms of Use](https://developers.google.com/health-ai-developer-foundations/terms).
337
 
338
  > **Disclaimer:** This is a research / demonstration system. It is NOT a substitute for professional medical judgment. All clinical decisions must be made by qualified healthcare professionals.
 
333
 
334
  Licensed under the [Apache License 2.0](LICENSE).
335
 
336
+ This project uses MedGemma and other models from Google's [Health AI Developer Foundations (HAI-DEF)](https://developers.google.com/health-ai-developer-foundations), subject to the [HAI-DEF Terms of Use](https://developers.google.com/health-ai-developer-foundations/terms).
337
 
338
  > **Disclaimer:** This is a research / demonstration system. It is NOT a substitute for professional medical judgment. All clinical decisions must be made by qualified healthcare professionals.
docs/deploy_medgemma_hf.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploying MedGemma 27B on HuggingFace Dedicated Endpoints
2
+
3
+ This guide walks through deploying `google/medgemma-27b-text-it` as a
4
+ HuggingFace Dedicated Inference Endpoint, which our CDS Agent calls via an
5
+ OpenAI-compatible API.
6
+
7
+ ## Why HuggingFace Endpoints?
8
+
9
+ | Feature | Details |
10
+ |---|---|
11
+ | **Model** | `google/medgemma-27b-text-it` (HAI-DEF, competition-required) |
12
+ | **Cost** | ~$2.50/hr (1× A100 80 GB on AWS) |
13
+ | **Scale-to-zero** | Yes — no charges while idle |
14
+ | **API format** | OpenAI-compatible (TGI) — zero code changes |
15
+ | **Setup time** | ~10 minutes |
16
+
17
+ ## Prerequisites
18
+
19
+ 1. **HuggingFace account** with a valid payment method.
20
+ 2. **MedGemma access** — accept the gated-model terms at
21
+ <https://huggingface.co/google/medgemma-27b-text-it>. You must agree to
22
+ Google's Health AI Developer Foundations (HAI-DEF) license.
23
+ 3. A **HuggingFace token** with `read` scope (already in `.env` as `HF_TOKEN`).
24
+
25
+ ## Step-by-step Deployment
26
+
27
+ ### 1. Create the endpoint
28
+
29
+ 1. Go to <https://ui.endpoints.huggingface.co/new>.
30
+ 2. **Model Repository**: `google/medgemma-27b-text-it`
31
+ 3. **Cloud Provider**: AWS (cheapest) or GCP
32
+ 4. **Region**: `us-east-1` (AWS) or `us-central1` (GCP)
33
+ 5. **Instance type**: GPU — **1× NVIDIA A100 80 GB**
34
+ - AWS: ~$2.50/hr
35
+ - GCP: ~$3.60/hr
36
+ 6. **Container type**: Text Generation Inference (TGI) — this is the default.
37
+ 7. **Advanced Settings**:
38
+ - **Max Input Length**: `32768`
39
+ - **Max Total Tokens**: `40960`
40
+ - **Quantization**: `none` (bfloat16 fits in 80 GB)
41
+ - **Scale-to-zero**: **Enable** (idle timeout: 15 min recommended)
42
+ 8. Click **Create Endpoint**.
43
+
44
+ ### 2. Wait for the endpoint to become ready
45
+
46
+ The first deployment downloads the model weights (~54 GB) and starts the TGI
47
+ server. This typically takes **5–15 minutes**. The status will change from
48
+ `Initializing` → `Running`.
49
+
50
+ ### 3. Configure the CDS Agent
51
+
52
+ Edit `src/backend/.env`:
53
+
54
+ ```dotenv
55
+ MEDGEMMA_API_KEY=hf_YOUR_TOKEN_HERE
56
+ MEDGEMMA_BASE_URL=https://YOUR_ENDPOINT_ID.us-east-1.aws.endpoints.huggingface.cloud/v1
57
+ MEDGEMMA_MODEL_ID=tgi
58
+ ```
59
+
60
+ - **`MEDGEMMA_API_KEY`**: Your HuggingFace token (same as `HF_TOKEN`).
61
+ - **`MEDGEMMA_BASE_URL`**: The endpoint URL from the HF dashboard, with `/v1`
62
+ appended. Example:
63
+ `https://x1y2z3.us-east-1.aws.endpoints.huggingface.cloud/v1`
64
+ - **`MEDGEMMA_MODEL_ID`**: Use `tgi` — TGI exposes the model under this name
65
+ by default. Alternatively, you can use the full model name
66
+ `google/medgemma-27b-text-it`.
67
+
68
+ ### 4. Verify the connection
69
+
70
+ ```bash
71
+ cd src/backend
72
+ python -c "
73
+ import asyncio
74
+ from app.services.medgemma import MedGemmaService
75
+
76
+ async def test():
77
+ svc = MedGemmaService()
78
+ r = await svc.generate('What is the differential diagnosis for chest pain?')
79
+ print(r[:200])
80
+
81
+ asyncio.run(test())
82
+ "
83
+ ```
84
+
85
+ You should see a clinical response from MedGemma.
86
+
87
+ ### 5. Run validation
88
+
89
+ ```bash
90
+ cd src/backend
91
+ python -m validation.run_validation --medqa --max-cases 50 --seed 42 --delay 2
92
+ ```
93
+
94
+ ## Cost Estimation
95
+
96
+ | Scenario | Hours | Cost |
97
+ |---|---|---|
98
+ | Validation run (120 cases @ ~1 min/case) | ~2 hrs | ~$5 |
99
+ | Development / debugging (4 hrs) | ~4 hrs | ~$10 |
100
+ | Competition demo recording | ~1 hr | ~$2.50 |
101
+ | **Total estimated** | **~7 hrs** | **~$17.50** |
102
+
103
+ With scale-to-zero enabled, the endpoint automatically shuts down after 15 min
104
+ of inactivity — no overnight charges.
105
+
106
+ ## Troubleshooting
107
+
108
+ ### Cold start latency
109
+ After scaling to zero, the first request takes 5–15 min while the model
110
+ reloads. Send a warm-up request before benchmarking.
111
+
112
+ ### 403 Forbidden
113
+ Your HF token may not have access to the gated model. Verify at
114
+ <https://huggingface.co/google/medgemma-27b-text-it> that your account has been
115
+ granted access.
116
+
117
+ ### Out of memory
118
+ If the endpoint fails to start, ensure you selected the **80 GB** A100, not the
119
+ 40 GB variant. MedGemma 27B in bfloat16 requires ~54 GB VRAM.
120
+
121
+ ### "model not found" error
122
+ TGI exposes the model as `tgi` by default. If you get a model-not-found error,
123
+ try setting `MEDGEMMA_MODEL_ID=google/medgemma-27b-text-it` or check the
124
+ endpoint's `/v1/models` route.
125
+
126
+ ## Deleting the Endpoint
127
+
128
+ When you're done, delete the endpoint from the HF dashboard to stop all
129
+ charges:
130
+
131
+ 1. Go to <https://ui.endpoints.huggingface.co/>
132
+ 2. Select your endpoint → **Settings** → **Delete**
133
+
134
+ ## Comparison with Alternatives
135
+
136
+ | Platform | GPU | $/hr | Scale-to-Zero | Code Changes | Setup |
137
+ |---|---|---|---|---|---|
138
+ | **HF Endpoints** | 1× A100 80 GB | **$2.50** | **Yes** | **None** | **Easy** |
139
+ | Vertex AI | a2-ultragpu-1g | $5.78 | No | Medium | Medium |
140
+ | AWS EC2 (g5.12xlarge) | 4× A10G 96 GB | $5.67 | No (manual) | High | Hard |
141
+ | AWS EC2 (p4de.24xlarge) | 8× A100 80 GB | $27.45 | No (manual) | High | Hard |
docs/writeup_draft.md CHANGED
@@ -53,15 +53,16 @@ Estimated reach: There are approximately 140 million ED visits per year in the U
53
 
54
  **HAI-DEF models used:**
55
 
56
- - **Gemma 3 27B IT** (`gemma-3-27b-it`) — accessed via Google AI Studio's OpenAI-compatible endpoint
 
57
 
58
- **Why this model:**
59
 
60
- Gemma 3 27B IT provides the right balance of capability and accessibility for a clinical decision support application:
61
- - Large enough to perform complex clinical reasoning with chain-of-thought transparency
62
  - Open-weight model that can be self-hosted for HIPAA compliance in production
63
- - Available via API for rapid development and demonstration
64
- - Part of the HAI-DEF family, designed with health AI applications in mind
65
 
66
  **How the model is used:**
67
 
@@ -96,7 +97,7 @@ All inter-step data is strongly typed with Pydantic v2 models. The pipeline stre
96
 
97
  **Fine-tuning:**
98
 
99
- No fine-tuning was performed in the current version. The base `gemma-3-27b-it` model was used with carefully crafted prompt engineering for each pipeline step. Fine-tuning on clinical reasoning datasets is a planned improvement.
100
 
101
  **Performance analysis:**
102
 
@@ -113,13 +114,13 @@ No fine-tuning was performed in the current version. The base `gemma-3-27b-it` m
113
  |-------|-----------|
114
  | Frontend | Next.js 14, React 18, TypeScript, Tailwind CSS |
115
  | Backend | FastAPI, Python 3.10, Pydantic v2, WebSocket |
116
- | LLM | Gemma 3 27B IT via Google AI Studio |
117
  | RAG | ChromaDB + sentence-transformers (all-MiniLM-L6-v2) |
118
  | Drug Data | OpenFDA API, RxNorm / NLM API |
119
 
120
  **Deployment considerations:**
121
 
122
- - **HIPAA compliance:** Gemma is an open-weight model that can be self-hosted on-premises, eliminating the need to send patient data to external APIs. This is critical for healthcare deployment.
123
  - **Latency:** Current pipeline takes ~75 s end-to-end. For production, this could be reduced with: smaller/distilled models, parallel LLM calls, or GPU-accelerated inference.
124
  - **Scalability:** FastAPI + uvicorn supports async request handling. For high-throughput deployment, add worker processes and a task queue (e.g., Celery).
125
  - **EHR integration:** Current input is manual text paste. A production system would integrate with EHR systems via FHIR APIs for automatic patient data extraction.
@@ -163,4 +164,4 @@ The system is explicitly designed as a **decision support** tool, not a decision
163
  - Video: [To be recorded]
164
  - Code Repository: [github.com/bshepp/clinical-decision-support-agent](https://github.com/bshepp/clinical-decision-support-agent)
165
  - Live Demo: [To be deployed]
166
- - Hugging Face Model: N/A (using base Gemma 3 27B IT)
 
53
 
54
  **HAI-DEF models used:**
55
 
56
+ - **MedGemma** (`google/medgemma-27b-text-it`) — Google's medical-domain model from the Health AI Developer Foundations (HAI-DEF) collection
57
+ - Development/validation also performed with **Gemma 3 27B IT** (`gemma-3-27b-it`) via Google AI Studio for rapid iteration
58
 
59
+ **Why MedGemma:**
60
 
61
+ MedGemma is purpose-built for medical applications and is part of Google's HAI-DEF collection:
62
+ - Trained specifically for health and biomedical tasks, providing stronger clinical reasoning than general-purpose models
63
  - Open-weight model that can be self-hosted for HIPAA compliance in production
64
+ - Large enough (27B parameters) for complex chain-of-thought clinical reasoning
65
+ - Designed to be the foundation for healthcare AI applications exactly what this competition demands
66
 
67
  **How the model is used:**
68
 
 
97
 
98
  **Fine-tuning:**
99
 
100
+ No fine-tuning was performed in the current version. The base MedGemma model (`medgemma-27b-text-it`) was used with carefully crafted prompt engineering for each pipeline step. Fine-tuning on clinical reasoning datasets is a planned improvement.
101
 
102
  **Performance analysis:**
103
 
 
114
  |-------|-----------|
115
  | Frontend | Next.js 14, React 18, TypeScript, Tailwind CSS |
116
  | Backend | FastAPI, Python 3.10, Pydantic v2, WebSocket |
117
+ | LLM | MedGemma 27B Text IT (HAI-DEF) + Gemma 3 27B IT for dev |
118
  | RAG | ChromaDB + sentence-transformers (all-MiniLM-L6-v2) |
119
  | Drug Data | OpenFDA API, RxNorm / NLM API |
120
 
121
  **Deployment considerations:**
122
 
123
+ - **HIPAA compliance:** MedGemma is an open-weight model that can be self-hosted on-premises, eliminating the need to send patient data to external APIs. This is critical for healthcare deployment.
124
  - **Latency:** Current pipeline takes ~75 s end-to-end. For production, this could be reduced with: smaller/distilled models, parallel LLM calls, or GPU-accelerated inference.
125
  - **Scalability:** FastAPI + uvicorn supports async request handling. For high-throughput deployment, add worker processes and a task queue (e.g., Celery).
126
  - **EHR integration:** Current input is manual text paste. A production system would integrate with EHR systems via FHIR APIs for automatic patient data extraction.
 
164
  - Video: [To be recorded]
165
  - Code Repository: [github.com/bshepp/clinical-decision-support-agent](https://github.com/bshepp/clinical-decision-support-agent)
166
  - Live Demo: [To be deployed]
167
+ - Hugging Face Model: [google/medgemma-27b-text-it](https://huggingface.co/google/medgemma-27b-text-it)
src/backend/analyze_results.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-analysis of MedQA validation results.
3
+
4
+ Categorizes questions by type and reports accuracy for each category.
5
+ This is important because the CDS pipeline focuses on DIAGNOSIS, while
6
+ MedQA includes many non-diagnostic questions (pharmacology, management,
7
+ biostatistics, pathophysiology).
8
+ """
9
+ import json
10
+ import re
11
+ from collections import defaultdict
12
+ from pathlib import Path
13
+
14
+ CHECKPOINT = Path("validation/results/medqa_checkpoint.jsonl")
15
+ DATA_FILE = Path("validation/data/medqa_test.jsonl")
16
+
17
+
18
+ def classify_answer(correct_answer: str, full_question: str = "") -> str:
19
+ """Classify the MedQA answer type.
20
+
21
+ Categories:
22
+ - diagnosis: Answer is a disease, condition, or syndrome
23
+ - treatment: Answer is a drug, procedure, or intervention
24
+ - management: Answer is a management strategy (reassurance, referral, etc.)
25
+ - pathophysiology: Answer is a mechanism, pathway, or biochemical entity
26
+ - statistics: Answer is about study design, statistics, or epidemiology
27
+ - anatomy: Answer is about anatomy/location
28
+ - other: Everything else
29
+ """
30
+ answer = correct_answer.lower().strip()
31
+ question = full_question.lower()
32
+
33
+ # Statistics / study design
34
+ stats_patterns = [
35
+ r"type [12] error", r"null hypothesis", r"p.value", r"confidence interval",
36
+ r"odds ratio", r"relative risk", r"sensitivity", r"specificity",
37
+ r"positive predictive", r"negative predictive", r"number needed",
38
+ r"standard deviation", r"study design", r"randomized", r"case.control",
39
+ r"cohort study", r"cross.sectional", r"meta.analysis", r"selection bias",
40
+ r"recall bias", r"confounding", r"blinding", r"power of",
41
+ ]
42
+ for p in stats_patterns:
43
+ if re.search(p, answer) or re.search(p, question):
44
+ return "statistics"
45
+
46
+ # Treatment / pharmacology (drugs, procedures, interventions)
47
+ treatment_patterns = [
48
+ r"^start\b", r"^administer\b", r"^give\b", r"^prescribe\b",
49
+ r"^begin\b", r"^initiate\b", r"surgery", r"laparotomy",
50
+ r"laparoscop", r"analgesia", r"^reassurance", r"^observation",
51
+ r"^follow.up", r"^refer", r"^discharge",
52
+ r"corticosteroid", r"hydrocortisone", r"fludrocortisone",
53
+ r"prednisone", r"methylprednisolone", r"dexamethasone",
54
+ r"amitriptyline", r"fluoxetine", r"sertraline", r"metformin",
55
+ r"insulin", r"heparin", r"warfarin", r"aspirin",
56
+ r"amoxicillin", r"azithromycin", r"ceftriaxone",
57
+ r"exploratory", r"endoscop",
58
+ ]
59
+ for p in treatment_patterns:
60
+ if re.search(p, answer):
61
+ return "treatment"
62
+
63
+ # Management strategies
64
+ management_patterns = [
65
+ r"reassurance", r"watchful waiting", r"follow.up", r"counseling",
66
+ r"lifestyle", r"observation", r"monitor", r"admit",
67
+ r"discharge", r"consult",
68
+ ]
69
+ for p in management_patterns:
70
+ if re.search(p, answer):
71
+ return "management"
72
+
73
+ # Pathophysiology / biochemistry
74
+ patho_patterns = [
75
+ r"prostaglandin", r"acetaldehyde", r"histamine", r"serotonin",
76
+ r"dopamine", r"cytokine", r"interleukin", r"antibod",
77
+ r"complement", r"release of", r"synthesis of", r"inhibition of",
78
+ r"degradation of", r"mutation in", r"deficiency of",
79
+ r"mechanism", r"pathway", r"receptor", r"kinase",
80
+ r"affective symptoms", r"diagnosis of exclusion",
81
+ ]
82
+ for p in patho_patterns:
83
+ if re.search(p, answer):
84
+ return "pathophysiology"
85
+
86
+ # Anatomy
87
+ anatomy_patterns = [
88
+ r"lytic lesions", r"fracture", r"artery", r"vein",
89
+ r"nerve", r"muscle", r"bone", r"ligament",
90
+ r"right.sided", r"left.sided", r"posterior", r"anterior",
91
+ ]
92
+ for p in anatomy_patterns:
93
+ if re.search(p, answer):
94
+ return "anatomy"
95
+
96
+ # Default: assume it's a diagnosis
97
+ return "diagnosis"
98
+
99
+
100
+ def analyze():
101
+ if not CHECKPOINT.exists():
102
+ print("No checkpoint file found. Run validation first.")
103
+ return
104
+
105
+ # Load results
106
+ results = []
107
+ for line in CHECKPOINT.read_text(encoding="utf-8").strip().split("\n"):
108
+ if line.strip():
109
+ results.append(json.loads(line))
110
+
111
+ # Load original questions for classification
112
+ questions = {}
113
+ if DATA_FILE.exists():
114
+ raw = DATA_FILE.read_text(encoding="utf-8").strip().split("\n")
115
+ for item_str in raw:
116
+ item = json.loads(item_str)
117
+ questions[item.get("question", "")] = item
118
+
119
+ # Classify and categorize
120
+ categories = defaultdict(list)
121
+
122
+ for r in results:
123
+ det = r.get("details", {})
124
+ correct = det.get("correct_answer", "")
125
+ full_q = det.get("full_question", "")
126
+
127
+ # Try to get the full question from the ground truth
128
+ if not full_q:
129
+ for case_key in r.get("ground_truth", {}).keys():
130
+ pass # Fallback
131
+
132
+ cat = classify_answer(correct, full_q)
133
+ categories[cat].append(r)
134
+
135
+ # Print summary
136
+ print("=" * 70)
137
+ print(" MedQA RESULTS BY QUESTION CATEGORY")
138
+ print("=" * 70)
139
+
140
+ total_cases = len(results)
141
+ total_mentioned = sum(1 for r in results if r.get("details", {}).get("match_location", "not_found") != "not_found")
142
+ total_diff = sum(1 for r in results if r.get("details", {}).get("match_location") == "differential")
143
+
144
+ print(f"\n OVERALL: {total_cases} cases | Mentioned: {total_mentioned}/{total_cases} ({100*total_mentioned/total_cases:.0f}%) | Differential: {total_diff}/{total_cases} ({100*total_diff/total_cases:.0f}%)")
145
+
146
+ print(f"\n {'Category':<20} {'Count':>6} {'Mentioned':>10} {'Differential':>13} {'Pipeline OK':>12}")
147
+ print(f" {'-'*20} {'-'*6} {'-'*10} {'-'*13} {'-'*12}")
148
+
149
+ for cat in sorted(categories.keys()):
150
+ items = categories[cat]
151
+ n = len(items)
152
+ mentioned = sum(1 for r in items if r.get("details", {}).get("match_location", "not_found") != "not_found")
153
+ differential = sum(1 for r in items if r.get("details", {}).get("match_location") == "differential")
154
+ success = sum(1 for r in items if r.get("success"))
155
+
156
+ mentioned_pct = f"{100*mentioned/n:.0f}%" if n > 0 else "N/A"
157
+ diff_pct = f"{100*differential/n:.0f}%" if n > 0 else "N/A"
158
+ success_pct = f"{100*success/n:.0f}%" if n > 0 else "N/A"
159
+
160
+ print(f" {cat:<20} {n:>6} {mentioned:>5} ({mentioned_pct:>4}) {differential:>7} ({diff_pct:>4}) {success:>6} ({success_pct:>4})")
161
+
162
+ # Detailed per-case
163
+ print(f"\n DETAILED PER-CASE RESULTS:")
164
+ print(f" {'Case':<14} {'Cat':<15} {'Location':<14} {'Correct':<35} {'Top Dx':<35}")
165
+ print(f" {'-'*14} {'-'*15} {'-'*14} {'-'*35} {'-'*35}")
166
+
167
+ for r in results:
168
+ det = r.get("details", {})
169
+ correct = det.get("correct_answer", "?")[:34]
170
+ top = det.get("top_diagnosis", "?")[:34]
171
+ loc = det.get("match_location", "not_found")
172
+ cat = classify_answer(det.get("correct_answer", ""))
173
+
174
+ print(f" {r['case_id']:<14} {cat:<15} {loc:<14} {correct:<35} {top:<35}")
175
+
176
+ # Key insight
177
+ diag_items = categories.get("diagnosis", [])
178
+ if diag_items:
179
+ d_mentioned = sum(1 for r in diag_items if r.get("details", {}).get("match_location", "not_found") != "not_found")
180
+ d_diff = sum(1 for r in diag_items if r.get("details", {}).get("match_location") == "differential")
181
+ d_n = len(diag_items)
182
+ print(f"\n KEY INSIGHT:")
183
+ print(f" On DIAGNOSTIC questions only: Mentioned {d_mentioned}/{d_n} ({100*d_mentioned/d_n:.0f}%), Differential {d_diff}/{d_n} ({100*d_diff/d_n:.0f}%)")
184
+ print(f" The CDS pipeline is designed for diagnosis support; non-diagnostic questions")
185
+ print(f" (treatment, stats, pathophysiology) are outside its intended scope.")
186
+
187
+
188
+ if __name__ == "__main__":
189
+ analyze()
src/backend/app/config.py CHANGED
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
27
  rxnorm_base_url: str = "https://rxnav.nlm.nih.gov/REST"
28
  pubmed_base_url: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
29
  pubmed_api_key: str = "" # Optional, increases rate limits
 
30
 
31
  # RAG
32
  chroma_persist_dir: str = "./data/chroma"
 
27
  rxnorm_base_url: str = "https://rxnav.nlm.nih.gov/REST"
28
  pubmed_base_url: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
29
  pubmed_api_key: str = "" # Optional, increases rate limits
30
+ hf_token: str = "" # HuggingFace token for dataset downloads
31
 
32
  # RAG
33
  chroma_persist_dir: str = "./data/chroma"
src/backend/app/services/medgemma.py CHANGED
@@ -125,25 +125,44 @@ class MedGemmaService:
125
  async def _generate_api(
126
  self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
127
  ) -> str:
128
- """Generate via OpenAI-compatible API."""
 
 
 
 
 
 
 
129
  client = await self._get_client()
130
 
131
  messages = []
132
- # Some models (e.g. Gemma via Google AI Studio) don't support system role.
133
- # Try with system prompt first, fall back to folding it into the user message.
134
  if system_prompt:
135
- user_content = f"{system_prompt}\n\n{prompt}"
136
- else:
137
- user_content = prompt
138
- messages.append({"role": "user", "content": user_content})
139
-
140
- response = await client.chat.completions.create(
141
- model=settings.medgemma_model_id,
142
- messages=messages,
143
- max_tokens=max_tokens,
144
- temperature=temperature,
145
- )
146
- return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  async def _generate_local(
149
  self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
 
125
  async def _generate_api(
126
  self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
127
  ) -> str:
128
+ """Generate via OpenAI-compatible API.
129
+
130
+ MedGemma (served by TGI on HuggingFace Endpoints) natively supports the
131
+ system role, so we send system/user messages properly. If the backend
132
+ happens to be plain Gemma on Google AI Studio (which rejects the system
133
+ role), we automatically fall back to folding the system prompt into the
134
+ user message.
135
+ """
136
  client = await self._get_client()
137
 
138
  messages = []
 
 
139
  if system_prompt:
140
+ messages.append({"role": "system", "content": system_prompt})
141
+ messages.append({"role": "user", "content": prompt})
142
+
143
+ try:
144
+ response = await client.chat.completions.create(
145
+ model=settings.medgemma_model_id,
146
+ messages=messages,
147
+ max_tokens=max_tokens,
148
+ temperature=temperature,
149
+ )
150
+ return response.choices[0].message.content
151
+ except Exception as e:
152
+ # Fallback: fold system prompt into user message (Google AI Studio compat)
153
+ if system_prompt and "system" in str(e).lower():
154
+ logger.warning("Backend rejected system role — folding into user message.")
155
+ fallback_messages = [
156
+ {"role": "user", "content": f"{system_prompt}\n\n{prompt}"}
157
+ ]
158
+ response = await client.chat.completions.create(
159
+ model=settings.medgemma_model_id,
160
+ messages=fallback_messages,
161
+ max_tokens=max_tokens,
162
+ temperature=temperature,
163
+ )
164
+ return response.choices[0].message.content
165
+ raise
166
 
167
  async def _generate_local(
168
  self, prompt: str, system_prompt: Optional[str], max_tokens: int, temperature: float
src/backend/app/tools/clinical_reasoning.py CHANGED
@@ -16,16 +16,21 @@ from app.services.medgemma import MedGemmaService
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
- SYSTEM_PROMPT = """You are an expert clinical reasoning assistant. Given a structured
20
- patient profile, perform systematic clinical reasoning to generate a differential
21
- diagnosis, risk assessment, and recommended workup.
22
-
23
- IMPORTANT GUIDELINES:
24
- - Think step-by-step through the clinical reasoning process
 
 
 
 
 
25
  - Consider the most likely diagnoses first, then less common but important ones
26
  - Always consider dangerous "can't miss" diagnoses
27
- - Base your reasoning on the available evidence (symptoms, labs, history)
28
- - Be explicit about your reasoning chain
29
  - Rate likelihood as "low", "moderate", or "high"
30
  - Rate priority of actions as "low", "moderate", "high", or "critical"
31
  - This is a decision SUPPORT tool — always recommend clinician judgment"""
@@ -86,7 +91,7 @@ class ClinicalReasoningTool:
86
  response_model=ClinicalReasoningResult,
87
  system_prompt=SYSTEM_PROMPT,
88
  temperature=0.3,
89
- max_tokens=4096,
90
  )
91
 
92
  logger.info(
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
+ SYSTEM_PROMPT = """You are an expert clinical reasoning assistant trained in USMLE-level
20
+ diagnostic reasoning. Given a structured patient profile, perform systematic clinical
21
+ reasoning to generate a differential diagnosis, risk assessment, and recommended workup.
22
+
23
+ CRITICAL GUIDELINES:
24
+ - Each diagnosis MUST be a specific DISEASE or PATHOLOGICAL CONDITION (the root cause),
25
+ NOT a symptom, sign, lab finding, or descriptive term.
26
+ GOOD: "Primary hyperaldosteronism (Conn syndrome)", "Chikungunya fever",
27
+ "Clear cell adenocarcinoma of the cervix"
28
+ BAD: "Hypokalemia", "Fatigue", "Metabolic alkalosis", "Muscle cramps"
29
+ - Think step-by-step: symptoms -> pathophysiology -> ETIOLOGICAL diagnosis
30
  - Consider the most likely diagnoses first, then less common but important ones
31
  - Always consider dangerous "can't miss" diagnoses
32
+ - Include at least 5 differential diagnoses when clinically reasonable
33
+ - For each diagnosis, cite the specific findings that support or argue against it
34
  - Rate likelihood as "low", "moderate", or "high"
35
  - Rate priority of actions as "low", "moderate", "high", or "critical"
36
  - This is a decision SUPPORT tool — always recommend clinician judgment"""
 
91
  response_model=ClinicalReasoningResult,
92
  system_prompt=SYSTEM_PROMPT,
93
  temperature=0.3,
94
+ max_tokens=3072,
95
  )
96
 
97
  logger.info(
src/backend/app/tools/conflict_detection.py CHANGED
@@ -125,7 +125,7 @@ class ConflictDetectionTool:
125
  response_model=ConflictDetectionResult,
126
  system_prompt=SYSTEM_PROMPT,
127
  temperature=0.1, # Low temp for safety-critical analysis
128
- max_tokens=4096,
129
  )
130
 
131
  # Fill in metadata
 
125
  response_model=ConflictDetectionResult,
126
  system_prompt=SYSTEM_PROMPT,
127
  temperature=0.1, # Low temp for safety-critical analysis
128
+ max_tokens=2000,
129
  )
130
 
131
  # Fill in metadata
src/backend/app/tools/patient_parser.py CHANGED
@@ -61,6 +61,7 @@ class PatientParserTool:
61
  response_model=PatientProfile,
62
  system_prompt=SYSTEM_PROMPT,
63
  temperature=0.1, # Low temp for factual extraction
 
64
  )
65
  logger.info(f"Parsed patient profile: {profile.chief_complaint}")
66
  return profile
 
61
  response_model=PatientProfile,
62
  system_prompt=SYSTEM_PROMPT,
63
  temperature=0.1, # Low temp for factual extraction
64
+ max_tokens=1500,
65
  )
66
  logger.info(f"Parsed patient profile: {profile.chief_complaint}")
67
  return profile
src/backend/app/tools/synthesis.py CHANGED
@@ -22,47 +22,75 @@ from app.services.medgemma import MedGemmaService
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
- SYSTEM_PROMPT = """You are a clinical decision support synthesis engine. Your job is to
26
- combine outputs from multiple clinical tools into a single, cohesive report for a clinician.
27
-
28
- CRITICAL RULES:
29
- 1. Be concise and clinically precise
30
- 2. Prioritize safety drug interactions and critical findings go first
31
- 3. Clearly distinguish between tool-verified facts and model-generated reasoning
32
- 4. Always include caveats and limitations
33
- 5. Cite sources when available
34
- 6. This report SUPPORTS clinical decision-making it does NOT replace clinician judgment
35
- 7. Include a standard disclaimer about AI-generated content"""
36
-
37
- SYNTHESIS_PROMPT = """Synthesize the following clinical tool outputs into a cohesive
38
- Clinical Decision Support report.
 
 
 
 
 
 
 
 
39
 
40
  ═══ PATIENT PROFILE ═══
41
  {patient_profile}
42
 
43
- ═══ CLINICAL REASONING (MedGemma) ═══
44
  {clinical_reasoning}
45
 
46
- ═══ DRUG INTERACTION CHECK ═══
47
  {drug_interactions}
48
 
49
- ═══ CLINICAL GUIDELINES ═══
50
  {guidelines}
51
 
52
- ═══ CONFLICTS & GAPS DETECTED ═══
53
  {conflicts}
54
 
55
- Create a comprehensive CDS report including:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  1. Patient Summary — concise summary of the case
57
- 2. Differential Diagnosis — ranked with reasoning, integrating guideline concordance
 
58
  3. Drug Interaction Warnings — any flagged interactions with clinical significance
59
  4. Guideline-Concordant Recommendations — actionable steps aligned with guidelines
60
- 5. Conflicts & Gaps — PROMINENTLY include every detected conflict. For each conflict,
61
- state what the guideline recommends, what the patient's current state is, and the
62
- suggested resolution. This section is CRITICAL for patient safety.
63
- 6. Suggested Next Steps prioritized actions for the clinician, incorporating conflict resolutions
64
- 7. Caveatslimitations, uncertainties, and important disclaimers
65
- 8. Sources — cited guidelines and data sources used"""
 
 
66
 
67
 
68
  class SynthesisTool:
@@ -104,7 +132,7 @@ class SynthesisTool:
104
  response_model=CDSReport,
105
  system_prompt=SYSTEM_PROMPT,
106
  temperature=0.2,
107
- max_tokens=4096,
108
  )
109
 
110
  # Add standard disclaimer to caveats
 
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
+ SYSTEM_PROMPT = """You are an expert clinical arbiter and decision support engine. You receive
26
+ an initial differential diagnosis from a clinical reasoning agent, PLUS independent evidence
27
+ from drug-interaction checks, clinical guideline retrieval, and conflict detection.
28
+
29
+ Your job is NOT merely to format these outputs. You are the FINAL DECISION MAKER:
30
+ 1. CRITICALLY RE-EVALUATE the initial differential using ALL available evidence.
31
+ 2. RE-RANK diagnoses: promote diagnoses that gain guideline/drug/conflict support;
32
+ demote diagnoses that lose support or are contradicted.
33
+ 3. ADD any diagnosis that the evidence strongly suggests but was MISSING from the initial list.
34
+ 4. REMOVE or deprioritize diagnoses that are inconsistent with guideline-based evidence.
35
+ 5. For the top diagnosis, explicitly state which evidence (guideline excerpts, drug signals,
36
+ conflict findings) supports or contradicts it.
37
+ 6. Prioritize safety drug interactions and critical conflicts go first.
38
+ 7. This report SUPPORTS clinical decision-making — it does NOT replace clinician judgment.
39
+ 8. Be concise and clinically precise. Cite sources.
40
+
41
+ You are an independent reviewer, not a rubber stamp. If the initial reasoning is wrong,
42
+ override it with evidence-based conclusions."""
43
+
44
+ SYNTHESIS_PROMPT = """You are given outputs from multiple independent clinical analysis tools.
45
+ Your task is to act as an ARBITER: critically evaluate all evidence and produce a final,
46
+ evidence-based Clinical Decision Support report.
47
 
48
  ═══ PATIENT PROFILE ═══
49
  {patient_profile}
50
 
51
+ ═══ INITIAL CLINICAL REASONING (from reasoning agent) ═══
52
  {clinical_reasoning}
53
 
54
+ ═══ DRUG INTERACTION CHECK (independent tool) ═══
55
  {drug_interactions}
56
 
57
+ ═══ CLINICAL GUIDELINES (RAG retrieval — independent evidence) ═══
58
  {guidelines}
59
 
60
+ ═══ CONFLICTS & GAPS DETECTED (independent analysis) ═══
61
  {conflicts}
62
 
63
+ ══════════════════════════════════════
64
+ ARBITRATION INSTRUCTIONS — Follow these steps:
65
+ ══════════════════════════════════════
66
+
67
+ STEP 1 — CHALLENGE THE INITIAL DIFFERENTIAL:
68
+ For each diagnosis in the initial reasoning, ask:
69
+ • Does the guideline evidence SUPPORT or CONTRADICT this diagnosis?
70
+ • Do the drug interactions or conflict findings change the likelihood?
71
+ • Is there a diagnosis NOT in the initial list that the guidelines strongly suggest?
72
+
73
+ STEP 2 — RE-RANK AND REVISE:
74
+ Produce a REVISED differential diagnosis list. This may differ from the initial one.
75
+ • Promote diagnoses with strong guideline concordance.
76
+ • Demote diagnoses contradicted by evidence.
77
+ • Add new diagnoses suggested by guideline/conflict evidence.
78
+ • For each diagnosis, state the supporting AND contradicting evidence.
79
+
80
+ STEP 3 — PRODUCE THE FINAL REPORT:
81
  1. Patient Summary — concise summary of the case
82
+ 2. Differential Diagnosis — YOUR REVISED ranking (not just a copy of the initial one),
83
+ with explicit evidence citations for each diagnosis
84
  3. Drug Interaction Warnings — any flagged interactions with clinical significance
85
  4. Guideline-Concordant Recommendations — actionable steps aligned with guidelines
86
+ 5. Conflicts & Gaps — PROMINENTLY include every detected conflict. For each:
87
+ state what the guideline recommends vs. patient's current state, and the resolution.
88
+ 6. Suggested Next Steps prioritized actions incorporating ALL evidence
89
+ 7. Caveatslimitations, uncertainties, disclaimers
90
+ 8. Sourcescited guidelines and data sources
91
+
92
+ IMPORTANT: Your differential diagnosis MUST reflect your independent arbiter judgment,
93
+ not merely repeat the initial reasoning. If evidence changes the ranking, CHANGE IT."""
94
 
95
 
96
  class SynthesisTool:
 
132
  response_model=CDSReport,
133
  system_prompt=SYSTEM_PROMPT,
134
  temperature=0.2,
135
+ max_tokens=3000,
136
  )
137
 
138
  # Add standard disclaimer to caveats
src/backend/check_progress.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick progress checker for validation run."""
2
+ import json
3
+ from pathlib import Path
4
+
5
+ checkpoint = Path("validation/results/medqa_checkpoint.jsonl")
6
+ if not checkpoint.exists():
7
+ print("No checkpoint file found")
8
+ exit()
9
+
10
+ lines = checkpoint.read_text(encoding="utf-8").strip().split("\n")
11
+ print(f"Completed: {len(lines)}/50")
12
+
13
+ matches = 0
14
+ diff_matches = 0
15
+ top3_matches = 0
16
+ failures = 0
17
+
18
+ for line in lines:
19
+ d = json.loads(line)
20
+ det = d.get("details", {})
21
+ scores = d.get("scores", {})
22
+ loc = det.get("match_location", "not_found")
23
+
24
+ if not d.get("success"):
25
+ failures += 1
26
+ if loc != "not_found":
27
+ matches += 1
28
+ if loc == "differential":
29
+ diff_matches += 1
30
+ if scores.get("top3_accuracy", 0) > 0:
31
+ top3_matches += 1
32
+
33
+ print(f"Pipeline success: {len(lines) - failures}/{len(lines)}")
34
+ print(f"Mentioned matches: {matches}/{len(lines)} ({100*matches/len(lines):.0f}%)")
35
+ print(f"Differential matches: {diff_matches}/{len(lines)} ({100*diff_matches/len(lines):.0f}%)")
36
+ print(f"Top-3 matches: {top3_matches}/{len(lines)} ({100*top3_matches/len(lines):.0f}%)")
37
+
38
+ # Show last 5 cases
39
+ print("\nRecent cases:")
40
+ for line in lines[-5:]:
41
+ d = json.loads(line)
42
+ det = d.get("details", {})
43
+ correct = det.get("correct_answer", "?")[:45]
44
+ top = det.get("top_diagnosis", "?")[:45]
45
+ loc = det.get("match_location", "not_found")
46
+ t = d.get("pipeline_time_ms", 0)
47
+ print(f" {d['case_id']}: [{loc}] {t/1000:.0f}s | correct={correct} | top={top}")
src/backend/validation/base.py CHANGED
@@ -28,7 +28,7 @@ if str(BACKEND_DIR) not in sys.path:
28
  sys.path.insert(0, str(BACKEND_DIR))
29
 
30
  from app.agent.orchestrator import Orchestrator
31
- from app.models.schemas import CaseSubmission, CDSReport, AgentState
32
 
33
 
34
  # ──────────────────────────────────────────────
@@ -103,7 +103,20 @@ async def run_cds_pipeline(
103
  async for _step_update in orchestrator.run(case):
104
  pass # consume all step updates
105
 
106
- return orchestrator.state, orchestrator.get_result(), None
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  except asyncio.TimeoutError:
108
  return orchestrator.state, None, f"Pipeline timed out after {timeout_sec}s"
109
  except Exception as e:
@@ -159,12 +172,14 @@ def diagnosis_in_differential(
159
  target_diagnosis: str,
160
  report: CDSReport,
161
  top_n: Optional[int] = None,
162
- ) -> tuple[bool, int]:
163
  """
164
  Check if target_diagnosis appears in the report's differential.
165
 
166
  Returns:
167
- (found, rank) — rank is 0-indexed position, or -1 if not found
 
 
168
  """
169
  diagnoses = report.differential_diagnosis
170
  if top_n:
@@ -172,18 +187,29 @@ def diagnosis_in_differential(
172
 
173
  for i, dx in enumerate(diagnoses):
174
  if fuzzy_match(dx.diagnosis, target_diagnosis):
175
- return True, i
 
 
 
 
 
 
 
 
 
 
176
 
177
- # Also check the full report text (patient_summary, guideline_recommendations, etc.)
178
  full_text = " ".join([
179
  report.patient_summary or "",
180
  " ".join(report.guideline_recommendations),
181
  " ".join(a.action for a in report.suggested_next_steps),
 
182
  ])
183
  if fuzzy_match(full_text, target_diagnosis, threshold=0.3):
184
- return True, len(diagnoses) # found but not in differential
185
 
186
- return False, -1
187
 
188
 
189
  # ──────────────────────────────────────────────
 
28
  sys.path.insert(0, str(BACKEND_DIR))
29
 
30
  from app.agent.orchestrator import Orchestrator
31
+ from app.models.schemas import CaseSubmission, CDSReport, AgentState, AgentStepStatus
32
 
33
 
34
  # ──────────────────────────────────────────────
 
103
  async for _step_update in orchestrator.run(case):
104
  pass # consume all step updates
105
 
106
+ report = orchestrator.get_result()
107
+
108
+ # If no report was produced, collect errors from failed steps
109
+ if report is None and orchestrator.state:
110
+ failed_steps = [
111
+ s for s in orchestrator.state.steps
112
+ if s.status == AgentStepStatus.FAILED
113
+ ]
114
+ if failed_steps:
115
+ error_msgs = [f"{s.step_id}: {s.error}" for s in failed_steps]
116
+ return orchestrator.state, None, "; ".join(error_msgs)
117
+ return orchestrator.state, None, "Pipeline completed but produced no report"
118
+
119
+ return orchestrator.state, report, None
120
  except asyncio.TimeoutError:
121
  return orchestrator.state, None, f"Pipeline timed out after {timeout_sec}s"
122
  except Exception as e:
 
172
  target_diagnosis: str,
173
  report: CDSReport,
174
  top_n: Optional[int] = None,
175
+ ) -> tuple[bool, int, str]:
176
  """
177
  Check if target_diagnosis appears in the report's differential.
178
 
179
  Returns:
180
+ (found, rank, match_location) — rank is 0-indexed position, or -1 if not found.
181
+ match_location is one of: "differential", "next_steps", "recommendations",
182
+ "fulltext", or "not_found".
183
  """
184
  diagnoses = report.differential_diagnosis
185
  if top_n:
 
187
 
188
  for i, dx in enumerate(diagnoses):
189
  if fuzzy_match(dx.diagnosis, target_diagnosis):
190
+ return True, i, "differential"
191
+
192
+ # Check suggested_next_steps (for management-type answers)
193
+ for i, action in enumerate(report.suggested_next_steps):
194
+ if fuzzy_match(action.action, target_diagnosis):
195
+ return True, len(diagnoses) + i, "next_steps"
196
+
197
+ # Check guideline recommendations (for treatment-type answers)
198
+ for i, rec in enumerate(report.guideline_recommendations):
199
+ if fuzzy_match(rec, target_diagnosis):
200
+ return True, len(diagnoses) + len(report.suggested_next_steps) + i, "recommendations"
201
 
202
+ # Broad fulltext check (patient_summary, recommendations, next steps combined)
203
  full_text = " ".join([
204
  report.patient_summary or "",
205
  " ".join(report.guideline_recommendations),
206
  " ".join(a.action for a in report.suggested_next_steps),
207
+ " ".join(dx.reasoning for dx in report.differential_diagnosis),
208
  ])
209
  if fuzzy_match(full_text, target_diagnosis, threshold=0.3):
210
+ return True, len(diagnoses), "fulltext"
211
 
212
+ return False, -1, "not_found"
213
 
214
 
215
  # ──────────────────────────────────────────────
src/backend/validation/harness_medqa.py CHANGED
@@ -243,39 +243,56 @@ async def validate_medqa(
243
  correct_answer = case.ground_truth["correct_answer"]
244
 
245
  if report:
246
- # Top-1 accuracy
247
- found_top1, rank = diagnosis_in_differential(correct_answer, report, top_n=1)
248
  scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
249
 
250
- # Top-3 accuracy
251
- found_top3, rank3 = diagnosis_in_differential(correct_answer, report, top_n=3)
252
  scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
253
 
254
- # Mentioned anywhere
255
- found_any, rank_any = diagnosis_in_differential(correct_answer, report)
256
  scores["mentioned_accuracy"] = 1.0 if found_any else 0.0
257
 
 
 
 
 
258
  # Parse success
259
  scores["parse_success"] = 1.0
260
 
 
 
 
 
 
261
  details = {
262
  "correct_answer": correct_answer,
263
- "top_diagnosis": report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE",
 
 
 
264
  "num_diagnoses": len(report.differential_diagnosis),
265
  "found_at_rank": rank_any if found_any else -1,
 
 
266
  }
267
 
268
- status_icon = "✓" if found_top3 else "✗"
269
- print(f"{status_icon} top1={'Y' if found_top1 else 'N'} top3={'Y' if found_top3 else 'N'} ({elapsed_ms}ms)")
 
 
270
  else:
271
  scores = {
272
  "top1_accuracy": 0.0,
273
  "top3_accuracy": 0.0,
274
  "mentioned_accuracy": 0.0,
 
275
  "parse_success": 0.0,
276
  }
277
- details = {"correct_answer": correct_answer, "error": error}
278
- print(f" FAILED: {error[:80] if error else 'unknown'}")
279
 
280
  result = ValidationResult(
281
  case_id=case.case_id,
@@ -300,7 +317,7 @@ async def validate_medqa(
300
  successful = sum(1 for r in results if r.success)
301
 
302
  # Average each metric across successful cases only
303
- metric_names = ["top1_accuracy", "top3_accuracy", "mentioned_accuracy", "parse_success"]
304
  metrics = {}
305
  for m in metric_names:
306
  values = [r.scores.get(m, 0.0) for r in results]
 
243
  correct_answer = case.ground_truth["correct_answer"]
244
 
245
  if report:
246
+ # Top-1 accuracy (differential only)
247
+ found_top1, rank1, loc1 = diagnosis_in_differential(correct_answer, report, top_n=1)
248
  scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
249
 
250
+ # Top-3 accuracy (differential only)
251
+ found_top3, rank3, loc3 = diagnosis_in_differential(correct_answer, report, top_n=3)
252
  scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
253
 
254
+ # Mentioned anywhere (differential + next_steps + recommendations + fulltext)
255
+ found_any, rank_any, loc_any = diagnosis_in_differential(correct_answer, report)
256
  scores["mentioned_accuracy"] = 1.0 if found_any else 0.0
257
 
258
+ # Differential-only accuracy (strict: only counts differential matches)
259
+ found_diff_only, rank_diff, loc_diff = diagnosis_in_differential(correct_answer, report)
260
+ scores["differential_accuracy"] = 1.0 if (found_diff_only and loc_diff == "differential") else 0.0
261
+
262
  # Parse success
263
  scores["parse_success"] = 1.0
264
 
265
+ # Rich details for debugging
266
+ all_dx = [dx.diagnosis for dx in report.differential_diagnosis]
267
+ all_next = [a.action for a in report.suggested_next_steps]
268
+ all_recs = list(report.guideline_recommendations)
269
+
270
  details = {
271
  "correct_answer": correct_answer,
272
+ "top_diagnosis": all_dx[0] if all_dx else "NONE",
273
+ "all_diagnoses": all_dx,
274
+ "all_next_steps": all_next[:5],
275
+ "all_recommendations": all_recs[:5],
276
  "num_diagnoses": len(report.differential_diagnosis),
277
  "found_at_rank": rank_any if found_any else -1,
278
+ "match_location": loc_any,
279
+ "patient_summary": report.patient_summary[:300] if report.patient_summary else "",
280
  }
281
 
282
+ # Richer console output
283
+ loc_tag = f"[{loc_any}]" if found_any else ""
284
+ status_icon = "+" if found_any else "-"
285
+ print(f"{status_icon} top1={'Y' if found_top1 else 'N'} top3={'Y' if found_top3 else 'N'} diff={'Y' if loc_any=='differential' else 'N'} {loc_tag} ({elapsed_ms}ms)")
286
  else:
287
  scores = {
288
  "top1_accuracy": 0.0,
289
  "top3_accuracy": 0.0,
290
  "mentioned_accuracy": 0.0,
291
+ "differential_accuracy": 0.0,
292
  "parse_success": 0.0,
293
  }
294
+ details = {"correct_answer": correct_answer, "error": error, "match_location": "not_found"}
295
+ print(f"- FAILED: {error[:80] if error else 'unknown'}")
296
 
297
  result = ValidationResult(
298
  case_id=case.case_id,
 
317
  successful = sum(1 for r in results if r.success)
318
 
319
  # Average each metric across successful cases only
320
+ metric_names = ["top1_accuracy", "top3_accuracy", "mentioned_accuracy", "differential_accuracy", "parse_success"]
321
  metrics = {}
322
  for m in metric_names:
323
  values = [r.scores.get(m, 0.0) for r in results]
src/backend/validation/harness_pmc.py CHANGED
@@ -373,15 +373,15 @@ async def validate_pmc(
373
 
374
  if report:
375
  # Diagnostic accuracy (anywhere in differential)
376
- found_any, rank_any = diagnosis_in_differential(target_diagnosis, report)
377
  scores["diagnostic_accuracy"] = 1.0 if found_any else 0.0
378
 
379
  # Top-3 accuracy
380
- found_top3, rank3 = diagnosis_in_differential(target_diagnosis, report, top_n=3)
381
  scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
382
 
383
  # Top-1 accuracy
384
- found_top1, rank1 = diagnosis_in_differential(target_diagnosis, report, top_n=1)
385
  scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
386
 
387
  # Parse success
 
373
 
374
  if report:
375
  # Diagnostic accuracy (anywhere in differential)
376
+ found_any, rank_any, loc_any = diagnosis_in_differential(target_diagnosis, report)
377
  scores["diagnostic_accuracy"] = 1.0 if found_any else 0.0
378
 
379
  # Top-3 accuracy
380
+ found_top3, rank3, loc3 = diagnosis_in_differential(target_diagnosis, report, top_n=3)
381
  scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
382
 
383
  # Top-1 accuracy
384
+ found_top1, rank1, loc1 = diagnosis_in_differential(target_diagnosis, report, top_n=1)
385
  scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
386
 
387
  # Parse success
src/backend/validation/run_validation.py CHANGED
@@ -18,6 +18,7 @@ from __future__ import annotations
18
 
19
  import asyncio
20
  import json
 
21
  import sys
22
  import time
23
  from datetime import datetime, timezone
@@ -28,6 +29,12 @@ BACKEND_DIR = Path(__file__).resolve().parent.parent
28
  if str(BACKEND_DIR) not in sys.path:
29
  sys.path.insert(0, str(BACKEND_DIR))
30
 
 
 
 
 
 
 
31
  from validation.base import (
32
  ValidationSummary,
33
  print_summary,
@@ -163,7 +170,7 @@ def _print_combined_summary(results: dict, total_duration: float):
163
  )
164
 
165
  # All metrics
166
- print(f"\n {'' * 66}")
167
  for name, summary in results.items():
168
  print(f"\n {name.upper()} metrics:")
169
  for metric, value in sorted(summary.metrics.items()):
@@ -252,9 +259,9 @@ Examples:
252
  run_mtsamples = args.all or args.mtsamples
253
  run_pmc = args.all or args.pmc
254
 
255
- print("╔════════════════════════════════════════════════════════╗")
256
- print(" Clinical Decision Support Agent Validation Suite")
257
- print("╚════════════════════════════════════════════════════════╝")
258
  print(f"\n Datasets: {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
259
  print(f" Cases/dataset: {args.max_cases}")
260
  print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
 
18
 
19
  import asyncio
20
  import json
21
+ import os
22
  import sys
23
  import time
24
  from datetime import datetime, timezone
 
29
  if str(BACKEND_DIR) not in sys.path:
30
  sys.path.insert(0, str(BACKEND_DIR))
31
 
32
+ # Load .env and export HF_TOKEN so huggingface_hub picks it up
33
+ from dotenv import load_dotenv
34
+ load_dotenv(BACKEND_DIR / ".env")
35
+ if os.getenv("HF_TOKEN"):
36
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
37
+
38
  from validation.base import (
39
  ValidationSummary,
40
  print_summary,
 
170
  )
171
 
172
  # All metrics
173
+ print(f"\n {'-' * 66}")
174
  for name, summary in results.items():
175
  print(f"\n {name.upper()} metrics:")
176
  for metric, value in sorted(summary.metrics.items()):
 
259
  run_mtsamples = args.all or args.mtsamples
260
  run_pmc = args.all or args.pmc
261
 
262
+ print("=" * 58)
263
+ print(" Clinical Decision Support Agent - Validation Suite")
264
+ print("=" * 58)
265
  print(f"\n Datasets: {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
266
  print(f" Cases/dataset: {args.max_cases}")
267
  print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")