Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .env.template +3 -0
- .gitattributes +5 -0
- README.md +19 -6
- agent.py +1170 -0
- api_client.py +707 -0
- app.py +1000 -0
- baselines.py +694 -0
- calibration.py +519 -0
- config.py +276 -0
- datasets/__init__.py +17 -0
- datasets/base.py +146 -0
- datasets/midas.py +444 -0
- datasets/nejm.py +440 -0
- datasets/olives.py +470 -0
- demo_cases/chest_xray_ipf.png +3 -0
- demo_cases/ct_pulmonary_pe.png +3 -0
- demo_cases/fundus_dme.png +3 -0
- demo_cases/oct_bscan_dme.png +3 -0
- demo_cases/skin_lesion_dermoscopy.png +3 -0
- evaluation/__init__.py +455 -0
- evaluation/analysis.py +546 -0
- information_gain.py +441 -0
- policy.py +608 -0
- prompts.py +228 -0
- reasoning_analysis.py +612 -0
- requirements.txt +8 -0
- tools.py +185 -0
- trajectory.py +338 -0
.env.template
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Set your API keys as HF Space secrets
|
| 2 |
+
# OPENAI_API_KEY=sk-...
|
| 3 |
+
# ANTHROPIC_API_KEY=sk-ant-...
|
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo_cases/chest_xray_ipf.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo_cases/ct_pulmonary_pe.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo_cases/fundus_dme.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
demo_cases/oct_bscan_dme.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
demo_cases/skin_lesion_dermoscopy.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,25 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ActiveMedAgent Demo
|
| 3 |
+
emoji: 🏥
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.10.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# ActiveMedAgent: Learned Information Acquisition for Medical Diagnosis
|
| 14 |
+
|
| 15 |
+
Interactive demo for the ActiveMedAgent framework. Watch the agent reason step-by-step,
|
| 16 |
+
acquire information channels strategically, and track entropy reduction in real time.
|
| 17 |
+
|
| 18 |
+
**No budget constraint** — the agent decides when to stop based on information-theoretic criteria.
|
| 19 |
+
|
| 20 |
+
## Features
|
| 21 |
+
- Pre-built demo cases (NEJM, MIDAS, OLIVES)
|
| 22 |
+
- Custom case builder with image upload
|
| 23 |
+
- Step-by-step reasoning trace with probability bars
|
| 24 |
+
- Entropy trajectory and information gain tracking
|
| 25 |
+
- Simulated mode (no API key needed) + real VLM backends
|
agent.py
ADDED
|
@@ -0,0 +1,1170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core ActiveMedAgent — Tool-Use Architecture with Adaptive Context Management.
|
| 3 |
+
|
| 4 |
+
Two modes of operation:
|
| 5 |
+
|
| 6 |
+
1. FULL MODE (for capable models: GPT-4o, Claude, Qwen-72B):
|
| 7 |
+
Multi-turn conversation with full history. The VLM sees its own prior
|
| 8 |
+
reasoning and can build on it.
|
| 9 |
+
|
| 10 |
+
2. CONDENSED MODE (for weaker models: GPT-4o-mini, etc.):
|
| 11 |
+
Each acquisition step is a fresh single-turn call. The VLM receives:
|
| 12 |
+
- Initial image(s)
|
| 13 |
+
- A compact structured acquisition log of all prior steps
|
| 14 |
+
- The latest channel data
|
| 15 |
+
- Available channels + tools
|
| 16 |
+
This keeps context size O(1) per step instead of O(n), preventing
|
| 17 |
+
weaker models from losing track of their own reasoning.
|
| 18 |
+
|
| 19 |
+
In both modes:
|
| 20 |
+
- There is NO fixed budget. The agent acquires as many channels as it
|
| 21 |
+
needs (0 to all available). If a case needs all 5 NEJM vignettes, it
|
| 22 |
+
gets all 5. If the image alone is sufficient, it commits immediately.
|
| 23 |
+
- The agent decides when to stop via the commit_diagnosis tool.
|
| 24 |
+
- Probability distributions from tool calls are tracked for information-
|
| 25 |
+
theoretic analysis (entropy, IG, KL divergence).
|
| 26 |
+
"""
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
|
| 31 |
+
from api_client import BaseVLMClient, VLMResponse
|
| 32 |
+
from datasets.base import MedicalCase, ChannelData
|
| 33 |
+
from tools import (
|
| 34 |
+
ToolCall, ToolResult, AGENT_TOOLS,
|
| 35 |
+
to_openai_tools, to_anthropic_tools,
|
| 36 |
+
constrain_tools_for_step,
|
| 37 |
+
)
|
| 38 |
+
from information_gain import (
|
| 39 |
+
BeliefState, BeliefTrajectory,
|
| 40 |
+
compute_entropy, compute_kl_divergence,
|
| 41 |
+
estimate_expected_information_gain,
|
| 42 |
+
should_commit, compute_value_of_information,
|
| 43 |
+
)
|
| 44 |
+
from prompts import format_available_channels, format_acquired_info
|
| 45 |
+
import config
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ============================================================
|
| 51 |
+
# Data Structures
|
| 52 |
+
# ============================================================
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class AcquisitionStep:
|
| 56 |
+
"""Record of a single acquisition decision made via tool call."""
|
| 57 |
+
step: int
|
| 58 |
+
tool_call: ToolCall | None
|
| 59 |
+
requested_channel: str | None # None if agent committed
|
| 60 |
+
reasoning: str
|
| 61 |
+
differential: list[dict] # [{name, confidence}, ...]
|
| 62 |
+
committed: bool
|
| 63 |
+
raw_response: str
|
| 64 |
+
latency_ms: float
|
| 65 |
+
entropy: float = 0.0
|
| 66 |
+
information_gain: float = 0.0
|
| 67 |
+
kl_divergence: float = 0.0
|
| 68 |
+
expected_impact: dict = field(default_factory=dict)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class AgentResult:
|
| 73 |
+
"""Complete result of an agent's diagnostic process on one case."""
|
| 74 |
+
case_id: str
|
| 75 |
+
dataset: str
|
| 76 |
+
prompt_variant: str
|
| 77 |
+
backend: str
|
| 78 |
+
budget: int # max channels available (not a hard limit)
|
| 79 |
+
steps: list[AcquisitionStep] = field(default_factory=list)
|
| 80 |
+
final_ranking: list[dict] = field(default_factory=list)
|
| 81 |
+
acquired_channels: list[str] = field(default_factory=list)
|
| 82 |
+
total_latency_ms: float = 0.0
|
| 83 |
+
total_input_tokens: int = 0
|
| 84 |
+
total_output_tokens: int = 0
|
| 85 |
+
committed_early: bool = False
|
| 86 |
+
final_raw_response: str = ""
|
| 87 |
+
belief_trajectory: BeliefTrajectory | None = None
|
| 88 |
+
total_case_cost: float = 0.0
|
| 89 |
+
acquisition_cost: float = 0.0
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ============================================================
|
| 93 |
+
# System Prompts
|
| 94 |
+
# ============================================================
|
| 95 |
+
|
| 96 |
+
SYSTEM_PROMPT_FULL = """\
|
| 97 |
+
You are a medical diagnostic agent. You may receive free clinical context, exam data, \
|
| 98 |
+
labs, and images in a tiered pathway. Your job is to determine the correct diagnosis \
|
| 99 |
+
from a set of candidates while avoiding unnecessary resource use.
|
| 100 |
+
|
| 101 |
+
You have two tools:
|
| 102 |
+
|
| 103 |
+
1. request_information — Request one additional data channel when its expected \
|
| 104 |
+
clinical value justifies its cost. You MUST provide:
|
| 105 |
+
- channel_name: exactly one of the available channel names
|
| 106 |
+
- reasoning: why this channel best resolves your current uncertainty
|
| 107 |
+
- current_differential: your FULL ranked differential with calibrated \
|
| 108 |
+
probabilities that sum to 1.0 across ALL candidates
|
| 109 |
+
- expected_impact: what you expect (if_positive and if_negative)
|
| 110 |
+
|
| 111 |
+
2. commit_diagnosis — Submit your final ranked diagnosis. Provide:
|
| 112 |
+
- ranked_diagnoses: ALL candidates with calibrated probabilities summing to 1.0, \
|
| 113 |
+
each with key_evidence
|
| 114 |
+
- reasoning: your complete diagnostic reasoning chain
|
| 115 |
+
|
| 116 |
+
Strategy:
|
| 117 |
+
- Start with whatever information is already available for free at presentation.
|
| 118 |
+
- Escalate only when the currently available information cannot safely distinguish the top diagnoses.
|
| 119 |
+
- If demographics, chief complaint, history, exam, or existing evidence are sufficient, commit without requesting imaging.
|
| 120 |
+
- Commit when your top diagnosis has high probability and is well-separated from \
|
| 121 |
+
alternatives, OR when no remaining channel would meaningfully change your differential.
|
| 122 |
+
- Your probability estimates MUST sum to 1.0 and reflect genuine calibrated uncertainty."""
|
| 123 |
+
|
| 124 |
+
SYSTEM_PROMPT_CONDENSED = """\
|
| 125 |
+
You are a medical diagnostic agent. Examine all currently available clinical \
|
| 126 |
+
information below, then decide your next action.
|
| 127 |
+
|
| 128 |
+
You have two tools:
|
| 129 |
+
|
| 130 |
+
1. request_information — Request one more data channel only if its expected value \
|
| 131 |
+
justifies its cost. Provide:
|
| 132 |
+
- channel_name: one of the available channels listed below
|
| 133 |
+
- reasoning: why this channel would help
|
| 134 |
+
- current_differential: ranked diagnoses with probabilities summing to 1.0
|
| 135 |
+
- expected_impact: if_positive and if_negative predictions
|
| 136 |
+
|
| 137 |
+
2. commit_diagnosis — Submit final diagnosis. Provide:
|
| 138 |
+
- ranked_diagnoses: ALL candidates with probabilities summing to 1.0 and key_evidence
|
| 139 |
+
- reasoning: complete reasoning chain
|
| 140 |
+
|
| 141 |
+
Decide: if remaining channels would meaningfully change your differential enough to \
|
| 142 |
+
justify their cost, request the best one. Otherwise, commit your diagnosis."""
|
| 143 |
+
|
| 144 |
+
SYSTEM_PROMPT_FINAL = """\
|
| 145 |
+
You are a medical diagnostic agent. You have gathered information. Now provide \
|
| 146 |
+
your final ranked diagnosis using the commit_diagnosis tool.
|
| 147 |
+
|
| 148 |
+
You MUST:
|
| 149 |
+
- Include ALL candidate diagnoses in ranked_diagnoses
|
| 150 |
+
- Assign calibrated probabilities summing to 1.0
|
| 151 |
+
- Provide specific key_evidence for EACH diagnosis
|
| 152 |
+
- Write a thorough reasoning chain synthesizing all acquired evidence
|
| 153 |
+
- Favor the least resource-intensive pathway that still supports the diagnosis"""
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ============================================================
|
| 157 |
+
# Context Mode Detection
|
| 158 |
+
# ============================================================
|
| 159 |
+
|
| 160 |
+
def _should_use_condensed(model_name: str) -> bool:
|
| 161 |
+
"""Determine if a model should use condensed context mode."""
|
| 162 |
+
if config.CONTEXT_MODE == "full":
|
| 163 |
+
return False
|
| 164 |
+
if config.CONTEXT_MODE == "condensed":
|
| 165 |
+
return True
|
| 166 |
+
# adaptive — check model name
|
| 167 |
+
for pattern in config.CONDENSED_MODELS:
|
| 168 |
+
if pattern in model_name:
|
| 169 |
+
return True
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ============================================================
|
| 174 |
+
# Condensed Acquisition Log Builder
|
| 175 |
+
# ============================================================
|
| 176 |
+
|
| 177 |
+
def _build_acquisition_log(
|
| 178 |
+
steps: list[AcquisitionStep],
|
| 179 |
+
acquired_data: dict[str, str],
|
| 180 |
+
) -> str:
|
| 181 |
+
"""
|
| 182 |
+
Build a compact structured summary of all prior acquisition steps.
|
| 183 |
+
|
| 184 |
+
This replaces the full multi-turn conversation for condensed mode.
|
| 185 |
+
Each step is ~50-80 tokens instead of 300-500 for the full tool call
|
| 186 |
+
response, keeping context manageable for weaker models.
|
| 187 |
+
|
| 188 |
+
Example output:
|
| 189 |
+
=== ACQUISITION LOG (2 channels acquired) ===
|
| 190 |
+
Step 1: Acquired [dermoscopy]
|
| 191 |
+
Reasoning: Need subsurface structures to distinguish melanoma from BCC
|
| 192 |
+
Data received: [dermoscopy]: (image — see above)
|
| 193 |
+
Updated differential: Melanoma(0.55), BCC(0.30), SCC(0.15)
|
| 194 |
+
Entropy: 1.37 bits | Information gain: 0.19 bits
|
| 195 |
+
|
| 196 |
+
Step 2: Acquired [patient_demographics]
|
| 197 |
+
Reasoning: Age and skin type are critical for melanoma risk
|
| 198 |
+
Data received: [demographics]: 34M, Fitzpatrick II
|
| 199 |
+
Updated differential: Melanoma(0.75), BCC(0.15), SCC(0.10)
|
| 200 |
+
Entropy: 1.06 bits | Information gain: 0.31 bits
|
| 201 |
+
=== END LOG ===
|
| 202 |
+
"""
|
| 203 |
+
if not steps:
|
| 204 |
+
return "(No information acquired yet.)"
|
| 205 |
+
|
| 206 |
+
lines = [f"=== ACQUISITION LOG ({len(steps)} channel(s) acquired) ==="]
|
| 207 |
+
for step in steps:
|
| 208 |
+
if step.committed:
|
| 209 |
+
continue
|
| 210 |
+
ch = step.requested_channel or "unknown"
|
| 211 |
+
lines.append(f"Step {step.step + 1}: Acquired [{ch}]")
|
| 212 |
+
if step.reasoning:
|
| 213 |
+
# Truncate reasoning to key point
|
| 214 |
+
reasoning = step.reasoning
|
| 215 |
+
if len(reasoning) > 200:
|
| 216 |
+
reasoning = reasoning[:197] + "..."
|
| 217 |
+
lines.append(f" Reasoning: {reasoning}")
|
| 218 |
+
|
| 219 |
+
# Include the actual data received
|
| 220 |
+
data = acquired_data.get(ch, "")
|
| 221 |
+
if data:
|
| 222 |
+
if len(data) > 300:
|
| 223 |
+
data = data[:297] + "..."
|
| 224 |
+
lines.append(f" Data received: {data}")
|
| 225 |
+
|
| 226 |
+
# Compact differential
|
| 227 |
+
if step.differential:
|
| 228 |
+
diff_str = ", ".join(
|
| 229 |
+
f"{d['name']}({d['confidence']:.2f})"
|
| 230 |
+
for d in step.differential[:5]
|
| 231 |
+
)
|
| 232 |
+
lines.append(f" Updated differential: {diff_str}")
|
| 233 |
+
|
| 234 |
+
lines.append(
|
| 235 |
+
f" Entropy: {step.entropy:.2f} bits"
|
| 236 |
+
+ (f" | IG: {step.information_gain:.2f} bits" if step.information_gain else "")
|
| 237 |
+
)
|
| 238 |
+
lines.append("")
|
| 239 |
+
|
| 240 |
+
lines.append("=== END LOG ===")
|
| 241 |
+
return "\n".join(lines)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ============================================================
|
| 245 |
+
# Main Agent Class
|
| 246 |
+
# ============================================================
|
| 247 |
+
|
| 248 |
+
class ActiveMedAgent:
|
| 249 |
+
"""
|
| 250 |
+
Tool-use active acquisition agent with adaptive context management.
|
| 251 |
+
|
| 252 |
+
No fixed budget — the agent acquires as many or as few channels as it
|
| 253 |
+
needs. Supports two context modes:
|
| 254 |
+
- full: multi-turn conversation (capable models)
|
| 255 |
+
- condensed: single-turn with compressed state (weaker models)
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
client: BaseVLMClient,
|
| 261 |
+
prompt_variant: str = "A",
|
| 262 |
+
budget: int = None,
|
| 263 |
+
context_mode: str = None,
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Args:
|
| 267 |
+
client: VLM API client.
|
| 268 |
+
prompt_variant: Prompt variant ID (for tracking, not used in tool mode).
|
| 269 |
+
budget: Max acquisitions. None = unlimited (use all channels if needed).
|
| 270 |
+
context_mode: "full", "condensed", or None (auto-detect from model).
|
| 271 |
+
"""
|
| 272 |
+
self.client = client
|
| 273 |
+
self.prompt_variant = prompt_variant
|
| 274 |
+
self.budget = budget # None means unlimited
|
| 275 |
+
self._commit_hint = ""
|
| 276 |
+
if context_mode is not None:
|
| 277 |
+
self.condensed = (context_mode == "condensed")
|
| 278 |
+
else:
|
| 279 |
+
self.condensed = _should_use_condensed(client.model)
|
| 280 |
+
|
| 281 |
+
if self.condensed:
|
| 282 |
+
logger.info(
|
| 283 |
+
f"Using CONDENSED context mode for {client.model} "
|
| 284 |
+
f"(single-turn with compressed state)"
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# ============================================================
|
| 288 |
+
# Main Acquisition Loop
|
| 289 |
+
# ============================================================
|
| 290 |
+
|
| 291 |
+
def diagnose(self, case: MedicalCase) -> AgentResult:
|
| 292 |
+
"""
|
| 293 |
+
Run the full tool-use acquisition loop.
|
| 294 |
+
|
| 295 |
+
The agent has NO fixed budget. It keeps requesting channels until:
|
| 296 |
+
1. It calls commit_diagnosis (confident enough), or
|
| 297 |
+
2. All available channels are exhausted, or
|
| 298 |
+
3. The safety limit is hit (max_steps = number of requestable channels)
|
| 299 |
+
|
| 300 |
+
Context mode determines how conversation history is managed:
|
| 301 |
+
- full: growing multi-turn conversation
|
| 302 |
+
- condensed: fresh single-turn call each step with compressed state
|
| 303 |
+
"""
|
| 304 |
+
max_steps = len(case.requestable_names)
|
| 305 |
+
if self.budget is not None:
|
| 306 |
+
max_steps = min(max_steps, self.budget)
|
| 307 |
+
|
| 308 |
+
result = AgentResult(
|
| 309 |
+
case_id=case.case_id,
|
| 310 |
+
dataset=case.dataset,
|
| 311 |
+
prompt_variant=self.prompt_variant,
|
| 312 |
+
backend=self.client.model,
|
| 313 |
+
budget=max_steps,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
acquired = []
|
| 317 |
+
acquired_data = {} # channel_name -> data string (for condensed log)
|
| 318 |
+
dataset_channel_config = config.CHANNEL_CONFIGS.get(case.dataset, {})
|
| 319 |
+
channel_config = {
|
| 320 |
+
name: info for name, info in dataset_channel_config.items()
|
| 321 |
+
if name in case.initial_channels or name in case.requestable_channels
|
| 322 |
+
}
|
| 323 |
+
conversation = [] # only used in full mode
|
| 324 |
+
trajectory = BeliefTrajectory(case_id=case.case_id)
|
| 325 |
+
prev_distribution = None
|
| 326 |
+
initial_images = case.get_initial_images()
|
| 327 |
+
initial_context_str = format_acquired_info(case.get_text_context([]))
|
| 328 |
+
|
| 329 |
+
# ---- Build initial message (shared by both modes) ----
|
| 330 |
+
available_str = format_available_channels(channel_config, acquired)
|
| 331 |
+
candidates_str = "\n".join(
|
| 332 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if not self.condensed:
|
| 336 |
+
# FULL MODE: build initial user message for multi-turn
|
| 337 |
+
initial_content = self._build_image_content(initial_images)
|
| 338 |
+
initial_text = (
|
| 339 |
+
f"Review the currently available clinical information and determine the diagnosis.\n\n"
|
| 340 |
+
f"Information already available at presentation:\n{initial_context_str}\n\n"
|
| 341 |
+
f"Candidate diagnoses (you must rank ALL of these):\n"
|
| 342 |
+
f"{candidates_str}\n\n"
|
| 343 |
+
f"Prefer the least costly pathway that still supports a safe diagnosis.\n"
|
| 344 |
+
f"If the current information is sufficient, commit immediately without requesting more.\n"
|
| 345 |
+
f"You can request as many additional channels as you need "
|
| 346 |
+
f"(or none if already confident).\n"
|
| 347 |
+
f"Available information channels:\n{available_str}\n\n"
|
| 348 |
+
f"Use request_information to acquire the most informative "
|
| 349 |
+
f"channel for the cost, or commit_diagnosis if already confident."
|
| 350 |
+
)
|
| 351 |
+
initial_content.append({"type": "text", "text": initial_text})
|
| 352 |
+
conversation.append({"role": "user", "content": initial_content})
|
| 353 |
+
|
| 354 |
+
# ---- Acquisition Loop ----
|
| 355 |
+
for step_idx in range(max_steps):
|
| 356 |
+
available = [
|
| 357 |
+
n for n in case.requestable_names if n not in acquired
|
| 358 |
+
]
|
| 359 |
+
if not available:
|
| 360 |
+
logger.debug(
|
| 361 |
+
f"[{case.case_id}] All channels exhausted at step {step_idx}"
|
| 362 |
+
)
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
# Force acquisition on first step; allow commit after that
|
| 366 |
+
step_tools = constrain_tools_for_step(
|
| 367 |
+
budget_remaining=max_steps - step_idx,
|
| 368 |
+
allow_commit=(step_idx > 0),
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
if self.condensed:
|
| 372 |
+
# CONDENSED MODE: build a fresh single-turn call each step
|
| 373 |
+
response = self._call_condensed(
|
| 374 |
+
case=case,
|
| 375 |
+
initial_images=initial_images,
|
| 376 |
+
acquired=acquired,
|
| 377 |
+
acquired_data=acquired_data,
|
| 378 |
+
steps=result.steps,
|
| 379 |
+
available=available,
|
| 380 |
+
candidates_str=candidates_str,
|
| 381 |
+
channel_config=channel_config,
|
| 382 |
+
step_tools=step_tools,
|
| 383 |
+
)
|
| 384 |
+
else:
|
| 385 |
+
# FULL MODE: multi-turn call with complete history
|
| 386 |
+
response = self.client.call_with_retry(
|
| 387 |
+
system_prompt=SYSTEM_PROMPT_FULL,
|
| 388 |
+
messages=conversation,
|
| 389 |
+
temperature=config.TEMPERATURE,
|
| 390 |
+
max_tokens=config.MAX_TOKENS,
|
| 391 |
+
tools=step_tools,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
result.total_latency_ms += response.latency_ms
|
| 395 |
+
result.total_input_tokens += response.input_tokens
|
| 396 |
+
result.total_output_tokens += response.output_tokens
|
| 397 |
+
|
| 398 |
+
# ---- Process tool call ----
|
| 399 |
+
tool_call = response.tool_call
|
| 400 |
+
|
| 401 |
+
if tool_call is None:
|
| 402 |
+
# No tool call — fallback
|
| 403 |
+
logger.warning(
|
| 404 |
+
f"[{case.case_id}] Step {step_idx}: no tool call returned"
|
| 405 |
+
)
|
| 406 |
+
if not available:
|
| 407 |
+
break
|
| 408 |
+
fallback = available[0]
|
| 409 |
+
acquired.append(fallback)
|
| 410 |
+
result.acquired_channels.append(fallback)
|
| 411 |
+
|
| 412 |
+
ch = case.get_channel(fallback)
|
| 413 |
+
if ch and ch.channel_type == "text":
|
| 414 |
+
acquired_data[fallback] = f"[{fallback}]: {ch.value}"
|
| 415 |
+
else:
|
| 416 |
+
acquired_data[fallback] = f"[{fallback}]: (image)"
|
| 417 |
+
|
| 418 |
+
step = AcquisitionStep(
|
| 419 |
+
step=step_idx, tool_call=None,
|
| 420 |
+
requested_channel=fallback,
|
| 421 |
+
reasoning="(fallback — no tool call produced)",
|
| 422 |
+
differential=[], committed=False,
|
| 423 |
+
raw_response=response.text,
|
| 424 |
+
latency_ms=response.latency_ms,
|
| 425 |
+
)
|
| 426 |
+
result.steps.append(step)
|
| 427 |
+
|
| 428 |
+
if not self.condensed:
|
| 429 |
+
self._deliver_channel_data_as_user_message(
|
| 430 |
+
case, fallback, conversation, available, acquired,
|
| 431 |
+
channel_config,
|
| 432 |
+
)
|
| 433 |
+
continue
|
| 434 |
+
|
| 435 |
+
# Add assistant message to conversation (full mode only)
|
| 436 |
+
if not self.condensed:
|
| 437 |
+
conversation.append({
|
| 438 |
+
"role": "assistant",
|
| 439 |
+
"content": response.text,
|
| 440 |
+
"tool_calls": [tool_call],
|
| 441 |
+
})
|
| 442 |
+
|
| 443 |
+
# ---- Handle commit_diagnosis ----
|
| 444 |
+
if tool_call.tool_name == "commit_diagnosis":
|
| 445 |
+
args = tool_call.arguments
|
| 446 |
+
ranking = self._extract_ranking_from_commit(args)
|
| 447 |
+
distribution = {d["name"]: d["confidence"] for d in ranking}
|
| 448 |
+
|
| 449 |
+
belief = BeliefState(
|
| 450 |
+
step=step_idx,
|
| 451 |
+
distribution=distribution,
|
| 452 |
+
channel_acquired=None,
|
| 453 |
+
)
|
| 454 |
+
trajectory.states.append(belief)
|
| 455 |
+
|
| 456 |
+
ig = 0.0
|
| 457 |
+
kl = 0.0
|
| 458 |
+
if prev_distribution is not None:
|
| 459 |
+
ig = compute_entropy(prev_distribution) - compute_entropy(distribution)
|
| 460 |
+
kl = compute_kl_divergence(distribution, prev_distribution)
|
| 461 |
+
|
| 462 |
+
step = AcquisitionStep(
|
| 463 |
+
step=step_idx, tool_call=tool_call,
|
| 464 |
+
requested_channel=None,
|
| 465 |
+
reasoning=args.get("reasoning", ""),
|
| 466 |
+
differential=ranking, committed=True,
|
| 467 |
+
raw_response=response.text,
|
| 468 |
+
latency_ms=response.latency_ms,
|
| 469 |
+
entropy=belief.entropy,
|
| 470 |
+
information_gain=ig, kl_divergence=kl,
|
| 471 |
+
)
|
| 472 |
+
result.steps.append(step)
|
| 473 |
+
result.committed_early = True
|
| 474 |
+
result.final_ranking = ranking
|
| 475 |
+
logger.debug(
|
| 476 |
+
f"[{case.case_id}] Committed at step {step_idx} "
|
| 477 |
+
f"after acquiring {len(acquired)} channels "
|
| 478 |
+
f"(entropy={belief.entropy:.3f} bits)"
|
| 479 |
+
)
|
| 480 |
+
break
|
| 481 |
+
|
| 482 |
+
# ---- Handle request_information ----
|
| 483 |
+
elif tool_call.tool_name == "request_information":
|
| 484 |
+
args = tool_call.arguments
|
| 485 |
+
requested = args.get("channel_name", "")
|
| 486 |
+
differential = args.get("current_differential", [])
|
| 487 |
+
expected_impact = args.get("expected_impact", {})
|
| 488 |
+
reasoning = args.get("reasoning", "")
|
| 489 |
+
|
| 490 |
+
matched = self._match_channel(requested, available)
|
| 491 |
+
if matched is None:
|
| 492 |
+
matched = available[0]
|
| 493 |
+
logger.warning(
|
| 494 |
+
f"[{case.case_id}] Step {step_idx}: '{requested}' "
|
| 495 |
+
f"not in {available}, falling back to '{matched}'"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Build distribution from tool call
|
| 499 |
+
distribution = {}
|
| 500 |
+
for d in differential:
|
| 501 |
+
distribution[d.get("name", "")] = d.get("probability", 0.0)
|
| 502 |
+
|
| 503 |
+
# Information-theoretic metrics
|
| 504 |
+
ig = 0.0
|
| 505 |
+
kl = 0.0
|
| 506 |
+
if prev_distribution is not None:
|
| 507 |
+
ig = compute_entropy(prev_distribution) - compute_entropy(distribution)
|
| 508 |
+
kl = compute_kl_divergence(distribution, prev_distribution)
|
| 509 |
+
|
| 510 |
+
belief = BeliefState(
|
| 511 |
+
step=step_idx,
|
| 512 |
+
distribution=distribution,
|
| 513 |
+
channel_acquired=matched,
|
| 514 |
+
)
|
| 515 |
+
trajectory.states.append(belief)
|
| 516 |
+
|
| 517 |
+
eig = estimate_expected_information_gain(
|
| 518 |
+
distribution, matched, expected_impact, case.candidates,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
logger.debug(
|
| 522 |
+
f"[{case.case_id}] Step {step_idx}: requesting '{matched}' "
|
| 523 |
+
f"(H={belief.entropy:.3f}, IG={ig:.3f}, EIG={eig:.3f})"
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
step = AcquisitionStep(
|
| 527 |
+
step=step_idx, tool_call=tool_call,
|
| 528 |
+
requested_channel=matched, reasoning=reasoning,
|
| 529 |
+
differential=[
|
| 530 |
+
{"name": d.get("name", ""),
|
| 531 |
+
"confidence": d.get("probability", 0.0),
|
| 532 |
+
"rank": i + 1}
|
| 533 |
+
for i, d in enumerate(differential)
|
| 534 |
+
],
|
| 535 |
+
committed=False,
|
| 536 |
+
raw_response=response.text,
|
| 537 |
+
latency_ms=response.latency_ms,
|
| 538 |
+
entropy=belief.entropy,
|
| 539 |
+
information_gain=ig, kl_divergence=kl,
|
| 540 |
+
expected_impact=expected_impact,
|
| 541 |
+
)
|
| 542 |
+
result.steps.append(step)
|
| 543 |
+
prev_distribution = distribution
|
| 544 |
+
|
| 545 |
+
acquired.append(matched)
|
| 546 |
+
result.acquired_channels.append(matched)
|
| 547 |
+
|
| 548 |
+
# Store acquired data for condensed log
|
| 549 |
+
ch = case.get_channel(matched)
|
| 550 |
+
if ch and ch.channel_type == "text":
|
| 551 |
+
acquired_data[matched] = f"[{matched}]: {ch.value}"
|
| 552 |
+
elif ch and ch.channel_type == "image":
|
| 553 |
+
acquired_data[matched] = f"[{matched}]: (image provided)"
|
| 554 |
+
else:
|
| 555 |
+
acquired_data[matched] = f"[{matched}]: No data available."
|
| 556 |
+
|
| 557 |
+
# ---- Check stopping criterion ----
|
| 558 |
+
# After recording the new belief state, evaluate whether
|
| 559 |
+
# the agent should stop acquiring. This is a principled
|
| 560 |
+
# information-theoretic check, not just a prompt heuristic.
|
| 561 |
+
remaining_channels = [
|
| 562 |
+
n for n in case.requestable_names if n not in acquired
|
| 563 |
+
]
|
| 564 |
+
commit_recommended, commit_reason = should_commit(
|
| 565 |
+
trajectory=trajectory,
|
| 566 |
+
available_channels=remaining_channels,
|
| 567 |
+
min_steps=0, # agent decides — no forced minimum
|
| 568 |
+
)
|
| 569 |
+
voi = compute_value_of_information(
|
| 570 |
+
trajectory, len(remaining_channels),
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
if commit_recommended and remaining_channels:
|
| 574 |
+
logger.info(
|
| 575 |
+
f"[{case.case_id}] Stopping criterion triggered at "
|
| 576 |
+
f"step {step_idx}: {commit_reason} (VoI={voi:.3f})"
|
| 577 |
+
)
|
| 578 |
+
# Don't break yet — let the VLM make the decision on
|
| 579 |
+
# the next iteration. But inject a hint into the
|
| 580 |
+
# follow-up context.
|
| 581 |
+
self._commit_hint = (
|
| 582 |
+
f"\n\nNote: Based on your belief trajectory, additional "
|
| 583 |
+
f"acquisition has low expected value (VoI={voi:.2f}). "
|
| 584 |
+
f"The last channel provided only {ig:.3f} bits of "
|
| 585 |
+
f"information gain. Consider committing your diagnosis."
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
self._commit_hint = ""
|
| 589 |
+
|
| 590 |
+
# Deliver tool result (full mode only — condensed rebuilds
|
| 591 |
+
# the full state each call)
|
| 592 |
+
if not self.condensed:
|
| 593 |
+
self._deliver_tool_result(
|
| 594 |
+
case=case, channel_name=matched,
|
| 595 |
+
tool_call=tool_call,
|
| 596 |
+
conversation=conversation,
|
| 597 |
+
acquired=acquired,
|
| 598 |
+
channel_config=channel_config,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# ---- Final Diagnosis ----
|
| 602 |
+
if not result.committed_early or not result.final_ranking:
|
| 603 |
+
if self.condensed:
|
| 604 |
+
final_ranking, final_response, final_belief = (
|
| 605 |
+
self._get_final_diagnosis_condensed(
|
| 606 |
+
case, acquired, acquired_data, result.steps,
|
| 607 |
+
)
|
| 608 |
+
)
|
| 609 |
+
else:
|
| 610 |
+
final_ranking, final_response, final_belief = (
|
| 611 |
+
self._get_final_diagnosis_tooluse(
|
| 612 |
+
case, acquired, conversation,
|
| 613 |
+
)
|
| 614 |
+
)
|
| 615 |
+
result.final_ranking = final_ranking
|
| 616 |
+
result.final_raw_response = final_response.text
|
| 617 |
+
result.total_latency_ms += final_response.latency_ms
|
| 618 |
+
result.total_input_tokens += final_response.input_tokens
|
| 619 |
+
result.total_output_tokens += final_response.output_tokens
|
| 620 |
+
if final_belief:
|
| 621 |
+
trajectory.states.append(final_belief)
|
| 622 |
+
|
| 623 |
+
result.acquired_channels = acquired
|
| 624 |
+
result.belief_trajectory = trajectory
|
| 625 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 626 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 627 |
+
return result
|
| 628 |
+
|
| 629 |
+
# ============================================================
|
| 630 |
+
# Condensed Mode: Single-Turn Call Builder
|
| 631 |
+
# ============================================================
|
| 632 |
+
|
| 633 |
+
def _call_condensed(
|
| 634 |
+
self,
|
| 635 |
+
case: MedicalCase,
|
| 636 |
+
initial_images: list[str],
|
| 637 |
+
acquired: list[str],
|
| 638 |
+
acquired_data: dict[str, str],
|
| 639 |
+
steps: list[AcquisitionStep],
|
| 640 |
+
available: list[str],
|
| 641 |
+
candidates_str: str,
|
| 642 |
+
channel_config: dict,
|
| 643 |
+
step_tools: list[dict],
|
| 644 |
+
) -> VLMResponse:
|
| 645 |
+
"""
|
| 646 |
+
Build and execute a single-turn call for condensed mode.
|
| 647 |
+
|
| 648 |
+
Each call gets a complete, self-contained context:
|
| 649 |
+
1. Initial image(s)
|
| 650 |
+
2. Any acquired images
|
| 651 |
+
3. Structured acquisition log (compact summary of all prior steps)
|
| 652 |
+
4. All acquired text data
|
| 653 |
+
5. Available channels
|
| 654 |
+
6. Tools
|
| 655 |
+
|
| 656 |
+
This keeps context size predictable and prevents weaker models
|
| 657 |
+
from losing track of their reasoning in long multi-turn histories.
|
| 658 |
+
"""
|
| 659 |
+
content = []
|
| 660 |
+
|
| 661 |
+
# 1. Initial image(s) — always included
|
| 662 |
+
content.extend(self._build_image_content(initial_images))
|
| 663 |
+
|
| 664 |
+
# 2. Acquired images — include all visual channels
|
| 665 |
+
for ch_name in acquired:
|
| 666 |
+
ch = case.get_channel(ch_name)
|
| 667 |
+
if ch and ch.channel_type == "image" and ch.value:
|
| 668 |
+
if isinstance(ch.value, list):
|
| 669 |
+
for img_b64 in ch.value:
|
| 670 |
+
content.append({
|
| 671 |
+
"type": "image_url",
|
| 672 |
+
"image_url": {
|
| 673 |
+
"url": f"data:image/jpeg;base64,{img_b64}",
|
| 674 |
+
},
|
| 675 |
+
})
|
| 676 |
+
else:
|
| 677 |
+
content.append({
|
| 678 |
+
"type": "image_url",
|
| 679 |
+
"image_url": {
|
| 680 |
+
"url": f"data:image/jpeg;base64,{ch.value}",
|
| 681 |
+
},
|
| 682 |
+
})
|
| 683 |
+
|
| 684 |
+
# 3. Build the text prompt
|
| 685 |
+
available_str = format_available_channels(channel_config, acquired)
|
| 686 |
+
log_str = _build_acquisition_log(steps, acquired_data)
|
| 687 |
+
|
| 688 |
+
# 4. Collect all currently available context (initial + acquired)
|
| 689 |
+
current_context = format_acquired_info(case.get_text_context(acquired))
|
| 690 |
+
|
| 691 |
+
prompt = (
|
| 692 |
+
f"Review all currently available clinical information below.\n\n"
|
| 693 |
+
f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
|
| 694 |
+
f"Current available evidence:\n{current_context}\n\n"
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
if steps:
|
| 698 |
+
prompt += (
|
| 699 |
+
f"Your prior acquisition history:\n{log_str}\n\n"
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
commit_hint = getattr(self, '_commit_hint', '')
|
| 703 |
+
|
| 704 |
+
if available:
|
| 705 |
+
prompt += (
|
| 706 |
+
f"Remaining channels you can request:\n{available_str}\n\n"
|
| 707 |
+
f"Decide: Would any remaining channel meaningfully change your "
|
| 708 |
+
f"differential enough to justify its cost? If yes, use "
|
| 709 |
+
f"request_information. If no, use commit_diagnosis with your final ranking."
|
| 710 |
+
f"{commit_hint}"
|
| 711 |
+
)
|
| 712 |
+
else:
|
| 713 |
+
prompt += (
|
| 714 |
+
f"All channels have been acquired. Use commit_diagnosis to "
|
| 715 |
+
f"submit your final ranked diagnosis."
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
content.append({"type": "text", "text": prompt})
|
| 719 |
+
|
| 720 |
+
return self.client.call_with_retry(
|
| 721 |
+
system_prompt=SYSTEM_PROMPT_CONDENSED,
|
| 722 |
+
user_text=None,
|
| 723 |
+
images=None,
|
| 724 |
+
temperature=config.TEMPERATURE,
|
| 725 |
+
max_tokens=config.MAX_TOKENS,
|
| 726 |
+
tools=step_tools,
|
| 727 |
+
messages=[{"role": "user", "content": content}],
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# ============================================================
|
| 731 |
+
# Full Mode: Tool Result Delivery
|
| 732 |
+
# ============================================================
|
| 733 |
+
|
| 734 |
+
def _deliver_tool_result(
|
| 735 |
+
self,
|
| 736 |
+
case: MedicalCase,
|
| 737 |
+
channel_name: str,
|
| 738 |
+
tool_call: ToolCall,
|
| 739 |
+
conversation: list[dict],
|
| 740 |
+
acquired: list[str],
|
| 741 |
+
channel_config: dict,
|
| 742 |
+
):
|
| 743 |
+
"""Deliver requested channel data as a tool_result message (full mode)."""
|
| 744 |
+
ch = case.get_channel(channel_name)
|
| 745 |
+
|
| 746 |
+
result_images = []
|
| 747 |
+
if ch and ch.channel_type == "image" and ch.value:
|
| 748 |
+
if isinstance(ch.value, list):
|
| 749 |
+
result_images.extend(ch.value)
|
| 750 |
+
else:
|
| 751 |
+
result_images.append(ch.value)
|
| 752 |
+
|
| 753 |
+
if ch and ch.channel_type == "text":
|
| 754 |
+
data_str = f"[{channel_name}]: {ch.value}"
|
| 755 |
+
elif ch and ch.channel_type == "image":
|
| 756 |
+
data_str = f"[{channel_name}]: (image provided — see attached)"
|
| 757 |
+
else:
|
| 758 |
+
data_str = f"[{channel_name}]: No data available for this channel."
|
| 759 |
+
|
| 760 |
+
available_after = [
|
| 761 |
+
n for n in case.requestable_names if n not in acquired
|
| 762 |
+
]
|
| 763 |
+
available_after_str = format_available_channels(channel_config, acquired)
|
| 764 |
+
|
| 765 |
+
# Include commit hint from stopping criterion (if triggered)
|
| 766 |
+
commit_hint = getattr(self, '_commit_hint', '')
|
| 767 |
+
|
| 768 |
+
if available_after:
|
| 769 |
+
follow_up = (
|
| 770 |
+
f"Here is the information you requested:\n{data_str}\n\n"
|
| 771 |
+
f"Integrate this evidence with your prior observations.\n\n"
|
| 772 |
+
f"Remaining channels you can request:\n{available_after_str}\n\n"
|
| 773 |
+
f"Use request_information if another channel would meaningfully "
|
| 774 |
+
f"change your differential enough to justify its cost, or "
|
| 775 |
+
f"commit_diagnosis if confident."
|
| 776 |
+
f"{commit_hint}"
|
| 777 |
+
)
|
| 778 |
+
else:
|
| 779 |
+
follow_up = (
|
| 780 |
+
f"Here is the information you requested:\n{data_str}\n\n"
|
| 781 |
+
f"All channels have been acquired. Use commit_diagnosis to "
|
| 782 |
+
f"submit your final ranked diagnosis."
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
conversation.append({
|
| 786 |
+
"role": "tool_result",
|
| 787 |
+
"tool_call_id": tool_call.call_id,
|
| 788 |
+
"content": data_str,
|
| 789 |
+
"images": result_images,
|
| 790 |
+
"follow_up": follow_up,
|
| 791 |
+
})
|
| 792 |
+
|
| 793 |
+
def _deliver_channel_data_as_user_message(
|
| 794 |
+
self,
|
| 795 |
+
case: MedicalCase,
|
| 796 |
+
channel_name: str,
|
| 797 |
+
conversation: list[dict],
|
| 798 |
+
available_before: list[str],
|
| 799 |
+
acquired: list[str],
|
| 800 |
+
channel_config: dict,
|
| 801 |
+
):
|
| 802 |
+
"""Deliver channel data as a plain user message (fallback, full mode)."""
|
| 803 |
+
ch = case.get_channel(channel_name)
|
| 804 |
+
content = []
|
| 805 |
+
|
| 806 |
+
if ch and ch.channel_type == "image" and ch.value:
|
| 807 |
+
if isinstance(ch.value, list):
|
| 808 |
+
for img_b64 in ch.value:
|
| 809 |
+
content.append({
|
| 810 |
+
"type": "image_url",
|
| 811 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
|
| 812 |
+
})
|
| 813 |
+
else:
|
| 814 |
+
content.append({
|
| 815 |
+
"type": "image_url",
|
| 816 |
+
"image_url": {"url": f"data:image/jpeg;base64,{ch.value}"},
|
| 817 |
+
})
|
| 818 |
+
|
| 819 |
+
if ch and ch.channel_type == "text":
|
| 820 |
+
data_str = f"[{channel_name}]: {ch.value}"
|
| 821 |
+
elif ch and ch.channel_type == "image":
|
| 822 |
+
data_str = f"[{channel_name}]: (image provided above)"
|
| 823 |
+
else:
|
| 824 |
+
data_str = f"[{channel_name}]: No data available."
|
| 825 |
+
|
| 826 |
+
available_after = [n for n in case.requestable_names if n not in acquired]
|
| 827 |
+
available_after_str = format_available_channels(channel_config, acquired)
|
| 828 |
+
|
| 829 |
+
if available_after:
|
| 830 |
+
text = (
|
| 831 |
+
f"Data received:\n{data_str}\n\n"
|
| 832 |
+
f"Remaining channels:\n{available_after_str}\n\n"
|
| 833 |
+
f"Use request_information only if another channel is worth its cost, or commit_diagnosis."
|
| 834 |
+
)
|
| 835 |
+
else:
|
| 836 |
+
text = (
|
| 837 |
+
f"Data received:\n{data_str}\n\n"
|
| 838 |
+
f"All channels acquired. Use commit_diagnosis."
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
content.append({"type": "text", "text": text})
|
| 842 |
+
conversation.append({"role": "user", "content": content})
|
| 843 |
+
|
| 844 |
+
# ============================================================
|
| 845 |
+
# Final Diagnosis
|
| 846 |
+
# ============================================================
|
| 847 |
+
|
| 848 |
+
def _get_final_diagnosis_tooluse(
|
| 849 |
+
self,
|
| 850 |
+
case: MedicalCase,
|
| 851 |
+
acquired: list[str],
|
| 852 |
+
conversation: list[dict],
|
| 853 |
+
) -> tuple[list[dict], VLMResponse, BeliefState | None]:
|
| 854 |
+
"""Get final diagnosis via tool call (full mode)."""
|
| 855 |
+
text_context = case.get_text_context(acquired)
|
| 856 |
+
acquired_str = format_acquired_info(text_context)
|
| 857 |
+
candidates_str = "\n".join(
|
| 858 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
final_prompt = (
|
| 862 |
+
f"All information has been gathered. Submit your final diagnosis.\n\n"
|
| 863 |
+
f"Information acquired:\n{acquired_str}\n\n"
|
| 864 |
+
f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
|
| 865 |
+
f"Use commit_diagnosis with calibrated probabilities summing to 1.0 "
|
| 866 |
+
f"and key_evidence for each diagnosis. Favor the least resource-intensive "
|
| 867 |
+
f"pathway supported by the evidence."
|
| 868 |
+
)
|
| 869 |
+
conversation.append({"role": "user", "content": final_prompt})
|
| 870 |
+
|
| 871 |
+
commit_tools = constrain_tools_for_step(budget_remaining=0)
|
| 872 |
+
|
| 873 |
+
response = self.client.call_with_retry(
|
| 874 |
+
system_prompt=SYSTEM_PROMPT_FINAL,
|
| 875 |
+
messages=conversation,
|
| 876 |
+
temperature=config.TEMPERATURE,
|
| 877 |
+
max_tokens=config.MAX_TOKENS,
|
| 878 |
+
tools=commit_tools,
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
return self._parse_final_response(response, case, acquired)
|
| 882 |
+
|
| 883 |
+
def _get_final_diagnosis_condensed(
|
| 884 |
+
self,
|
| 885 |
+
case: MedicalCase,
|
| 886 |
+
acquired: list[str],
|
| 887 |
+
acquired_data: dict[str, str],
|
| 888 |
+
steps: list[AcquisitionStep],
|
| 889 |
+
) -> tuple[list[dict], VLMResponse, BeliefState | None]:
|
| 890 |
+
"""Get final diagnosis via single-turn call (condensed mode)."""
|
| 891 |
+
content = []
|
| 892 |
+
|
| 893 |
+
# Include all images
|
| 894 |
+
content.extend(self._build_image_content(case.get_initial_images()))
|
| 895 |
+
for ch_name in acquired:
|
| 896 |
+
ch = case.get_channel(ch_name)
|
| 897 |
+
if ch and ch.channel_type == "image" and ch.value:
|
| 898 |
+
if isinstance(ch.value, list):
|
| 899 |
+
for img_b64 in ch.value:
|
| 900 |
+
content.append({
|
| 901 |
+
"type": "image_url",
|
| 902 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
|
| 903 |
+
})
|
| 904 |
+
else:
|
| 905 |
+
content.append({
|
| 906 |
+
"type": "image_url",
|
| 907 |
+
"image_url": {"url": f"data:image/jpeg;base64,{ch.value}"},
|
| 908 |
+
})
|
| 909 |
+
|
| 910 |
+
# Build text
|
| 911 |
+
candidates_str = "\n".join(
|
| 912 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 913 |
+
)
|
| 914 |
+
log_str = _build_acquisition_log(steps, acquired_data)
|
| 915 |
+
current_context = format_acquired_info(case.get_text_context(acquired))
|
| 916 |
+
|
| 917 |
+
prompt = (
|
| 918 |
+
f"Submit your final diagnosis based on all gathered information.\n\n"
|
| 919 |
+
f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
|
| 920 |
+
f"Acquisition history:\n{log_str}\n\n"
|
| 921 |
+
f"All currently available evidence:\n{current_context}\n\n"
|
| 922 |
+
f"Use commit_diagnosis with calibrated probabilities summing to 1.0 "
|
| 923 |
+
f"and key_evidence for each diagnosis. Favor the least resource-intensive "
|
| 924 |
+
f"pathway supported by the evidence."
|
| 925 |
+
)
|
| 926 |
+
content.append({"type": "text", "text": prompt})
|
| 927 |
+
|
| 928 |
+
commit_tools = constrain_tools_for_step(budget_remaining=0)
|
| 929 |
+
|
| 930 |
+
response = self.client.call_with_retry(
|
| 931 |
+
system_prompt=SYSTEM_PROMPT_FINAL,
|
| 932 |
+
messages=[{"role": "user", "content": content}],
|
| 933 |
+
temperature=config.TEMPERATURE,
|
| 934 |
+
max_tokens=config.MAX_TOKENS,
|
| 935 |
+
tools=commit_tools,
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
return self._parse_final_response(response, case, acquired)
|
| 939 |
+
|
| 940 |
+
def _parse_final_response(
|
| 941 |
+
self,
|
| 942 |
+
response: VLMResponse,
|
| 943 |
+
case: MedicalCase,
|
| 944 |
+
acquired: list[str],
|
| 945 |
+
) -> tuple[list[dict], VLMResponse, BeliefState | None]:
|
| 946 |
+
"""Parse the final diagnosis response (shared by both modes)."""
|
| 947 |
+
tool_call = response.tool_call
|
| 948 |
+
if tool_call and tool_call.tool_name == "commit_diagnosis":
|
| 949 |
+
ranking = self._extract_ranking_from_commit(tool_call.arguments)
|
| 950 |
+
distribution = {d["name"]: d["confidence"] for d in ranking}
|
| 951 |
+
belief = BeliefState(
|
| 952 |
+
step=len(acquired),
|
| 953 |
+
distribution=distribution,
|
| 954 |
+
channel_acquired=None,
|
| 955 |
+
)
|
| 956 |
+
return ranking, response, belief
|
| 957 |
+
|
| 958 |
+
logger.warning(
|
| 959 |
+
f"[{case.case_id}] Final diagnosis: no tool call, "
|
| 960 |
+
f"falling back to text extraction"
|
| 961 |
+
)
|
| 962 |
+
ranking = self._extract_ranking_from_text(response.text, case.candidates)
|
| 963 |
+
return ranking, response, None
|
| 964 |
+
|
| 965 |
+
# ============================================================
|
| 966 |
+
# Baseline Conditions
|
| 967 |
+
# ============================================================
|
| 968 |
+
|
| 969 |
+
def get_diagnosis_at_state(
|
| 970 |
+
self, case: MedicalCase, acquired: list[str]
|
| 971 |
+
) -> tuple[list[dict], VLMResponse]:
|
| 972 |
+
"""
|
| 973 |
+
Public helper: get a diagnosis given a set of acquired channels.
|
| 974 |
+
|
| 975 |
+
Used by TrajectoryCollector to evaluate intermediate states.
|
| 976 |
+
Returns (ranking, response).
|
| 977 |
+
"""
|
| 978 |
+
return self._get_final_diagnosis_single(case, acquired)
|
| 979 |
+
|
| 980 |
+
def diagnose_passive(self, case: MedicalCase) -> AgentResult:
|
| 981 |
+
"""Passive baseline: initial available context only, no acquisition."""
|
| 982 |
+
result = AgentResult(
|
| 983 |
+
case_id=case.case_id, dataset=case.dataset,
|
| 984 |
+
prompt_variant=self.prompt_variant,
|
| 985 |
+
backend=self.client.model, budget=0,
|
| 986 |
+
)
|
| 987 |
+
final_ranking, final_response = self._get_final_diagnosis_single(
|
| 988 |
+
case, acquired=[],
|
| 989 |
+
)
|
| 990 |
+
result.final_ranking = final_ranking
|
| 991 |
+
result.final_raw_response = final_response.text
|
| 992 |
+
result.total_latency_ms = final_response.latency_ms
|
| 993 |
+
result.total_input_tokens = final_response.input_tokens
|
| 994 |
+
result.total_output_tokens = final_response.output_tokens
|
| 995 |
+
result.total_case_cost = case.get_total_cost([])
|
| 996 |
+
return result
|
| 997 |
+
|
| 998 |
+
def diagnose_oracle(self, case: MedicalCase) -> AgentResult:
|
| 999 |
+
"""Oracle baseline: ALL information given upfront."""
|
| 1000 |
+
all_channels = list(case.requestable_channels.keys())
|
| 1001 |
+
result = AgentResult(
|
| 1002 |
+
case_id=case.case_id, dataset=case.dataset,
|
| 1003 |
+
prompt_variant=self.prompt_variant,
|
| 1004 |
+
backend=self.client.model,
|
| 1005 |
+
budget=len(all_channels),
|
| 1006 |
+
acquired_channels=all_channels,
|
| 1007 |
+
)
|
| 1008 |
+
final_ranking, final_response = self._get_final_diagnosis_single(
|
| 1009 |
+
case, acquired=all_channels,
|
| 1010 |
+
)
|
| 1011 |
+
result.final_ranking = final_ranking
|
| 1012 |
+
result.final_raw_response = final_response.text
|
| 1013 |
+
result.total_latency_ms = final_response.latency_ms
|
| 1014 |
+
result.total_input_tokens = final_response.input_tokens
|
| 1015 |
+
result.total_output_tokens = final_response.output_tokens
|
| 1016 |
+
result.acquisition_cost = case.get_acquisition_cost(all_channels)
|
| 1017 |
+
result.total_case_cost = case.get_total_cost(all_channels)
|
| 1018 |
+
return result
|
| 1019 |
+
|
| 1020 |
+
def diagnose_fixed_order(
|
| 1021 |
+
self, case: MedicalCase, order: list[str] = None
|
| 1022 |
+
) -> AgentResult:
|
| 1023 |
+
"""Fixed-order baseline: acquire channels in predetermined order."""
|
| 1024 |
+
if order is None:
|
| 1025 |
+
order = list(case.requestable_channels.keys())
|
| 1026 |
+
max_acq = self.budget if self.budget is not None else len(order)
|
| 1027 |
+
acquired = order[:max_acq]
|
| 1028 |
+
result = AgentResult(
|
| 1029 |
+
case_id=case.case_id, dataset=case.dataset,
|
| 1030 |
+
prompt_variant=self.prompt_variant,
|
| 1031 |
+
backend=self.client.model,
|
| 1032 |
+
budget=max_acq,
|
| 1033 |
+
acquired_channels=acquired,
|
| 1034 |
+
)
|
| 1035 |
+
final_ranking, final_response = self._get_final_diagnosis_single(
|
| 1036 |
+
case, acquired=acquired,
|
| 1037 |
+
)
|
| 1038 |
+
result.final_ranking = final_ranking
|
| 1039 |
+
result.final_raw_response = final_response.text
|
| 1040 |
+
result.total_latency_ms = final_response.latency_ms
|
| 1041 |
+
result.total_input_tokens = final_response.input_tokens
|
| 1042 |
+
result.total_output_tokens = final_response.output_tokens
|
| 1043 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 1044 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 1045 |
+
return result
|
| 1046 |
+
|
| 1047 |
+
def _get_final_diagnosis_single(
|
| 1048 |
+
self, case: MedicalCase, acquired: list[str]
|
| 1049 |
+
) -> tuple[list[dict], VLMResponse]:
|
| 1050 |
+
"""Single-turn final diagnosis (for baselines)."""
|
| 1051 |
+
images = case.get_all_images_up_to(acquired)
|
| 1052 |
+
text_context = case.get_text_context(acquired)
|
| 1053 |
+
acquired_str = format_acquired_info(text_context)
|
| 1054 |
+
candidates_str = "\n".join(
|
| 1055 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
user_text = (
|
| 1059 |
+
f"Provide your diagnosis using the currently available clinical information.\n\n"
|
| 1060 |
+
f"Available information:\n{acquired_str}\n\n"
|
| 1061 |
+
f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
|
| 1062 |
+
f"Use commit_diagnosis with calibrated probabilities summing "
|
| 1063 |
+
f"to 1.0 and key_evidence for each diagnosis. Prefer the least costly "
|
| 1064 |
+
f"explanation supported by the evidence."
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
commit_tools = constrain_tools_for_step(budget_remaining=0)
|
| 1068 |
+
|
| 1069 |
+
response = self.client.call_with_retry(
|
| 1070 |
+
system_prompt=SYSTEM_PROMPT_FINAL,
|
| 1071 |
+
user_text=user_text,
|
| 1072 |
+
images=images,
|
| 1073 |
+
temperature=config.TEMPERATURE,
|
| 1074 |
+
max_tokens=config.MAX_TOKENS,
|
| 1075 |
+
tools=commit_tools,
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
tool_call = response.tool_call
|
| 1079 |
+
if tool_call and tool_call.tool_name == "commit_diagnosis":
|
| 1080 |
+
ranking = self._extract_ranking_from_commit(tool_call.arguments)
|
| 1081 |
+
return ranking, response
|
| 1082 |
+
|
| 1083 |
+
ranking = self._extract_ranking_from_text(response.text, case.candidates)
|
| 1084 |
+
return ranking, response
|
| 1085 |
+
|
| 1086 |
+
# ============================================================
|
| 1087 |
+
# Helpers
|
| 1088 |
+
# ============================================================
|
| 1089 |
+
|
| 1090 |
+
def _build_image_content(self, images: list[str]) -> list[dict]:
|
| 1091 |
+
"""Build image content blocks for API messages."""
|
| 1092 |
+
content = []
|
| 1093 |
+
for img_b64 in images:
|
| 1094 |
+
content.append({
|
| 1095 |
+
"type": "image_url",
|
| 1096 |
+
"image_url": {
|
| 1097 |
+
"url": f"data:image/jpeg;base64,{img_b64}",
|
| 1098 |
+
"detail": "high",
|
| 1099 |
+
},
|
| 1100 |
+
})
|
| 1101 |
+
return content
|
| 1102 |
+
|
| 1103 |
+
def _extract_ranking_from_commit(self, args: dict) -> list[dict]:
|
| 1104 |
+
"""Extract ranking from commit_diagnosis tool call (structured JSON)."""
|
| 1105 |
+
ranked = args.get("ranked_diagnoses", [])
|
| 1106 |
+
ranking = []
|
| 1107 |
+
for i, entry in enumerate(ranked):
|
| 1108 |
+
ranking.append({
|
| 1109 |
+
"name": entry.get("name", ""),
|
| 1110 |
+
"confidence": entry.get("confidence", 0.0),
|
| 1111 |
+
"rank": i + 1,
|
| 1112 |
+
"key_evidence": entry.get("key_evidence", ""),
|
| 1113 |
+
})
|
| 1114 |
+
ranking.sort(key=lambda x: x["confidence"], reverse=True)
|
| 1115 |
+
for i, entry in enumerate(ranking):
|
| 1116 |
+
entry["rank"] = i + 1
|
| 1117 |
+
return ranking
|
| 1118 |
+
|
| 1119 |
+
def _extract_ranking_from_text(
|
| 1120 |
+
self, text: str, candidates: list[str]
|
| 1121 |
+
) -> list[dict]:
|
| 1122 |
+
"""Last-resort fallback: extract ranking from free text."""
|
| 1123 |
+
import re
|
| 1124 |
+
ranking = []
|
| 1125 |
+
pattern = (
|
| 1126 |
+
r"(\d+)\.\s*(.+?)\s*"
|
| 1127 |
+
r"\((?:confidence|probability|prob|conf):\s*([\d.]+)\)"
|
| 1128 |
+
)
|
| 1129 |
+
matches = re.findall(pattern, text, re.IGNORECASE)
|
| 1130 |
+
if matches:
|
| 1131 |
+
for rank_str, name, conf_str in matches:
|
| 1132 |
+
try:
|
| 1133 |
+
ranking.append({
|
| 1134 |
+
"name": name.strip(),
|
| 1135 |
+
"confidence": float(conf_str),
|
| 1136 |
+
"rank": int(rank_str),
|
| 1137 |
+
})
|
| 1138 |
+
except ValueError:
|
| 1139 |
+
continue
|
| 1140 |
+
if not ranking and candidates:
|
| 1141 |
+
for i, candidate in enumerate(candidates):
|
| 1142 |
+
if candidate.lower() in text.lower():
|
| 1143 |
+
ranking.append({
|
| 1144 |
+
"name": candidate,
|
| 1145 |
+
"confidence": max(0.1, 1.0 - i * 0.2),
|
| 1146 |
+
"rank": len(ranking) + 1,
|
| 1147 |
+
})
|
| 1148 |
+
ranking.sort(key=lambda x: x.get("confidence", 0), reverse=True)
|
| 1149 |
+
for i, entry in enumerate(ranking):
|
| 1150 |
+
entry["rank"] = i + 1
|
| 1151 |
+
return ranking
|
| 1152 |
+
|
| 1153 |
+
def _match_channel(
|
| 1154 |
+
self, requested: str, available: list[str]
|
| 1155 |
+
) -> str | None:
|
| 1156 |
+
"""Match requested channel name to available channels."""
|
| 1157 |
+
requested = requested.lower().strip().replace(" ", "_")
|
| 1158 |
+
if requested in available:
|
| 1159 |
+
return requested
|
| 1160 |
+
for ch in available:
|
| 1161 |
+
if requested in ch or ch in requested:
|
| 1162 |
+
return ch
|
| 1163 |
+
req_words = set(requested.split("_"))
|
| 1164 |
+
best_match, best_overlap = None, 0
|
| 1165 |
+
for ch in available:
|
| 1166 |
+
overlap = len(req_words & set(ch.split("_")))
|
| 1167 |
+
if overlap > best_overlap:
|
| 1168 |
+
best_overlap = overlap
|
| 1169 |
+
best_match = ch
|
| 1170 |
+
return best_match if best_overlap > 0 else None
|
api_client.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified multi-backend VLM API client with tool-use support.
|
| 3 |
+
|
| 4 |
+
Supports OpenAI (GPT-4o), Anthropic (Claude), and Together (Qwen2.5-VL).
|
| 5 |
+
Handles image encoding, rate limiting, retries, response normalization,
|
| 6 |
+
and native function/tool calling across all backends.
|
| 7 |
+
"""
|
| 8 |
+
import base64
|
| 9 |
+
import io
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
from collections import deque
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
import config
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class VLMResponse:
|
| 27 |
+
"""Normalized response from any VLM backend, including tool calls."""
|
| 28 |
+
text: str
|
| 29 |
+
model: str
|
| 30 |
+
backend: str
|
| 31 |
+
input_tokens: int
|
| 32 |
+
output_tokens: int
|
| 33 |
+
latency_ms: float
|
| 34 |
+
tool_call: object | None = None # tools.ToolCall if a tool was called
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _normalize_image_mode(img: Image.Image) -> Image.Image:
|
| 38 |
+
"""Normalize medical image modes to RGB-compatible formats for JPEG encoding."""
|
| 39 |
+
if img.mode in ("RGB",):
|
| 40 |
+
return img
|
| 41 |
+
if img.mode == "RGBA":
|
| 42 |
+
background = Image.new("RGB", img.size, (255, 255, 255))
|
| 43 |
+
background.paste(img, mask=img.split()[3])
|
| 44 |
+
return background
|
| 45 |
+
if img.mode == "L":
|
| 46 |
+
return img.convert("RGB")
|
| 47 |
+
if img.mode in ("I", "I;16", "I;16B", "I;16L"):
|
| 48 |
+
import numpy as np
|
| 49 |
+
arr = np.array(img, dtype=np.float64)
|
| 50 |
+
if arr.max() > arr.min():
|
| 51 |
+
arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255.0
|
| 52 |
+
else:
|
| 53 |
+
arr = np.zeros_like(arr)
|
| 54 |
+
return Image.fromarray(arr.astype(np.uint8)).convert("RGB")
|
| 55 |
+
if img.mode == "F":
|
| 56 |
+
import numpy as np
|
| 57 |
+
arr = np.array(img, dtype=np.float64)
|
| 58 |
+
if arr.max() > arr.min():
|
| 59 |
+
arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255.0
|
| 60 |
+
else:
|
| 61 |
+
arr = np.zeros_like(arr)
|
| 62 |
+
return Image.fromarray(arr.astype(np.uint8)).convert("RGB")
|
| 63 |
+
return img.convert("RGB")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def encode_image_to_base64(image_path: str | Path, max_size: int = 1024) -> str:
|
| 67 |
+
"""Load and encode an image to base64, resizing if needed."""
|
| 68 |
+
img = Image.open(image_path)
|
| 69 |
+
if max(img.size) > max_size:
|
| 70 |
+
ratio = max_size / max(img.size)
|
| 71 |
+
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
|
| 72 |
+
img = img.resize(new_size, Image.LANCZOS)
|
| 73 |
+
img = _normalize_image_mode(img)
|
| 74 |
+
buf = io.BytesIO()
|
| 75 |
+
img.save(buf, format="JPEG", quality=90)
|
| 76 |
+
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def encode_pil_image_to_base64(img: Image.Image, max_size: int = 1024) -> str:
|
| 80 |
+
"""Encode a PIL Image object to base64."""
|
| 81 |
+
if max(img.size) > max_size:
|
| 82 |
+
ratio = max_size / max(img.size)
|
| 83 |
+
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
|
| 84 |
+
img = img.resize(new_size, Image.LANCZOS)
|
| 85 |
+
img = _normalize_image_mode(img)
|
| 86 |
+
buf = io.BytesIO()
|
| 87 |
+
img.save(buf, format="JPEG", quality=90)
|
| 88 |
+
return base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BaseVLMClient(ABC):
|
| 92 |
+
"""Abstract base class for VLM API clients with tool-use support."""
|
| 93 |
+
|
| 94 |
+
def __init__(self, model: str, api_key: str, rate_limit: int = 30):
|
| 95 |
+
self.model = model
|
| 96 |
+
self.api_key = api_key
|
| 97 |
+
self.rate_limit = rate_limit
|
| 98 |
+
self._call_timestamps: deque[float] = deque()
|
| 99 |
+
|
| 100 |
+
def _rate_limit_wait(self):
|
| 101 |
+
"""Enforce rate limiting using a sliding window over the last 60 seconds."""
|
| 102 |
+
now = time.time()
|
| 103 |
+
while self._call_timestamps and now - self._call_timestamps[0] >= 60.0:
|
| 104 |
+
self._call_timestamps.popleft()
|
| 105 |
+
if len(self._call_timestamps) >= self.rate_limit:
|
| 106 |
+
sleep_time = 60.0 - (now - self._call_timestamps[0])
|
| 107 |
+
if sleep_time > 0:
|
| 108 |
+
time.sleep(sleep_time)
|
| 109 |
+
self._call_timestamps.popleft()
|
| 110 |
+
self._call_timestamps.append(time.time())
|
| 111 |
+
|
| 112 |
+
@abstractmethod
|
| 113 |
+
def call(
|
| 114 |
+
self,
|
| 115 |
+
system_prompt: str,
|
| 116 |
+
user_text: str,
|
| 117 |
+
images: list[str] | None = None,
|
| 118 |
+
temperature: float = 0.1,
|
| 119 |
+
max_tokens: int = 2048,
|
| 120 |
+
tools: list[dict] | None = None,
|
| 121 |
+
) -> VLMResponse:
|
| 122 |
+
"""Make a VLM API call, optionally with tools."""
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
def call_multiturn(
|
| 126 |
+
self,
|
| 127 |
+
system_prompt: str,
|
| 128 |
+
messages: list[dict],
|
| 129 |
+
temperature: float = 0.1,
|
| 130 |
+
max_tokens: int = 2048,
|
| 131 |
+
tools: list[dict] | None = None,
|
| 132 |
+
) -> VLMResponse:
|
| 133 |
+
"""Multi-turn conversation call with tool support. Override in subclasses."""
|
| 134 |
+
last_user = ""
|
| 135 |
+
last_images = []
|
| 136 |
+
for msg in reversed(messages):
|
| 137 |
+
if msg["role"] == "user":
|
| 138 |
+
if isinstance(msg["content"], str):
|
| 139 |
+
last_user = msg["content"]
|
| 140 |
+
elif isinstance(msg["content"], list):
|
| 141 |
+
for block in msg["content"]:
|
| 142 |
+
if block.get("type") == "text":
|
| 143 |
+
last_user = block["text"]
|
| 144 |
+
elif block.get("type") == "image_url":
|
| 145 |
+
last_images.append(block["image_url"]["url"].split(",", 1)[-1])
|
| 146 |
+
break
|
| 147 |
+
return self.call(system_prompt, last_user, last_images or None, temperature, max_tokens, tools)
|
| 148 |
+
|
| 149 |
+
def call_with_retry(
|
| 150 |
+
self,
|
| 151 |
+
system_prompt: str,
|
| 152 |
+
user_text: str = None,
|
| 153 |
+
images: list[str] | None = None,
|
| 154 |
+
temperature: float = 0.1,
|
| 155 |
+
max_tokens: int = 2048,
|
| 156 |
+
max_retries: int = 3,
|
| 157 |
+
messages: list[dict] | None = None,
|
| 158 |
+
tools: list[dict] | None = None,
|
| 159 |
+
) -> VLMResponse:
|
| 160 |
+
"""Call with exponential backoff retry. Supports single-turn, multi-turn, and tools."""
|
| 161 |
+
for attempt in range(max_retries):
|
| 162 |
+
try:
|
| 163 |
+
self._rate_limit_wait()
|
| 164 |
+
if messages is not None:
|
| 165 |
+
return self.call_multiturn(system_prompt, messages, temperature, max_tokens, tools)
|
| 166 |
+
return self.call(system_prompt, user_text, images, temperature, max_tokens, tools)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
wait_time = 2 ** attempt * 5
|
| 169 |
+
logger.warning(
|
| 170 |
+
f"API call failed (attempt {attempt + 1}/{max_retries}): {e}. "
|
| 171 |
+
f"Retrying in {wait_time}s..."
|
| 172 |
+
)
|
| 173 |
+
if attempt == max_retries - 1:
|
| 174 |
+
raise
|
| 175 |
+
time.sleep(wait_time)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _parse_tool_call_openai(response_message) -> object | None:
|
| 179 |
+
"""Extract a ToolCall from an OpenAI response message."""
|
| 180 |
+
from tools import ToolCall
|
| 181 |
+
|
| 182 |
+
tool_calls = getattr(response_message, "tool_calls", None)
|
| 183 |
+
if not tool_calls:
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
tc = tool_calls[0] # Take the first tool call
|
| 187 |
+
try:
|
| 188 |
+
arguments = json.loads(tc.function.arguments)
|
| 189 |
+
except (json.JSONDecodeError, AttributeError):
|
| 190 |
+
arguments = {}
|
| 191 |
+
|
| 192 |
+
return ToolCall(
|
| 193 |
+
tool_name=tc.function.name,
|
| 194 |
+
arguments=arguments,
|
| 195 |
+
call_id=tc.id,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _parse_tool_call_anthropic(response) -> object | None:
|
| 200 |
+
"""Extract a ToolCall from an Anthropic response."""
|
| 201 |
+
from tools import ToolCall
|
| 202 |
+
|
| 203 |
+
for block in response.content:
|
| 204 |
+
if block.type == "tool_use":
|
| 205 |
+
return ToolCall(
|
| 206 |
+
tool_name=block.name,
|
| 207 |
+
arguments=block.input,
|
| 208 |
+
call_id=block.id,
|
| 209 |
+
)
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ============================================================
|
| 214 |
+
# OpenAI Backend (GPT-4o) — with tool calling
|
| 215 |
+
# ============================================================
|
| 216 |
+
|
| 217 |
+
class OpenAIClient(BaseVLMClient):
|
| 218 |
+
"""OpenAI GPT-4o API client with native function calling."""
|
| 219 |
+
|
| 220 |
+
def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
|
| 221 |
+
super().__init__(
|
| 222 |
+
model=model or config.MODELS["openai"],
|
| 223 |
+
api_key=api_key or config.OPENAI_API_KEY,
|
| 224 |
+
rate_limit=rate_limit or config.RATE_LIMITS["openai"],
|
| 225 |
+
)
|
| 226 |
+
from openai import OpenAI
|
| 227 |
+
self.client = OpenAI(api_key=self.api_key)
|
| 228 |
+
|
| 229 |
+
def call(
|
| 230 |
+
self,
|
| 231 |
+
system_prompt: str,
|
| 232 |
+
user_text: str,
|
| 233 |
+
images: list[str] | None = None,
|
| 234 |
+
temperature: float = 0.1,
|
| 235 |
+
max_tokens: int = 2048,
|
| 236 |
+
tools: list[dict] | None = None,
|
| 237 |
+
) -> VLMResponse:
|
| 238 |
+
content = []
|
| 239 |
+
if images:
|
| 240 |
+
for img_b64 in images:
|
| 241 |
+
content.append({
|
| 242 |
+
"type": "image_url",
|
| 243 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}", "detail": "high"},
|
| 244 |
+
})
|
| 245 |
+
content.append({"type": "text", "text": user_text})
|
| 246 |
+
|
| 247 |
+
messages = [
|
| 248 |
+
{"role": "system", "content": system_prompt},
|
| 249 |
+
{"role": "user", "content": content},
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
kwargs = {
|
| 253 |
+
"model": self.model,
|
| 254 |
+
"messages": messages,
|
| 255 |
+
"temperature": temperature,
|
| 256 |
+
"max_tokens": max_tokens,
|
| 257 |
+
}
|
| 258 |
+
if tools:
|
| 259 |
+
from tools import to_openai_tools
|
| 260 |
+
kwargs["tools"] = to_openai_tools(tools)
|
| 261 |
+
kwargs["tool_choice"] = "required"
|
| 262 |
+
|
| 263 |
+
t0 = time.time()
|
| 264 |
+
response = self.client.chat.completions.create(**kwargs)
|
| 265 |
+
latency = (time.time() - t0) * 1000
|
| 266 |
+
|
| 267 |
+
msg = response.choices[0].message
|
| 268 |
+
tool_call = _parse_tool_call_openai(msg) if tools else None
|
| 269 |
+
|
| 270 |
+
return VLMResponse(
|
| 271 |
+
text=msg.content or "",
|
| 272 |
+
model=self.model,
|
| 273 |
+
backend="openai",
|
| 274 |
+
input_tokens=response.usage.prompt_tokens,
|
| 275 |
+
output_tokens=response.usage.completion_tokens,
|
| 276 |
+
latency_ms=latency,
|
| 277 |
+
tool_call=tool_call,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def call_multiturn(
|
| 281 |
+
self,
|
| 282 |
+
system_prompt: str,
|
| 283 |
+
messages: list[dict],
|
| 284 |
+
temperature: float = 0.1,
|
| 285 |
+
max_tokens: int = 2048,
|
| 286 |
+
tools: list[dict] | None = None,
|
| 287 |
+
) -> VLMResponse:
|
| 288 |
+
"""
|
| 289 |
+
Multi-turn OpenAI call with full tool-calling protocol.
|
| 290 |
+
|
| 291 |
+
Translates our internal message format to OpenAI's API format:
|
| 292 |
+
- "user" → role:"user" (passed through)
|
| 293 |
+
- "assistant" → role:"assistant" with tool_calls array
|
| 294 |
+
- "tool_result" → role:"tool" (text result) + role:"user" (images + follow-up)
|
| 295 |
+
|
| 296 |
+
OpenAI requires: after an assistant message with tool_calls, the next
|
| 297 |
+
message MUST be role:"tool" with the matching tool_call_id. Images
|
| 298 |
+
cannot go in tool messages, so we send them in a separate user message.
|
| 299 |
+
"""
|
| 300 |
+
api_messages = [{"role": "system", "content": system_prompt}]
|
| 301 |
+
|
| 302 |
+
for msg in messages:
|
| 303 |
+
role = msg["role"]
|
| 304 |
+
|
| 305 |
+
if role == "user":
|
| 306 |
+
api_messages.append({
|
| 307 |
+
"role": "user",
|
| 308 |
+
"content": msg["content"],
|
| 309 |
+
})
|
| 310 |
+
|
| 311 |
+
elif role == "assistant":
|
| 312 |
+
api_msg = {"role": "assistant"}
|
| 313 |
+
if msg.get("tool_calls"):
|
| 314 |
+
tc = msg["tool_calls"][0]
|
| 315 |
+
api_msg["tool_calls"] = [{
|
| 316 |
+
"id": tc.call_id,
|
| 317 |
+
"type": "function",
|
| 318 |
+
"function": {
|
| 319 |
+
"name": tc.tool_name,
|
| 320 |
+
"arguments": json.dumps(tc.arguments),
|
| 321 |
+
},
|
| 322 |
+
}]
|
| 323 |
+
# OpenAI requires content to be null when tool_calls present
|
| 324 |
+
api_msg["content"] = msg.get("content") or None
|
| 325 |
+
else:
|
| 326 |
+
api_msg["content"] = msg.get("content", "")
|
| 327 |
+
api_messages.append(api_msg)
|
| 328 |
+
|
| 329 |
+
elif role == "tool_result":
|
| 330 |
+
# Step 1: Send the tool result as role:"tool"
|
| 331 |
+
api_messages.append({
|
| 332 |
+
"role": "tool",
|
| 333 |
+
"tool_call_id": msg["tool_call_id"],
|
| 334 |
+
"content": msg.get("content", ""),
|
| 335 |
+
})
|
| 336 |
+
|
| 337 |
+
# Step 2: Send images + follow-up as a user message
|
| 338 |
+
# (OpenAI tool messages don't support image content blocks)
|
| 339 |
+
follow_up_content = []
|
| 340 |
+
for img_b64 in msg.get("images", []):
|
| 341 |
+
follow_up_content.append({
|
| 342 |
+
"type": "image_url",
|
| 343 |
+
"image_url": {
|
| 344 |
+
"url": f"data:image/jpeg;base64,{img_b64}",
|
| 345 |
+
},
|
| 346 |
+
})
|
| 347 |
+
follow_up = msg.get("follow_up", "")
|
| 348 |
+
if follow_up:
|
| 349 |
+
follow_up_content.append({
|
| 350 |
+
"type": "text",
|
| 351 |
+
"text": follow_up,
|
| 352 |
+
})
|
| 353 |
+
if follow_up_content:
|
| 354 |
+
api_messages.append({
|
| 355 |
+
"role": "user",
|
| 356 |
+
"content": follow_up_content,
|
| 357 |
+
})
|
| 358 |
+
|
| 359 |
+
kwargs = {
|
| 360 |
+
"model": self.model,
|
| 361 |
+
"messages": api_messages,
|
| 362 |
+
"temperature": temperature,
|
| 363 |
+
"max_tokens": max_tokens,
|
| 364 |
+
}
|
| 365 |
+
if tools:
|
| 366 |
+
from tools import to_openai_tools
|
| 367 |
+
kwargs["tools"] = to_openai_tools(tools)
|
| 368 |
+
kwargs["tool_choice"] = "required"
|
| 369 |
+
|
| 370 |
+
t0 = time.time()
|
| 371 |
+
response = self.client.chat.completions.create(**kwargs)
|
| 372 |
+
latency = (time.time() - t0) * 1000
|
| 373 |
+
|
| 374 |
+
msg = response.choices[0].message
|
| 375 |
+
tool_call = _parse_tool_call_openai(msg) if tools else None
|
| 376 |
+
|
| 377 |
+
return VLMResponse(
|
| 378 |
+
text=msg.content or "",
|
| 379 |
+
model=self.model,
|
| 380 |
+
backend="openai",
|
| 381 |
+
input_tokens=response.usage.prompt_tokens,
|
| 382 |
+
output_tokens=response.usage.completion_tokens,
|
| 383 |
+
latency_ms=latency,
|
| 384 |
+
tool_call=tool_call,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# ============================================================
|
| 389 |
+
# Anthropic Backend (Claude) — with tool use
|
| 390 |
+
# ============================================================
|
| 391 |
+
|
| 392 |
+
class AnthropicClient(BaseVLMClient):
|
| 393 |
+
"""Anthropic Claude API client with native tool use."""
|
| 394 |
+
|
| 395 |
+
def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
|
| 396 |
+
super().__init__(
|
| 397 |
+
model=model or config.MODELS["anthropic"],
|
| 398 |
+
api_key=api_key or config.ANTHROPIC_API_KEY,
|
| 399 |
+
rate_limit=rate_limit or config.RATE_LIMITS["anthropic"],
|
| 400 |
+
)
|
| 401 |
+
from anthropic import Anthropic
|
| 402 |
+
self.client = Anthropic(api_key=self.api_key)
|
| 403 |
+
|
| 404 |
+
def call(
|
| 405 |
+
self,
|
| 406 |
+
system_prompt: str,
|
| 407 |
+
user_text: str,
|
| 408 |
+
images: list[str] | None = None,
|
| 409 |
+
temperature: float = 0.1,
|
| 410 |
+
max_tokens: int = 2048,
|
| 411 |
+
tools: list[dict] | None = None,
|
| 412 |
+
) -> VLMResponse:
|
| 413 |
+
content = []
|
| 414 |
+
if images:
|
| 415 |
+
for img_b64 in images:
|
| 416 |
+
content.append({
|
| 417 |
+
"type": "image",
|
| 418 |
+
"source": {
|
| 419 |
+
"type": "base64",
|
| 420 |
+
"media_type": "image/jpeg",
|
| 421 |
+
"data": img_b64,
|
| 422 |
+
},
|
| 423 |
+
})
|
| 424 |
+
content.append({"type": "text", "text": user_text})
|
| 425 |
+
|
| 426 |
+
kwargs = {
|
| 427 |
+
"model": self.model,
|
| 428 |
+
"system": system_prompt,
|
| 429 |
+
"messages": [{"role": "user", "content": content}],
|
| 430 |
+
"temperature": temperature,
|
| 431 |
+
"max_tokens": max_tokens,
|
| 432 |
+
}
|
| 433 |
+
if tools:
|
| 434 |
+
from tools import to_anthropic_tools
|
| 435 |
+
kwargs["tools"] = to_anthropic_tools(tools)
|
| 436 |
+
kwargs["tool_choice"] = {"type": "any"}
|
| 437 |
+
|
| 438 |
+
t0 = time.time()
|
| 439 |
+
response = self.client.messages.create(**kwargs)
|
| 440 |
+
latency = (time.time() - t0) * 1000
|
| 441 |
+
|
| 442 |
+
# Extract text from response (may have both text and tool_use blocks)
|
| 443 |
+
text_parts = []
|
| 444 |
+
for block in response.content:
|
| 445 |
+
if hasattr(block, "text"):
|
| 446 |
+
text_parts.append(block.text)
|
| 447 |
+
|
| 448 |
+
tool_call = _parse_tool_call_anthropic(response) if tools else None
|
| 449 |
+
|
| 450 |
+
return VLMResponse(
|
| 451 |
+
text="\n".join(text_parts),
|
| 452 |
+
model=self.model,
|
| 453 |
+
backend="anthropic",
|
| 454 |
+
input_tokens=response.usage.input_tokens,
|
| 455 |
+
output_tokens=response.usage.output_tokens,
|
| 456 |
+
latency_ms=latency,
|
| 457 |
+
tool_call=tool_call,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def call_multiturn(
|
| 461 |
+
self,
|
| 462 |
+
system_prompt: str,
|
| 463 |
+
messages: list[dict],
|
| 464 |
+
temperature: float = 0.1,
|
| 465 |
+
max_tokens: int = 2048,
|
| 466 |
+
tools: list[dict] | None = None,
|
| 467 |
+
) -> VLMResponse:
|
| 468 |
+
"""
|
| 469 |
+
Multi-turn Anthropic call with full tool-use protocol.
|
| 470 |
+
|
| 471 |
+
Translates our internal message format to Anthropic's API format:
|
| 472 |
+
- "user" → role:"user" (passed through)
|
| 473 |
+
- "assistant" → role:"assistant" with tool_use content blocks
|
| 474 |
+
- "tool_result" → role:"user" with tool_result block + image blocks
|
| 475 |
+
|
| 476 |
+
Anthropic's protocol: after an assistant message with a tool_use block,
|
| 477 |
+
the next message MUST be role:"user" containing a tool_result block
|
| 478 |
+
with the matching tool_use_id. Images and follow-up text can be
|
| 479 |
+
included in the same user message as additional content blocks.
|
| 480 |
+
"""
|
| 481 |
+
api_messages = []
|
| 482 |
+
|
| 483 |
+
for msg in messages:
|
| 484 |
+
role = msg["role"]
|
| 485 |
+
|
| 486 |
+
if role == "user":
|
| 487 |
+
content = msg["content"]
|
| 488 |
+
# Convert image_url format to Anthropic's image format
|
| 489 |
+
if isinstance(content, list):
|
| 490 |
+
anthropic_content = []
|
| 491 |
+
for block in content:
|
| 492 |
+
if block.get("type") == "image_url":
|
| 493 |
+
url = block["image_url"]["url"]
|
| 494 |
+
# Extract base64 data from data URL
|
| 495 |
+
if url.startswith("data:"):
|
| 496 |
+
b64_data = url.split(",", 1)[-1]
|
| 497 |
+
else:
|
| 498 |
+
b64_data = url
|
| 499 |
+
anthropic_content.append({
|
| 500 |
+
"type": "image",
|
| 501 |
+
"source": {
|
| 502 |
+
"type": "base64",
|
| 503 |
+
"media_type": "image/jpeg",
|
| 504 |
+
"data": b64_data,
|
| 505 |
+
},
|
| 506 |
+
})
|
| 507 |
+
elif block.get("type") == "text":
|
| 508 |
+
anthropic_content.append(block)
|
| 509 |
+
else:
|
| 510 |
+
anthropic_content.append(block)
|
| 511 |
+
api_messages.append({
|
| 512 |
+
"role": "user",
|
| 513 |
+
"content": anthropic_content,
|
| 514 |
+
})
|
| 515 |
+
else:
|
| 516 |
+
api_messages.append({
|
| 517 |
+
"role": "user",
|
| 518 |
+
"content": content,
|
| 519 |
+
})
|
| 520 |
+
|
| 521 |
+
elif role == "assistant":
|
| 522 |
+
content_blocks = []
|
| 523 |
+
if msg.get("content"):
|
| 524 |
+
content_blocks.append({
|
| 525 |
+
"type": "text",
|
| 526 |
+
"text": msg["content"],
|
| 527 |
+
})
|
| 528 |
+
if msg.get("tool_calls"):
|
| 529 |
+
tc = msg["tool_calls"][0]
|
| 530 |
+
content_blocks.append({
|
| 531 |
+
"type": "tool_use",
|
| 532 |
+
"id": tc.call_id,
|
| 533 |
+
"name": tc.tool_name,
|
| 534 |
+
"input": tc.arguments,
|
| 535 |
+
})
|
| 536 |
+
api_messages.append({
|
| 537 |
+
"role": "assistant",
|
| 538 |
+
"content": content_blocks,
|
| 539 |
+
})
|
| 540 |
+
|
| 541 |
+
elif role == "tool_result":
|
| 542 |
+
# Anthropic: tool_result goes in a user message alongside
|
| 543 |
+
# any images and follow-up text
|
| 544 |
+
user_content = []
|
| 545 |
+
|
| 546 |
+
# The tool_result block
|
| 547 |
+
user_content.append({
|
| 548 |
+
"type": "tool_result",
|
| 549 |
+
"tool_use_id": msg["tool_call_id"],
|
| 550 |
+
"content": msg.get("content", ""),
|
| 551 |
+
})
|
| 552 |
+
|
| 553 |
+
# Images from the channel data
|
| 554 |
+
for img_b64 in msg.get("images", []):
|
| 555 |
+
user_content.append({
|
| 556 |
+
"type": "image",
|
| 557 |
+
"source": {
|
| 558 |
+
"type": "base64",
|
| 559 |
+
"media_type": "image/jpeg",
|
| 560 |
+
"data": img_b64,
|
| 561 |
+
},
|
| 562 |
+
})
|
| 563 |
+
|
| 564 |
+
# Follow-up text (next-step instructions)
|
| 565 |
+
follow_up = msg.get("follow_up", "")
|
| 566 |
+
if follow_up:
|
| 567 |
+
user_content.append({
|
| 568 |
+
"type": "text",
|
| 569 |
+
"text": follow_up,
|
| 570 |
+
})
|
| 571 |
+
|
| 572 |
+
api_messages.append({
|
| 573 |
+
"role": "user",
|
| 574 |
+
"content": user_content,
|
| 575 |
+
})
|
| 576 |
+
|
| 577 |
+
kwargs = {
|
| 578 |
+
"model": self.model,
|
| 579 |
+
"system": system_prompt,
|
| 580 |
+
"messages": api_messages,
|
| 581 |
+
"temperature": temperature,
|
| 582 |
+
"max_tokens": max_tokens,
|
| 583 |
+
}
|
| 584 |
+
if tools:
|
| 585 |
+
from tools import to_anthropic_tools
|
| 586 |
+
kwargs["tools"] = to_anthropic_tools(tools)
|
| 587 |
+
kwargs["tool_choice"] = {"type": "any"}
|
| 588 |
+
|
| 589 |
+
t0 = time.time()
|
| 590 |
+
response = self.client.messages.create(**kwargs)
|
| 591 |
+
latency = (time.time() - t0) * 1000
|
| 592 |
+
|
| 593 |
+
text_parts = []
|
| 594 |
+
for block in response.content:
|
| 595 |
+
if hasattr(block, "text"):
|
| 596 |
+
text_parts.append(block.text)
|
| 597 |
+
|
| 598 |
+
tool_call = _parse_tool_call_anthropic(response) if tools else None
|
| 599 |
+
|
| 600 |
+
return VLMResponse(
|
| 601 |
+
text="\n".join(text_parts),
|
| 602 |
+
model=self.model,
|
| 603 |
+
backend="anthropic",
|
| 604 |
+
input_tokens=response.usage.input_tokens,
|
| 605 |
+
output_tokens=response.usage.output_tokens,
|
| 606 |
+
latency_ms=latency,
|
| 607 |
+
tool_call=tool_call,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# ============================================================
|
| 612 |
+
# Together Backend (Qwen2.5-VL) — with tool calling
|
| 613 |
+
# ============================================================
|
| 614 |
+
|
| 615 |
+
class TogetherClient(BaseVLMClient):
|
| 616 |
+
"""Together AI client with function calling support."""
|
| 617 |
+
|
| 618 |
+
def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
|
| 619 |
+
super().__init__(
|
| 620 |
+
model=model or config.MODELS["together"],
|
| 621 |
+
api_key=api_key or config.TOGETHER_API_KEY,
|
| 622 |
+
rate_limit=rate_limit or config.RATE_LIMITS["together"],
|
| 623 |
+
)
|
| 624 |
+
from together import Together
|
| 625 |
+
self.client = Together(api_key=self.api_key)
|
| 626 |
+
|
| 627 |
+
def call(
|
| 628 |
+
self,
|
| 629 |
+
system_prompt: str,
|
| 630 |
+
user_text: str,
|
| 631 |
+
images: list[str] | None = None,
|
| 632 |
+
temperature: float = 0.1,
|
| 633 |
+
max_tokens: int = 2048,
|
| 634 |
+
tools: list[dict] | None = None,
|
| 635 |
+
) -> VLMResponse:
|
| 636 |
+
content = []
|
| 637 |
+
if images:
|
| 638 |
+
for img_b64 in images:
|
| 639 |
+
content.append({
|
| 640 |
+
"type": "image_url",
|
| 641 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
|
| 642 |
+
})
|
| 643 |
+
content.append({"type": "text", "text": user_text})
|
| 644 |
+
|
| 645 |
+
messages = [
|
| 646 |
+
{"role": "system", "content": system_prompt},
|
| 647 |
+
{"role": "user", "content": content},
|
| 648 |
+
]
|
| 649 |
+
|
| 650 |
+
kwargs = {
|
| 651 |
+
"model": self.model,
|
| 652 |
+
"messages": messages,
|
| 653 |
+
"temperature": temperature,
|
| 654 |
+
"max_tokens": max_tokens,
|
| 655 |
+
}
|
| 656 |
+
if tools:
|
| 657 |
+
from tools import to_openai_tools
|
| 658 |
+
kwargs["tools"] = to_openai_tools(tools)
|
| 659 |
+
|
| 660 |
+
t0 = time.time()
|
| 661 |
+
response = self.client.chat.completions.create(**kwargs)
|
| 662 |
+
latency = (time.time() - t0) * 1000
|
| 663 |
+
|
| 664 |
+
msg = response.choices[0].message
|
| 665 |
+
usage = response.usage
|
| 666 |
+
tool_call = _parse_tool_call_openai(msg) if tools else None
|
| 667 |
+
|
| 668 |
+
return VLMResponse(
|
| 669 |
+
text=msg.content or "",
|
| 670 |
+
model=self.model,
|
| 671 |
+
backend="together",
|
| 672 |
+
input_tokens=getattr(usage, "prompt_tokens", 0),
|
| 673 |
+
output_tokens=getattr(usage, "completion_tokens", 0),
|
| 674 |
+
latency_ms=latency,
|
| 675 |
+
tool_call=tool_call,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
# ============================================================
|
| 680 |
+
# Client Factory
|
| 681 |
+
# ============================================================
|
| 682 |
+
|
| 683 |
+
class OpenAIMiniClient(OpenAIClient):
|
| 684 |
+
"""OpenAI GPT-4o-mini client."""
|
| 685 |
+
|
| 686 |
+
def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
|
| 687 |
+
BaseVLMClient.__init__(
|
| 688 |
+
self,
|
| 689 |
+
model=model or config.MODELS["openai_mini"],
|
| 690 |
+
api_key=api_key or config.OPENAI_API_KEY,
|
| 691 |
+
rate_limit=rate_limit or config.RATE_LIMITS["openai_mini"],
|
| 692 |
+
)
|
| 693 |
+
from openai import OpenAI
|
| 694 |
+
self.client = OpenAI(api_key=self.api_key)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def create_client(backend: str, **kwargs) -> BaseVLMClient:
|
| 698 |
+
"""Factory function to create a VLM client by backend name."""
|
| 699 |
+
clients = {
|
| 700 |
+
"openai": OpenAIClient,
|
| 701 |
+
"openai_mini": OpenAIMiniClient,
|
| 702 |
+
"anthropic": AnthropicClient,
|
| 703 |
+
"together": TogetherClient,
|
| 704 |
+
}
|
| 705 |
+
if backend not in clients:
|
| 706 |
+
raise ValueError(f"Unknown backend: {backend}. Choose from {list(clients.keys())}")
|
| 707 |
+
return clients[backend](**kwargs)
|
app.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Interactive Demo for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
A Gradio-based UI that lets users:
|
| 5 |
+
- Select from pre-built demo cases OR enter a custom clinical scenario
|
| 6 |
+
- Upload medical images (optional)
|
| 7 |
+
- Watch the agent's step-by-step reasoning, information acquisition, and
|
| 8 |
+
entropy reduction in real time
|
| 9 |
+
- No budget constraint — the agent acquires as many channels as it needs
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python app.py
|
| 13 |
+
python app.py --backend openai
|
| 14 |
+
python app.py --backend anthropic --port 7861
|
| 15 |
+
"""
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
import math
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import gradio as gr
|
| 27 |
+
from PIL import Image
|
| 28 |
+
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| 30 |
+
|
| 31 |
+
import config
|
| 32 |
+
from api_client import create_client, encode_image_to_base64, encode_pil_image_to_base64
|
| 33 |
+
from agent import ActiveMedAgent, AgentResult, AcquisitionStep, SYSTEM_PROMPT_FULL, SYSTEM_PROMPT_CONDENSED, SYSTEM_PROMPT_FINAL
|
| 34 |
+
from datasets.base import MedicalCase, ChannelData
|
| 35 |
+
from tools import AGENT_TOOLS, constrain_tools_for_step, ToolCall
|
| 36 |
+
from information_gain import (
|
| 37 |
+
BeliefState, BeliefTrajectory,
|
| 38 |
+
compute_entropy, compute_kl_divergence,
|
| 39 |
+
estimate_expected_information_gain,
|
| 40 |
+
should_commit, compute_value_of_information,
|
| 41 |
+
)
|
| 42 |
+
from prompts import format_available_channels, format_acquired_info
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.INFO,
|
| 46 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 47 |
+
)
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ============================================================
|
| 52 |
+
# Backend Availability Detection
|
| 53 |
+
# ============================================================
|
| 54 |
+
|
| 55 |
+
def _detect_available_backends() -> list[str]:
|
| 56 |
+
"""Detect which backends have API keys configured."""
|
| 57 |
+
available = []
|
| 58 |
+
if config.OPENAI_API_KEY and config.OPENAI_API_KEY != "sk-...":
|
| 59 |
+
available.append("openai")
|
| 60 |
+
if config.ANTHROPIC_API_KEY and config.ANTHROPIC_API_KEY != "sk-ant-...":
|
| 61 |
+
available.append("anthropic")
|
| 62 |
+
if config.TOGETHER_API_KEY:
|
| 63 |
+
available.append("together")
|
| 64 |
+
return available
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
AVAILABLE_BACKENDS = _detect_available_backends()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ============================================================
|
| 71 |
+
# Simulation Mode — works without API keys
|
| 72 |
+
# ============================================================
|
| 73 |
+
|
| 74 |
+
def _simulate_agent_on_case(case: MedicalCase) -> AgentResult:
|
| 75 |
+
"""
|
| 76 |
+
Run a simulated agent that demonstrates the full pipeline
|
| 77 |
+
with realistic-looking reasoning traces. No API keys needed.
|
| 78 |
+
"""
|
| 79 |
+
import random
|
| 80 |
+
random.seed(42)
|
| 81 |
+
|
| 82 |
+
result = AgentResult(
|
| 83 |
+
case_id=case.case_id,
|
| 84 |
+
dataset=case.dataset,
|
| 85 |
+
prompt_variant="A",
|
| 86 |
+
backend="simulated (no API key)",
|
| 87 |
+
budget=len(case.requestable_channels),
|
| 88 |
+
)
|
| 89 |
+
trajectory = BeliefTrajectory(case_id=case.case_id)
|
| 90 |
+
acquired = []
|
| 91 |
+
n_candidates = len(case.candidates)
|
| 92 |
+
|
| 93 |
+
# Generate initial uniform-ish distribution
|
| 94 |
+
probs = np.random.dirichlet(np.ones(n_candidates) * 2.0).tolist()
|
| 95 |
+
probs.sort(reverse=True)
|
| 96 |
+
# Make ground truth likely to end up on top by the end
|
| 97 |
+
gt_idx = case.ground_truth_rank
|
| 98 |
+
|
| 99 |
+
requestable_names = list(case.requestable_channels.keys())
|
| 100 |
+
cumulative_cost = case.get_initial_cost()
|
| 101 |
+
|
| 102 |
+
for step_idx, ch_name in enumerate(requestable_names):
|
| 103 |
+
ch = case.requestable_channels[ch_name]
|
| 104 |
+
|
| 105 |
+
# Evolve the distribution — gradually concentrate on correct answer
|
| 106 |
+
progress = (step_idx + 1) / len(requestable_names)
|
| 107 |
+
new_probs = []
|
| 108 |
+
for i in range(n_candidates):
|
| 109 |
+
if i == gt_idx:
|
| 110 |
+
new_probs.append(probs[i] + 0.15 * progress + random.uniform(0, 0.05))
|
| 111 |
+
else:
|
| 112 |
+
new_probs.append(max(0.01, probs[i] - 0.04 * progress + random.uniform(-0.02, 0.02)))
|
| 113 |
+
total = sum(new_probs)
|
| 114 |
+
probs = [p / total for p in new_probs]
|
| 115 |
+
|
| 116 |
+
distribution = {case.candidates[i]: probs[i] for i in range(n_candidates)}
|
| 117 |
+
sorted_dist = sorted(distribution.items(), key=lambda x: -x[1])
|
| 118 |
+
|
| 119 |
+
prev_entropy = trajectory.states[-1].entropy if trajectory.states else compute_entropy(distribution) + 0.3
|
| 120 |
+
belief = BeliefState(
|
| 121 |
+
step=step_idx,
|
| 122 |
+
distribution=distribution,
|
| 123 |
+
channel_acquired=ch_name,
|
| 124 |
+
)
|
| 125 |
+
trajectory.states.append(belief)
|
| 126 |
+
|
| 127 |
+
ig = prev_entropy - belief.entropy
|
| 128 |
+
kl = abs(ig) * 1.2 + random.uniform(0, 0.1)
|
| 129 |
+
|
| 130 |
+
top_two = sorted_dist[:2]
|
| 131 |
+
reasoning_templates = [
|
| 132 |
+
f"Need to distinguish between {top_two[0][0]} ({top_two[0][1]:.0%}) and {top_two[1][0]} ({top_two[1][1]:.0%}). "
|
| 133 |
+
f"Requesting {ch_name} to resolve this uncertainty.",
|
| 134 |
+
f"Current top diagnosis is {top_two[0][0]} at {top_two[0][1]:.0%} but {top_two[1][0]} cannot be ruled out. "
|
| 135 |
+
f"The {ch_name} channel should provide discriminating evidence.",
|
| 136 |
+
f"Diagnostic uncertainty remains high (H={belief.entropy:.2f} bits). "
|
| 137 |
+
f"The {ch_name} data is expected to significantly narrow the differential.",
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
step = AcquisitionStep(
|
| 141 |
+
step=step_idx,
|
| 142 |
+
tool_call=ToolCall(tool_name="request_information", arguments={
|
| 143 |
+
"channel_name": ch_name,
|
| 144 |
+
"reasoning": reasoning_templates[step_idx % len(reasoning_templates)],
|
| 145 |
+
}),
|
| 146 |
+
requested_channel=ch_name,
|
| 147 |
+
reasoning=reasoning_templates[step_idx % len(reasoning_templates)],
|
| 148 |
+
differential=[
|
| 149 |
+
{"name": name, "confidence": prob, "rank": i + 1}
|
| 150 |
+
for i, (name, prob) in enumerate(sorted_dist)
|
| 151 |
+
],
|
| 152 |
+
committed=False,
|
| 153 |
+
raw_response="(simulated)",
|
| 154 |
+
latency_ms=random.uniform(800, 3000),
|
| 155 |
+
entropy=belief.entropy,
|
| 156 |
+
information_gain=ig,
|
| 157 |
+
kl_divergence=kl,
|
| 158 |
+
expected_impact={
|
| 159 |
+
"if_positive": sorted_dist[0][0],
|
| 160 |
+
"if_negative": sorted_dist[1][0],
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
result.steps.append(step)
|
| 164 |
+
acquired.append(ch_name)
|
| 165 |
+
|
| 166 |
+
# Final commit step
|
| 167 |
+
final_probs = []
|
| 168 |
+
for i in range(n_candidates):
|
| 169 |
+
if i == gt_idx:
|
| 170 |
+
final_probs.append(0.65 + random.uniform(0, 0.15))
|
| 171 |
+
else:
|
| 172 |
+
final_probs.append(random.uniform(0.02, 0.12))
|
| 173 |
+
total = sum(final_probs)
|
| 174 |
+
final_probs = [p / total for p in final_probs]
|
| 175 |
+
final_dist = {case.candidates[i]: final_probs[i] for i in range(n_candidates)}
|
| 176 |
+
sorted_final = sorted(final_dist.items(), key=lambda x: -x[1])
|
| 177 |
+
|
| 178 |
+
final_belief = BeliefState(
|
| 179 |
+
step=len(requestable_names),
|
| 180 |
+
distribution=final_dist,
|
| 181 |
+
channel_acquired=None,
|
| 182 |
+
)
|
| 183 |
+
trajectory.states.append(final_belief)
|
| 184 |
+
|
| 185 |
+
final_ranking = [
|
| 186 |
+
{
|
| 187 |
+
"name": name,
|
| 188 |
+
"confidence": prob,
|
| 189 |
+
"rank": i + 1,
|
| 190 |
+
"key_evidence": f"Supported by evidence from acquired channels" if i == 0 else "Less consistent with findings",
|
| 191 |
+
}
|
| 192 |
+
for i, (name, prob) in enumerate(sorted_final)
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
commit_step = AcquisitionStep(
|
| 196 |
+
step=len(requestable_names),
|
| 197 |
+
tool_call=ToolCall(tool_name="commit_diagnosis", arguments={}),
|
| 198 |
+
requested_channel=None,
|
| 199 |
+
reasoning=f"After acquiring all available channels, the evidence strongly supports {sorted_final[0][0]}. "
|
| 200 |
+
f"Entropy reduced to {final_belief.entropy:.2f} bits. Committing diagnosis.",
|
| 201 |
+
differential=final_ranking,
|
| 202 |
+
committed=True,
|
| 203 |
+
raw_response="(simulated)",
|
| 204 |
+
latency_ms=random.uniform(500, 2000),
|
| 205 |
+
entropy=final_belief.entropy,
|
| 206 |
+
information_gain=trajectory.states[-2].entropy - final_belief.entropy if len(trajectory.states) >= 2 else 0,
|
| 207 |
+
kl_divergence=0.0,
|
| 208 |
+
)
|
| 209 |
+
result.steps.append(commit_step)
|
| 210 |
+
result.committed_early = False
|
| 211 |
+
result.final_ranking = final_ranking
|
| 212 |
+
result.acquired_channels = acquired
|
| 213 |
+
result.belief_trajectory = trajectory
|
| 214 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 215 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 216 |
+
result.total_latency_ms = sum(s.latency_ms for s in result.steps)
|
| 217 |
+
result.total_input_tokens = 0
|
| 218 |
+
result.total_output_tokens = 0
|
| 219 |
+
|
| 220 |
+
return result
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ============================================================
|
| 224 |
+
# Synthetic Demo Cases
|
| 225 |
+
# ============================================================
|
| 226 |
+
|
| 227 |
+
def _make_dummy_image(width=224, height=224, color=(180, 60, 60)) -> str:
|
| 228 |
+
img = Image.new("RGB", (width, height), color)
|
| 229 |
+
arr = np.array(img)
|
| 230 |
+
noise = np.random.randint(-20, 20, arr.shape, dtype=np.int16)
|
| 231 |
+
arr = np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8)
|
| 232 |
+
img = Image.fromarray(arr)
|
| 233 |
+
return encode_pil_image_to_base64(img)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
DEMO_CASES = {
|
| 237 |
+
"NEJM: Pulmonary Fibrosis": {
|
| 238 |
+
"description": (
|
| 239 |
+
"A 58-year-old man with progressive dyspnea and dry cough over 3 months. "
|
| 240 |
+
"30-pack-year smoking history, takes lisinopril for hypertension."
|
| 241 |
+
),
|
| 242 |
+
"case": lambda: MedicalCase(
|
| 243 |
+
case_id="demo_nejm_ipf",
|
| 244 |
+
dataset="nejm",
|
| 245 |
+
initial_channels={
|
| 246 |
+
"demographics": ChannelData(
|
| 247 |
+
name="demographics", channel_type="text",
|
| 248 |
+
description="Patient age, sex, and ethnicity",
|
| 249 |
+
value="A 58-year-old man", always_given=True, cost=0.0, tier="free",
|
| 250 |
+
),
|
| 251 |
+
"chief_complaint": ChannelData(
|
| 252 |
+
name="chief_complaint", channel_type="text",
|
| 253 |
+
description="Presenting symptoms and duration",
|
| 254 |
+
value="Progressive dyspnea and dry cough over the past 3 months.",
|
| 255 |
+
always_given=True, cost=0.0, tier="free",
|
| 256 |
+
),
|
| 257 |
+
"medical_history": ChannelData(
|
| 258 |
+
name="medical_history", channel_type="text",
|
| 259 |
+
description="Past medical conditions, medications, family and social history",
|
| 260 |
+
value="30-pack-year smoking history. No prior lung disease. Takes lisinopril for hypertension.",
|
| 261 |
+
always_given=True, cost=0.0, tier="free",
|
| 262 |
+
),
|
| 263 |
+
},
|
| 264 |
+
requestable_channels={
|
| 265 |
+
"exam_findings": ChannelData(
|
| 266 |
+
name="exam_findings", channel_type="text",
|
| 267 |
+
description="Physical examination results and observations",
|
| 268 |
+
value="Bibasilar crackles on auscultation. No clubbing. Oxygen saturation 92% on room air.",
|
| 269 |
+
cost=75.0, tier="cheap",
|
| 270 |
+
),
|
| 271 |
+
"investigations": ChannelData(
|
| 272 |
+
name="investigations", channel_type="text",
|
| 273 |
+
description="Laboratory values, prior imaging results, and test outcomes",
|
| 274 |
+
value="PFTs show restrictive pattern with reduced DLCO. CT chest shows bilateral ground-glass opacities with honeycombing in the lower lobes.",
|
| 275 |
+
cost=250.0, tier="moderate",
|
| 276 |
+
),
|
| 277 |
+
"image": ChannelData(
|
| 278 |
+
name="image", channel_type="image",
|
| 279 |
+
description="The primary diagnostic image (chest CT)",
|
| 280 |
+
value=_make_dummy_image(300, 300, (200, 200, 210)),
|
| 281 |
+
cost=800.0, tier="expensive",
|
| 282 |
+
),
|
| 283 |
+
},
|
| 284 |
+
candidates=[
|
| 285 |
+
"A. Idiopathic pulmonary fibrosis",
|
| 286 |
+
"B. Hypersensitivity pneumonitis",
|
| 287 |
+
"C. Sarcoidosis",
|
| 288 |
+
"D. Lung adenocarcinoma",
|
| 289 |
+
"E. ACE-inhibitor induced cough with incidental CT findings",
|
| 290 |
+
],
|
| 291 |
+
ground_truth="A. Idiopathic pulmonary fibrosis",
|
| 292 |
+
ground_truth_rank=0,
|
| 293 |
+
),
|
| 294 |
+
},
|
| 295 |
+
"Dermatology: Pigmented Lesion": {
|
| 296 |
+
"description": (
|
| 297 |
+
"A 62-year-old woman presents with a pigmented lesion on her left forearm. "
|
| 298 |
+
"The lesion is 8mm x 6mm. Clinical photograph provided."
|
| 299 |
+
),
|
| 300 |
+
"case": lambda: MedicalCase(
|
| 301 |
+
case_id="demo_midas_001",
|
| 302 |
+
dataset="midas",
|
| 303 |
+
initial_channels={
|
| 304 |
+
"clinical_30cm": ChannelData(
|
| 305 |
+
name="clinical_30cm", channel_type="image",
|
| 306 |
+
description="Clinical photograph at 30cm distance",
|
| 307 |
+
value=_make_dummy_image(224, 224, (180, 120, 100)),
|
| 308 |
+
always_given=True, cost=0.0, tier="free",
|
| 309 |
+
),
|
| 310 |
+
},
|
| 311 |
+
requestable_channels={
|
| 312 |
+
"patient_demographics": ChannelData(
|
| 313 |
+
name="patient_demographics", channel_type="text",
|
| 314 |
+
description="Patient age, sex, and Fitzpatrick skin type",
|
| 315 |
+
value="Age: 62; Sex: Female; Fitzpatrick skin type: III",
|
| 316 |
+
cost=0.0, tier="free",
|
| 317 |
+
),
|
| 318 |
+
"lesion_metadata": ChannelData(
|
| 319 |
+
name="lesion_metadata", channel_type="text",
|
| 320 |
+
description="Anatomic location, lesion length and width",
|
| 321 |
+
value="Anatomic location: Left forearm; Lesion length: 8mm; Lesion width: 6mm",
|
| 322 |
+
cost=25.0, tier="cheap",
|
| 323 |
+
),
|
| 324 |
+
"clinical_15cm": ChannelData(
|
| 325 |
+
name="clinical_15cm", channel_type="image",
|
| 326 |
+
description="Clinical photograph at 15cm distance (closer view)",
|
| 327 |
+
value=_make_dummy_image(224, 224, (170, 110, 90)),
|
| 328 |
+
cost=50.0, tier="moderate",
|
| 329 |
+
),
|
| 330 |
+
"dermoscopy": ChannelData(
|
| 331 |
+
name="dermoscopy", channel_type="image",
|
| 332 |
+
description="Dermoscopic image showing subsurface skin structures",
|
| 333 |
+
value=_make_dummy_image(224, 224, (100, 80, 60)),
|
| 334 |
+
cost=250.0, tier="expensive",
|
| 335 |
+
),
|
| 336 |
+
},
|
| 337 |
+
candidates=[
|
| 338 |
+
"Melanoma in situ",
|
| 339 |
+
"Dysplastic nevus",
|
| 340 |
+
"Basal cell carcinoma",
|
| 341 |
+
"Seborrheic keratosis",
|
| 342 |
+
"Solar lentigo",
|
| 343 |
+
],
|
| 344 |
+
ground_truth="Dysplastic nevus",
|
| 345 |
+
ground_truth_rank=1,
|
| 346 |
+
),
|
| 347 |
+
},
|
| 348 |
+
"Ophthalmology: Retinal Biomarkers (OLIVES)": {
|
| 349 |
+
"description": (
|
| 350 |
+
"A patient with diabetic macular edema (DME), 4 prior anti-VEGF injections, "
|
| 351 |
+
"32 weeks in treatment. Fundus photograph provided."
|
| 352 |
+
),
|
| 353 |
+
"case": lambda: MedicalCase(
|
| 354 |
+
case_id="demo_olives_P01",
|
| 355 |
+
dataset="olives",
|
| 356 |
+
initial_channels={
|
| 357 |
+
"disease_context": ChannelData(
|
| 358 |
+
name="disease_context", channel_type="text",
|
| 359 |
+
description="Disease type and treatment context",
|
| 360 |
+
value="Disease: Diabetic Macular Edema (DME). Prior anti-VEGF injections: 4. Weeks in treatment: 32.",
|
| 361 |
+
always_given=True, cost=0.0, tier="free",
|
| 362 |
+
),
|
| 363 |
+
},
|
| 364 |
+
requestable_channels={
|
| 365 |
+
"clinical_measurements": ChannelData(
|
| 366 |
+
name="clinical_measurements", channel_type="text",
|
| 367 |
+
description="Best Corrected Visual Acuity (BCVA) and Central Subfield Thickness (CST)",
|
| 368 |
+
value="BCVA: 20/60 (logMAR 0.48); CST: 385 um",
|
| 369 |
+
cost=20.0, tier="cheap",
|
| 370 |
+
),
|
| 371 |
+
"biomarker_hints": ChannelData(
|
| 372 |
+
name="biomarker_hints", channel_type="text",
|
| 373 |
+
description="Expert-graded presence of fundus-visible retinal biomarkers",
|
| 374 |
+
value="Hard Exudates: Present; Hemorrhage: Present; Microaneurysms: Present; Cotton Wool Spots: Not detected",
|
| 375 |
+
cost=100.0, tier="moderate",
|
| 376 |
+
),
|
| 377 |
+
"oct_scan": ChannelData(
|
| 378 |
+
name="oct_scan", channel_type="image",
|
| 379 |
+
description="OCT B-scan showing retinal cross-section",
|
| 380 |
+
value=_make_dummy_image(512, 128, (60, 60, 60)),
|
| 381 |
+
cost=300.0, tier="expensive",
|
| 382 |
+
),
|
| 383 |
+
"additional_oct": ChannelData(
|
| 384 |
+
name="additional_oct", channel_type="image",
|
| 385 |
+
description="Additional OCT B-scans from different retinal locations",
|
| 386 |
+
value=_make_dummy_image(512, 128, (50, 50, 55)),
|
| 387 |
+
cost=150.0, tier="very_expensive",
|
| 388 |
+
),
|
| 389 |
+
},
|
| 390 |
+
candidates=[
|
| 391 |
+
"Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms",
|
| 392 |
+
"Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Fluid Srf, Hard Exudates, Hemorrhage, Microaneurysms",
|
| 393 |
+
"Present biomarkers: Hard Exudates, Hemorrhage, Microaneurysms",
|
| 394 |
+
"Present biomarkers: Dril, Ez Disruption, Fluid Irf, Shrm",
|
| 395 |
+
"No biomarkers detected",
|
| 396 |
+
],
|
| 397 |
+
ground_truth="Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms",
|
| 398 |
+
ground_truth_rank=0,
|
| 399 |
+
),
|
| 400 |
+
},
|
| 401 |
+
"NEJM: Cardiac Case": {
|
| 402 |
+
"description": (
|
| 403 |
+
"A 45-year-old woman presents with sudden onset chest pain and shortness "
|
| 404 |
+
"of breath. She recently completed a long international flight."
|
| 405 |
+
),
|
| 406 |
+
"case": lambda: MedicalCase(
|
| 407 |
+
case_id="demo_nejm_pe",
|
| 408 |
+
dataset="nejm",
|
| 409 |
+
initial_channels={
|
| 410 |
+
"demographics": ChannelData(
|
| 411 |
+
name="demographics", channel_type="text",
|
| 412 |
+
description="Patient age, sex, and ethnicity",
|
| 413 |
+
value="A 45-year-old woman", always_given=True, cost=0.0, tier="free",
|
| 414 |
+
),
|
| 415 |
+
"chief_complaint": ChannelData(
|
| 416 |
+
name="chief_complaint", channel_type="text",
|
| 417 |
+
description="Presenting symptoms and duration",
|
| 418 |
+
value="Sudden onset chest pain and shortness of breath, started 2 hours ago after returning from a 14-hour international flight.",
|
| 419 |
+
always_given=True, cost=0.0, tier="free",
|
| 420 |
+
),
|
| 421 |
+
"medical_history": ChannelData(
|
| 422 |
+
name="medical_history", channel_type="text",
|
| 423 |
+
description="Past medical conditions, medications, family and social history",
|
| 424 |
+
value="On oral contraceptives for 5 years. BMI 32. No prior VTE. Mother had DVT at age 50.",
|
| 425 |
+
always_given=True, cost=0.0, tier="free",
|
| 426 |
+
),
|
| 427 |
+
},
|
| 428 |
+
requestable_channels={
|
| 429 |
+
"exam_findings": ChannelData(
|
| 430 |
+
name="exam_findings", channel_type="text",
|
| 431 |
+
description="Physical examination results and observations",
|
| 432 |
+
value="Tachycardic (HR 110), tachypneic (RR 24), SpO2 89% on room air. Right calf swollen and tender. JVP elevated. Loud P2 on cardiac auscultation.",
|
| 433 |
+
cost=75.0, tier="cheap",
|
| 434 |
+
),
|
| 435 |
+
"investigations": ChannelData(
|
| 436 |
+
name="investigations", channel_type="text",
|
| 437 |
+
description="Laboratory values, imaging results, and test outcomes",
|
| 438 |
+
value="D-dimer: 4200 ng/mL (markedly elevated). Troponin I: 0.15 ng/mL (mildly elevated). ABG: pH 7.48, PaO2 62 mmHg, PaCO2 28 mmHg. ECG: S1Q3T3 pattern, right axis deviation. CT pulmonary angiography: bilateral pulmonary emboli with right heart strain.",
|
| 439 |
+
cost=250.0, tier="moderate",
|
| 440 |
+
),
|
| 441 |
+
"image": ChannelData(
|
| 442 |
+
name="image", channel_type="image",
|
| 443 |
+
description="CT Pulmonary Angiography image",
|
| 444 |
+
value=_make_dummy_image(300, 300, (100, 100, 120)),
|
| 445 |
+
cost=800.0, tier="expensive",
|
| 446 |
+
),
|
| 447 |
+
},
|
| 448 |
+
candidates=[
|
| 449 |
+
"A. Pulmonary embolism",
|
| 450 |
+
"B. Acute myocardial infarction",
|
| 451 |
+
"C. Tension pneumothorax",
|
| 452 |
+
"D. Aortic dissection",
|
| 453 |
+
"E. Acute pericarditis",
|
| 454 |
+
],
|
| 455 |
+
ground_truth="A. Pulmonary embolism",
|
| 456 |
+
ground_truth_rank=0,
|
| 457 |
+
),
|
| 458 |
+
},
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# ============================================================
|
| 463 |
+
# Custom Case Builder
|
| 464 |
+
# ============================================================
|
| 465 |
+
|
| 466 |
+
def build_custom_case(
|
| 467 |
+
scenario_text: str,
|
| 468 |
+
candidates_text: str,
|
| 469 |
+
channel_1_name: str, channel_1_type: str, channel_1_value: str,
|
| 470 |
+
channel_2_name: str, channel_2_type: str, channel_2_value: str,
|
| 471 |
+
channel_3_name: str, channel_3_type: str, channel_3_value: str,
|
| 472 |
+
uploaded_image=None,
|
| 473 |
+
) -> MedicalCase:
|
| 474 |
+
"""Build a MedicalCase from user-provided custom inputs."""
|
| 475 |
+
candidates = [c.strip() for c in candidates_text.strip().split("\n") if c.strip()]
|
| 476 |
+
if not candidates:
|
| 477 |
+
candidates = ["Diagnosis A", "Diagnosis B", "Diagnosis C"]
|
| 478 |
+
|
| 479 |
+
initial_channels = {
|
| 480 |
+
"clinical_scenario": ChannelData(
|
| 481 |
+
name="clinical_scenario", channel_type="text",
|
| 482 |
+
description="The presenting clinical scenario",
|
| 483 |
+
value=scenario_text,
|
| 484 |
+
always_given=True, cost=0.0, tier="free",
|
| 485 |
+
),
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
if uploaded_image is not None:
|
| 489 |
+
img_b64 = encode_pil_image_to_base64(Image.fromarray(uploaded_image))
|
| 490 |
+
initial_channels["uploaded_image"] = ChannelData(
|
| 491 |
+
name="uploaded_image", channel_type="image",
|
| 492 |
+
description="Uploaded medical image",
|
| 493 |
+
value=img_b64, always_given=True, cost=0.0, tier="free",
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
requestable = {}
|
| 497 |
+
for name, ctype, value in [
|
| 498 |
+
(channel_1_name, channel_1_type, channel_1_value),
|
| 499 |
+
(channel_2_name, channel_2_type, channel_2_value),
|
| 500 |
+
(channel_3_name, channel_3_type, channel_3_value),
|
| 501 |
+
]:
|
| 502 |
+
name = name.strip()
|
| 503 |
+
value = value.strip()
|
| 504 |
+
if name and value:
|
| 505 |
+
key = name.lower().replace(" ", "_")
|
| 506 |
+
requestable[key] = ChannelData(
|
| 507 |
+
name=key, channel_type=ctype.lower(),
|
| 508 |
+
description=name,
|
| 509 |
+
value=value,
|
| 510 |
+
cost=100.0, tier="moderate",
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Register channel config so the agent can look it up
|
| 514 |
+
custom_config = {}
|
| 515 |
+
for name, ch in initial_channels.items():
|
| 516 |
+
custom_config[name] = {
|
| 517 |
+
"description": ch.description,
|
| 518 |
+
"type": ch.channel_type,
|
| 519 |
+
"always_given": True,
|
| 520 |
+
"tier": ch.tier,
|
| 521 |
+
"cost": ch.cost,
|
| 522 |
+
"order": 0,
|
| 523 |
+
}
|
| 524 |
+
for i, (name, ch) in enumerate(requestable.items()):
|
| 525 |
+
custom_config[name] = {
|
| 526 |
+
"description": ch.description,
|
| 527 |
+
"type": ch.channel_type,
|
| 528 |
+
"always_given": False,
|
| 529 |
+
"tier": ch.tier,
|
| 530 |
+
"cost": ch.cost,
|
| 531 |
+
"order": i + 1,
|
| 532 |
+
}
|
| 533 |
+
config.CHANNEL_CONFIGS["custom"] = custom_config
|
| 534 |
+
|
| 535 |
+
return MedicalCase(
|
| 536 |
+
case_id="custom_case",
|
| 537 |
+
dataset="custom",
|
| 538 |
+
initial_channels=initial_channels,
|
| 539 |
+
requestable_channels=requestable,
|
| 540 |
+
candidates=candidates,
|
| 541 |
+
ground_truth=candidates[0] if candidates else "",
|
| 542 |
+
ground_truth_rank=0,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# ============================================================
|
| 547 |
+
# Formatting Helpers
|
| 548 |
+
# ============================================================
|
| 549 |
+
|
| 550 |
+
def format_step_markdown(step_idx: int, step: AcquisitionStep, cumulative_cost: float) -> str:
|
| 551 |
+
"""Format a single acquisition step as rich markdown."""
|
| 552 |
+
lines = []
|
| 553 |
+
|
| 554 |
+
if step.committed:
|
| 555 |
+
lines.append(f"### Step {step_idx + 1}: COMMITTED TO DIAGNOSIS")
|
| 556 |
+
lines.append("")
|
| 557 |
+
lines.append(f"**Reasoning:** {step.reasoning}")
|
| 558 |
+
lines.append("")
|
| 559 |
+
if step.differential:
|
| 560 |
+
lines.append("**Final Ranking:**")
|
| 561 |
+
for d in step.differential:
|
| 562 |
+
conf = d.get("confidence", 0)
|
| 563 |
+
bar = render_bar(conf)
|
| 564 |
+
evidence = d.get("key_evidence", "")
|
| 565 |
+
lines.append(f"- **{d['name']}** — {conf:.1%} {bar}")
|
| 566 |
+
if evidence:
|
| 567 |
+
lines.append(f" - *Evidence:* {evidence}")
|
| 568 |
+
else:
|
| 569 |
+
lines.append(f"### Step {step_idx + 1}: Requested `{step.requested_channel}`")
|
| 570 |
+
lines.append("")
|
| 571 |
+
lines.append(f"**Reasoning:** {step.reasoning}")
|
| 572 |
+
lines.append("")
|
| 573 |
+
|
| 574 |
+
if step.differential:
|
| 575 |
+
lines.append("**Current Differential:**")
|
| 576 |
+
for d in step.differential:
|
| 577 |
+
conf = d.get("confidence", 0)
|
| 578 |
+
bar = render_bar(conf)
|
| 579 |
+
lines.append(f"- {d['name']} — {conf:.1%} {bar}")
|
| 580 |
+
|
| 581 |
+
if step.expected_impact:
|
| 582 |
+
lines.append("")
|
| 583 |
+
lines.append("**Expected Impact:**")
|
| 584 |
+
pos = step.expected_impact.get("if_positive", "N/A")
|
| 585 |
+
neg = step.expected_impact.get("if_negative", "N/A")
|
| 586 |
+
lines.append(f"- If positive/abnormal: *{pos}*")
|
| 587 |
+
lines.append(f"- If negative/normal: *{neg}*")
|
| 588 |
+
|
| 589 |
+
lines.append("")
|
| 590 |
+
lines.append("**Information Metrics:**")
|
| 591 |
+
lines.append(f"- Entropy: **{step.entropy:.3f}** bits")
|
| 592 |
+
if step.information_gain:
|
| 593 |
+
lines.append(f"- Information Gain: **{step.information_gain:.3f}** bits")
|
| 594 |
+
if step.kl_divergence:
|
| 595 |
+
lines.append(f"- KL Divergence: **{step.kl_divergence:.3f}** bits")
|
| 596 |
+
lines.append(f"- Latency: {step.latency_ms:.0f}ms")
|
| 597 |
+
lines.append(f"- Cumulative Cost: ${cumulative_cost:,.0f}")
|
| 598 |
+
lines.append("")
|
| 599 |
+
lines.append("---")
|
| 600 |
+
|
| 601 |
+
return "\n".join(lines)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def render_bar(value: float, width: int = 20) -> str:
|
| 605 |
+
"""Render a text-based progress bar."""
|
| 606 |
+
filled = int(value * width)
|
| 607 |
+
return "`" + "\u2588" * filled + "\u2591" * (width - filled) + "`"
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def format_entropy_table(trajectory: BeliefTrajectory) -> str:
|
| 611 |
+
"""Format entropy trajectory as a markdown table."""
|
| 612 |
+
if not trajectory or not trajectory.states:
|
| 613 |
+
return "*No belief trajectory recorded.*"
|
| 614 |
+
|
| 615 |
+
lines = ["| Step | Channel | Entropy (bits) | Info Gain | Cumulative IG |"]
|
| 616 |
+
lines.append("|------|---------|---------------|-----------|---------------|")
|
| 617 |
+
|
| 618 |
+
cumulative_ig = 0.0
|
| 619 |
+
for i, state in enumerate(trajectory.states):
|
| 620 |
+
ch = state.channel_acquired or "initial/commit"
|
| 621 |
+
ig = 0.0
|
| 622 |
+
if i > 0:
|
| 623 |
+
ig = trajectory.states[i - 1].entropy - state.entropy
|
| 624 |
+
cumulative_ig += ig
|
| 625 |
+
lines.append(
|
| 626 |
+
f"| {i} | {ch} | {state.entropy:.3f} | "
|
| 627 |
+
f"{ig:+.3f} | {cumulative_ig:.3f} |"
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
lines.append("")
|
| 631 |
+
lines.append(f"**Information Efficiency:** {trajectory.information_efficiency:.1%}")
|
| 632 |
+
lines.append(f"**Total Information Gain:** {trajectory.total_information_gain:.3f} bits")
|
| 633 |
+
|
| 634 |
+
return "\n".join(lines)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def format_summary(result: AgentResult, case: MedicalCase) -> str:
|
| 638 |
+
"""Format the overall result summary."""
|
| 639 |
+
lines = []
|
| 640 |
+
lines.append("## Summary")
|
| 641 |
+
lines.append("")
|
| 642 |
+
|
| 643 |
+
if result.final_ranking:
|
| 644 |
+
top = result.final_ranking[0]
|
| 645 |
+
top_name = top["name"].strip().lower()
|
| 646 |
+
gt_name = case.ground_truth.strip().lower()
|
| 647 |
+
# Fuzzy match: handle "Pulmonary embolism" vs "A. Pulmonary embolism"
|
| 648 |
+
correct = top_name == gt_name or top_name in gt_name or gt_name in top_name
|
| 649 |
+
icon = "correct" if correct else "incorrect"
|
| 650 |
+
lines.append(f"**Top Diagnosis:** {top['name']} ({top['confidence']:.1%})")
|
| 651 |
+
lines.append(f"**Ground Truth:** {case.ground_truth}")
|
| 652 |
+
lines.append(f"**Result:** {icon}")
|
| 653 |
+
else:
|
| 654 |
+
lines.append("*No diagnosis produced.*")
|
| 655 |
+
|
| 656 |
+
lines.append("")
|
| 657 |
+
lines.append(f"**Channels Acquired:** {len(result.acquired_channels)} / {len(case.requestable_channels)}")
|
| 658 |
+
if result.acquired_channels:
|
| 659 |
+
lines.append(f"**Acquisition Order:** {' -> '.join(result.acquired_channels)}")
|
| 660 |
+
lines.append(f"**Committed Early:** {'Yes' if result.committed_early else 'No'}")
|
| 661 |
+
lines.append(f"**Total Acquisition Cost:** ${result.acquisition_cost:,.0f}")
|
| 662 |
+
lines.append(f"**Total Case Cost:** ${result.total_case_cost:,.0f}")
|
| 663 |
+
lines.append(f"**Total Latency:** {result.total_latency_ms:,.0f}ms")
|
| 664 |
+
lines.append(f"**Tokens Used:** {result.total_input_tokens:,} in / {result.total_output_tokens:,} out")
|
| 665 |
+
|
| 666 |
+
return "\n".join(lines)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
# ============================================================
|
| 670 |
+
# Main Agent Runner (for Gradio)
|
| 671 |
+
# ============================================================
|
| 672 |
+
|
| 673 |
+
def run_agent_on_case(
|
| 674 |
+
case: MedicalCase,
|
| 675 |
+
backend: str,
|
| 676 |
+
context_mode: str,
|
| 677 |
+
) -> tuple[str, str, str]:
|
| 678 |
+
"""
|
| 679 |
+
Run the agent on a case and return formatted markdown outputs.
|
| 680 |
+
|
| 681 |
+
Returns: (steps_markdown, entropy_table, summary_markdown)
|
| 682 |
+
"""
|
| 683 |
+
if backend == "simulated (no API key)":
|
| 684 |
+
result = _simulate_agent_on_case(case)
|
| 685 |
+
model_name = "simulated"
|
| 686 |
+
else:
|
| 687 |
+
try:
|
| 688 |
+
client = create_client(backend)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
return (
|
| 691 |
+
f"**Error creating {backend} client:** {e}\n\n"
|
| 692 |
+
"Make sure your API key is set in `.env` or environment variables. "
|
| 693 |
+
"Or select **simulated (no API key)** to see a demo trace.",
|
| 694 |
+
"", "",
|
| 695 |
+
)
|
| 696 |
+
agent = ActiveMedAgent(
|
| 697 |
+
client,
|
| 698 |
+
prompt_variant="A",
|
| 699 |
+
budget=None, # NO BUDGET CONSTRAINT
|
| 700 |
+
context_mode=context_mode if context_mode != "adaptive" else None,
|
| 701 |
+
)
|
| 702 |
+
try:
|
| 703 |
+
result = agent.diagnose(case)
|
| 704 |
+
except Exception as e:
|
| 705 |
+
return f"**Error running agent:** {e}", "", ""
|
| 706 |
+
model_name = client.model
|
| 707 |
+
|
| 708 |
+
# Format step-by-step reasoning
|
| 709 |
+
steps_parts = []
|
| 710 |
+
steps_parts.append("# Agent Reasoning Trace\n")
|
| 711 |
+
steps_parts.append(f"**Case:** {case.case_id} | **Dataset:** {case.dataset} | **Backend:** {model_name}\n")
|
| 712 |
+
steps_parts.append(f"**Candidates:** {', '.join(case.candidates)}\n")
|
| 713 |
+
|
| 714 |
+
initial_info = format_acquired_info(case.get_text_context([]))
|
| 715 |
+
steps_parts.append(f"**Initial Information:**\n{initial_info}\n")
|
| 716 |
+
steps_parts.append("---\n")
|
| 717 |
+
|
| 718 |
+
cumulative_cost = case.get_initial_cost()
|
| 719 |
+
for i, step in enumerate(result.steps):
|
| 720 |
+
if step.requested_channel:
|
| 721 |
+
cumulative_cost += case.get_channel_cost(step.requested_channel)
|
| 722 |
+
steps_parts.append(format_step_markdown(i, step, cumulative_cost))
|
| 723 |
+
|
| 724 |
+
steps_md = "\n".join(steps_parts)
|
| 725 |
+
|
| 726 |
+
# Format entropy trajectory
|
| 727 |
+
entropy_md = ""
|
| 728 |
+
if result.belief_trajectory:
|
| 729 |
+
entropy_md = format_entropy_table(result.belief_trajectory)
|
| 730 |
+
|
| 731 |
+
# Format summary
|
| 732 |
+
summary_md = format_summary(result, case)
|
| 733 |
+
|
| 734 |
+
return steps_md, entropy_md, summary_md
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
# ============================================================
|
| 738 |
+
# Gradio Event Handlers
|
| 739 |
+
# ============================================================
|
| 740 |
+
|
| 741 |
+
def on_demo_case_selected(case_name: str) -> tuple[str, str]:
|
| 742 |
+
"""When a demo case is selected, show its description and candidates."""
|
| 743 |
+
if case_name in DEMO_CASES:
|
| 744 |
+
info = DEMO_CASES[case_name]
|
| 745 |
+
case = info["case"]()
|
| 746 |
+
desc = info["description"]
|
| 747 |
+
cands = "\n".join(case.candidates)
|
| 748 |
+
channels = []
|
| 749 |
+
for name, ch in case.requestable_channels.items():
|
| 750 |
+
channels.append(f"- **{name}** ({ch.tier}, ${ch.cost:,.0f}): {ch.description}")
|
| 751 |
+
ch_str = "\n".join(channels)
|
| 752 |
+
return (
|
| 753 |
+
f"{desc}\n\n**Available channels to acquire:**\n{ch_str}",
|
| 754 |
+
cands,
|
| 755 |
+
)
|
| 756 |
+
return "", ""
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def run_demo_case(case_name: str, backend: str, context_mode: str):
|
| 760 |
+
"""Run agent on a selected demo case."""
|
| 761 |
+
if case_name not in DEMO_CASES:
|
| 762 |
+
return "Please select a demo case.", "", ""
|
| 763 |
+
|
| 764 |
+
case = DEMO_CASES[case_name]["case"]()
|
| 765 |
+
return run_agent_on_case(case, backend, context_mode)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def run_custom_case(
|
| 769 |
+
scenario: str, candidates: str,
|
| 770 |
+
ch1_name: str, ch1_type: str, ch1_value: str,
|
| 771 |
+
ch2_name: str, ch2_type: str, ch2_value: str,
|
| 772 |
+
ch3_name: str, ch3_type: str, ch3_value: str,
|
| 773 |
+
uploaded_image,
|
| 774 |
+
backend: str, context_mode: str,
|
| 775 |
+
):
|
| 776 |
+
"""Run agent on a custom user-defined case."""
|
| 777 |
+
if not scenario.strip():
|
| 778 |
+
return "Please enter a clinical scenario.", "", ""
|
| 779 |
+
|
| 780 |
+
case = build_custom_case(
|
| 781 |
+
scenario, candidates,
|
| 782 |
+
ch1_name, ch1_type, ch1_value,
|
| 783 |
+
ch2_name, ch2_type, ch2_value,
|
| 784 |
+
ch3_name, ch3_type, ch3_value,
|
| 785 |
+
uploaded_image,
|
| 786 |
+
)
|
| 787 |
+
return run_agent_on_case(case, backend, context_mode)
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
# ============================================================
|
| 791 |
+
# Gradio UI
|
| 792 |
+
# ============================================================
|
| 793 |
+
|
| 794 |
+
def create_app():
|
| 795 |
+
with gr.Blocks(
|
| 796 |
+
title="ActiveMedAgent Interactive Demo",
|
| 797 |
+
) as app:
|
| 798 |
+
gr.Markdown(
|
| 799 |
+
"""
|
| 800 |
+
# ActiveMedAgent: Learned Information Acquisition for Medical Diagnosis
|
| 801 |
+
**Interactive Demo** — Watch the agent reason step-by-step, acquire information channels,
|
| 802 |
+
and track entropy reduction. **No budget constraint** — the agent decides when to stop.
|
| 803 |
+
""",
|
| 804 |
+
elem_classes="header-text",
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# Build backend choices: simulation always available, real backends if keys exist
|
| 808 |
+
backend_choices = ["simulated (no API key)"] + AVAILABLE_BACKENDS
|
| 809 |
+
default_backend = AVAILABLE_BACKENDS[0] if AVAILABLE_BACKENDS else "simulated (no API key)"
|
| 810 |
+
|
| 811 |
+
with gr.Row():
|
| 812 |
+
backend = gr.Dropdown(
|
| 813 |
+
choices=backend_choices,
|
| 814 |
+
value=default_backend,
|
| 815 |
+
label="VLM Backend",
|
| 816 |
+
info="Select 'simulated' to see the demo without API keys",
|
| 817 |
+
scale=1,
|
| 818 |
+
)
|
| 819 |
+
context_mode = gr.Dropdown(
|
| 820 |
+
choices=["adaptive", "full", "condensed"],
|
| 821 |
+
value="adaptive",
|
| 822 |
+
label="Context Mode",
|
| 823 |
+
info="How the agent manages conversation history",
|
| 824 |
+
scale=1,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
with gr.Tabs():
|
| 828 |
+
# ---- Tab 1: Demo Cases ----
|
| 829 |
+
with gr.TabItem("Demo Cases"):
|
| 830 |
+
gr.Markdown("Select a pre-built clinical scenario and run the agent.")
|
| 831 |
+
with gr.Row():
|
| 832 |
+
case_selector = gr.Dropdown(
|
| 833 |
+
choices=list(DEMO_CASES.keys()),
|
| 834 |
+
label="Select Case",
|
| 835 |
+
scale=2,
|
| 836 |
+
)
|
| 837 |
+
run_demo_btn = gr.Button("Run Agent", variant="primary", scale=1)
|
| 838 |
+
|
| 839 |
+
case_description = gr.Markdown(label="Case Description")
|
| 840 |
+
case_candidates = gr.Textbox(label="Candidate Diagnoses", lines=3, interactive=False)
|
| 841 |
+
|
| 842 |
+
case_selector.change(
|
| 843 |
+
fn=on_demo_case_selected,
|
| 844 |
+
inputs=[case_selector],
|
| 845 |
+
outputs=[case_description, case_candidates],
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
with gr.Row():
|
| 849 |
+
with gr.Column(scale=2):
|
| 850 |
+
demo_steps = gr.Markdown(
|
| 851 |
+
label="Reasoning Steps",
|
| 852 |
+
elem_classes="reasoning-box",
|
| 853 |
+
)
|
| 854 |
+
with gr.Column(scale=1):
|
| 855 |
+
demo_summary = gr.Markdown(label="Summary")
|
| 856 |
+
demo_entropy = gr.Markdown(label="Entropy Trajectory")
|
| 857 |
+
|
| 858 |
+
run_demo_btn.click(
|
| 859 |
+
fn=run_demo_case,
|
| 860 |
+
inputs=[case_selector, backend, context_mode],
|
| 861 |
+
outputs=[demo_steps, demo_entropy, demo_summary],
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
# ---- Tab 2: Custom Case ----
|
| 865 |
+
with gr.TabItem("Custom Case"):
|
| 866 |
+
gr.Markdown(
|
| 867 |
+
"Define your own clinical scenario, candidate diagnoses, "
|
| 868 |
+
"and information channels the agent can request."
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
with gr.Row():
|
| 872 |
+
with gr.Column():
|
| 873 |
+
custom_scenario = gr.Textbox(
|
| 874 |
+
label="Clinical Scenario",
|
| 875 |
+
placeholder="A 35-year-old woman presents with...",
|
| 876 |
+
lines=4,
|
| 877 |
+
)
|
| 878 |
+
custom_candidates = gr.Textbox(
|
| 879 |
+
label="Candidate Diagnoses (one per line)",
|
| 880 |
+
placeholder="A. Diagnosis one\nB. Diagnosis two\nC. Diagnosis three",
|
| 881 |
+
lines=5,
|
| 882 |
+
)
|
| 883 |
+
custom_image = gr.Image(
|
| 884 |
+
label="Upload Medical Image (optional)",
|
| 885 |
+
type="numpy",
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
with gr.Column():
|
| 889 |
+
gr.Markdown("### Requestable Information Channels")
|
| 890 |
+
gr.Markdown("Define up to 3 channels the agent can request.")
|
| 891 |
+
|
| 892 |
+
with gr.Group():
|
| 893 |
+
gr.Markdown("**Channel 1:**")
|
| 894 |
+
ch1_name = gr.Textbox(label="Name", value="Exam Findings", scale=1)
|
| 895 |
+
ch1_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
|
| 896 |
+
ch1_value = gr.Textbox(label="Content (what the agent receives)", lines=2,
|
| 897 |
+
placeholder="Physical exam: temperature 38.5C, ...")
|
| 898 |
+
|
| 899 |
+
with gr.Group():
|
| 900 |
+
gr.Markdown("**Channel 2:**")
|
| 901 |
+
ch2_name = gr.Textbox(label="Name", value="Lab Results", scale=1)
|
| 902 |
+
ch2_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
|
| 903 |
+
ch2_value = gr.Textbox(label="Content", lines=2,
|
| 904 |
+
placeholder="WBC 12,000, CRP elevated, ...")
|
| 905 |
+
|
| 906 |
+
with gr.Group():
|
| 907 |
+
gr.Markdown("**Channel 3:**")
|
| 908 |
+
ch3_name = gr.Textbox(label="Name", value="Imaging", scale=1)
|
| 909 |
+
ch3_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
|
| 910 |
+
ch3_value = gr.Textbox(label="Content", lines=2,
|
| 911 |
+
placeholder="CT scan shows...")
|
| 912 |
+
|
| 913 |
+
run_custom_btn = gr.Button("Run Agent on Custom Case", variant="primary")
|
| 914 |
+
|
| 915 |
+
with gr.Row():
|
| 916 |
+
with gr.Column(scale=2):
|
| 917 |
+
custom_steps = gr.Markdown(
|
| 918 |
+
label="Reasoning Steps",
|
| 919 |
+
elem_classes="reasoning-box",
|
| 920 |
+
)
|
| 921 |
+
with gr.Column(scale=1):
|
| 922 |
+
custom_summary = gr.Markdown(label="Summary")
|
| 923 |
+
custom_entropy = gr.Markdown(label="Entropy Trajectory")
|
| 924 |
+
|
| 925 |
+
run_custom_btn.click(
|
| 926 |
+
fn=run_custom_case,
|
| 927 |
+
inputs=[
|
| 928 |
+
custom_scenario, custom_candidates,
|
| 929 |
+
ch1_name, ch1_type, ch1_value,
|
| 930 |
+
ch2_name, ch2_type, ch2_value,
|
| 931 |
+
ch3_name, ch3_type, ch3_value,
|
| 932 |
+
custom_image,
|
| 933 |
+
backend, context_mode,
|
| 934 |
+
],
|
| 935 |
+
outputs=[custom_steps, custom_entropy, custom_summary],
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# ---- Tab 3: How It Works ----
|
| 939 |
+
with gr.TabItem("How It Works"):
|
| 940 |
+
gr.Markdown("""
|
| 941 |
+
## ActiveMedAgent Architecture
|
| 942 |
+
|
| 943 |
+
### Tool-Use Acquisition Loop
|
| 944 |
+
The agent uses native VLM function calling (not regex parsing) with two tools:
|
| 945 |
+
1. **`request_information`** — Request one data channel, providing reasoning, current differential with calibrated probabilities, and expected impact
|
| 946 |
+
2. **`commit_diagnosis`** — Submit final ranked diagnosis when confident
|
| 947 |
+
|
| 948 |
+
### No Budget Constraint
|
| 949 |
+
The agent acquires as many channels as it needs (0 to all). It stops when:
|
| 950 |
+
- It calls `commit_diagnosis` (self-determined confidence)
|
| 951 |
+
- Information-theoretic stopping criteria trigger (convergence, confirmed dominance, or diminishing returns)
|
| 952 |
+
- All channels are exhausted
|
| 953 |
+
|
| 954 |
+
### Information-Theoretic Metrics
|
| 955 |
+
At each step, the system tracks:
|
| 956 |
+
- **Shannon Entropy** H(p) — diagnostic uncertainty in bits
|
| 957 |
+
- **Information Gain** — entropy reduction from each acquisition
|
| 958 |
+
- **KL Divergence** — how much the belief distribution shifted
|
| 959 |
+
- **Expected Information Gain (EIG)** — predicted value of the next channel
|
| 960 |
+
- **Value of Information (VoI)** — whether continuing to acquire is worthwhile
|
| 961 |
+
|
| 962 |
+
### Context Management
|
| 963 |
+
- **Full Mode**: Multi-turn conversation with complete history (for capable models)
|
| 964 |
+
- **Condensed Mode**: Fresh single-turn call each step with compressed state log (for weaker models)
|
| 965 |
+
- **Adaptive**: Auto-selects based on model capability
|
| 966 |
+
|
| 967 |
+
### Stopping Criteria
|
| 968 |
+
1. **Convergence**: Last acquisition < 0.05 bits of IG
|
| 969 |
+
2. **Confirmed Dominance**: Top diagnosis > 90% probability with > 40% gap (after 2+ acquisitions)
|
| 970 |
+
3. **Diminishing Returns**: Last 2 acquisitions both < 0.1 bits IG
|
| 971 |
+
""")
|
| 972 |
+
|
| 973 |
+
return app
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
# ============================================================
|
| 977 |
+
# Entry Point
|
| 978 |
+
# ============================================================
|
| 979 |
+
|
| 980 |
+
def main():
|
| 981 |
+
parser = argparse.ArgumentParser(description="ActiveMedAgent Interactive Demo")
|
| 982 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to serve on")
|
| 983 |
+
parser.add_argument("--backend", default="openai", choices=["openai", "anthropic", "together"])
|
| 984 |
+
parser.add_argument("--share", action="store_true", help="Create a public Gradio link")
|
| 985 |
+
args = parser.parse_args()
|
| 986 |
+
|
| 987 |
+
app = create_app()
|
| 988 |
+
app.launch(
|
| 989 |
+
server_port=args.port,
|
| 990 |
+
share=args.share,
|
| 991 |
+
theme=gr.themes.Soft(),
|
| 992 |
+
css="""
|
| 993 |
+
.reasoning-box { font-size: 14px; }
|
| 994 |
+
.header-text { text-align: center; margin-bottom: 10px; }
|
| 995 |
+
""",
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
if __name__ == "__main__":
|
| 1000 |
+
main()
|
baselines.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Additional Baselines for ACL/EMNLP Submission.
|
| 3 |
+
|
| 4 |
+
Five baselines that answer: "Does active sequential acquisition actually
|
| 5 |
+
help over simpler strategies?"
|
| 6 |
+
|
| 7 |
+
1. AllAtOnce: Give the VLM all text channels upfront (no sequential reasoning)
|
| 8 |
+
2. RandomOrder: Acquire channels in random order (same budget as active)
|
| 9 |
+
3. ClinicalGuidelineOrder: Follow standard clinical workflow ordering
|
| 10 |
+
4. ReactBaseline: Free-form ReAct-style reasoning (no structured tool calls)
|
| 11 |
+
5. CoTSinglePass: Chain-of-thought with all info in one shot
|
| 12 |
+
|
| 13 |
+
All baselines use the same VLM and produce AgentResult objects for
|
| 14 |
+
direct comparison with the active agent.
|
| 15 |
+
"""
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import time
|
| 21 |
+
from dataclasses import field
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
import config
|
| 26 |
+
from api_client import BaseVLMClient, VLMResponse
|
| 27 |
+
from agent import (
|
| 28 |
+
ActiveMedAgent, AgentResult, AcquisitionStep,
|
| 29 |
+
SYSTEM_PROMPT_FULL, SYSTEM_PROMPT_FINAL,
|
| 30 |
+
)
|
| 31 |
+
from datasets.base import MedicalCase, ChannelData
|
| 32 |
+
from tools import ToolCall, constrain_tools_for_step
|
| 33 |
+
from information_gain import BeliefState, BeliefTrajectory, compute_entropy
|
| 34 |
+
from prompts import format_acquired_info
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ================================================================
|
| 40 |
+
# Clinical Guideline Orderings
|
| 41 |
+
# ================================================================
|
| 42 |
+
|
| 43 |
+
CLINICAL_GUIDELINE_ORDER = {
|
| 44 |
+
"nejm": [
|
| 45 |
+
"demographics",
|
| 46 |
+
"chief_complaint",
|
| 47 |
+
"medical_history",
|
| 48 |
+
"exam_findings",
|
| 49 |
+
"investigations",
|
| 50 |
+
"image",
|
| 51 |
+
],
|
| 52 |
+
"midas": [
|
| 53 |
+
"patient_demographics",
|
| 54 |
+
"lesion_metadata",
|
| 55 |
+
"clinical_15cm",
|
| 56 |
+
"dermoscopy",
|
| 57 |
+
],
|
| 58 |
+
"olives": [
|
| 59 |
+
"clinical_measurements",
|
| 60 |
+
"biomarker_hints",
|
| 61 |
+
"oct_scan",
|
| 62 |
+
"additional_oct",
|
| 63 |
+
],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ================================================================
|
| 68 |
+
# Baseline 1: All-At-Once
|
| 69 |
+
# ================================================================
|
| 70 |
+
|
| 71 |
+
class AllAtOnceBaseline:
|
| 72 |
+
"""
|
| 73 |
+
Give the VLM all available text/image channels at once.
|
| 74 |
+
|
| 75 |
+
Tests whether sequential reasoning matters or if the VLM can
|
| 76 |
+
handle everything in a single pass with all evidence.
|
| 77 |
+
Different from Oracle: Oracle uses the experiment evaluation
|
| 78 |
+
framework; this uses the same prompt structure as the active agent.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, client: BaseVLMClient, prompt_variant: str = "A"):
|
| 82 |
+
self.client = client
|
| 83 |
+
self.prompt_variant = prompt_variant
|
| 84 |
+
|
| 85 |
+
def diagnose(self, case: MedicalCase) -> AgentResult:
|
| 86 |
+
all_channels = list(case.requestable_channels.keys())
|
| 87 |
+
|
| 88 |
+
result = AgentResult(
|
| 89 |
+
case_id=case.case_id,
|
| 90 |
+
dataset=case.dataset,
|
| 91 |
+
prompt_variant=self.prompt_variant,
|
| 92 |
+
backend=self.client.model,
|
| 93 |
+
budget=len(all_channels),
|
| 94 |
+
acquired_channels=all_channels,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
images = case.get_all_images_up_to(all_channels)
|
| 98 |
+
text_context = case.get_text_context(all_channels)
|
| 99 |
+
acquired_str = format_acquired_info(text_context)
|
| 100 |
+
candidates_str = "\n".join(
|
| 101 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
system_prompt = (
|
| 105 |
+
"You are a medical diagnostic agent. You are given ALL available "
|
| 106 |
+
"clinical information at once. Analyze everything and provide your "
|
| 107 |
+
"final ranked diagnosis.\n\n"
|
| 108 |
+
"You MUST use the commit_diagnosis tool to submit your answer.\n"
|
| 109 |
+
"Include ALL candidate diagnoses with calibrated probabilities "
|
| 110 |
+
"summing to 1.0 and key_evidence for each."
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
user_text = (
|
| 114 |
+
f"All available clinical information:\n{acquired_str}\n\n"
|
| 115 |
+
f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
|
| 116 |
+
f"Analyze all information and submit your final diagnosis "
|
| 117 |
+
f"using commit_diagnosis."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
commit_tools = constrain_tools_for_step(budget_remaining=0)
|
| 121 |
+
|
| 122 |
+
t0 = time.time()
|
| 123 |
+
response = self.client.call_with_retry(
|
| 124 |
+
system_prompt=system_prompt,
|
| 125 |
+
user_text=user_text,
|
| 126 |
+
images=images,
|
| 127 |
+
temperature=config.TEMPERATURE,
|
| 128 |
+
max_tokens=config.MAX_TOKENS,
|
| 129 |
+
tools=commit_tools,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
result.total_latency_ms = response.latency_ms
|
| 133 |
+
result.total_input_tokens = response.input_tokens
|
| 134 |
+
result.total_output_tokens = response.output_tokens
|
| 135 |
+
|
| 136 |
+
if response.tool_call and response.tool_call.tool_name == "commit_diagnosis":
|
| 137 |
+
args = response.tool_call.arguments
|
| 138 |
+
ranked = args.get("ranked_diagnoses", [])
|
| 139 |
+
ranking = []
|
| 140 |
+
for i, entry in enumerate(ranked):
|
| 141 |
+
ranking.append({
|
| 142 |
+
"name": entry.get("name", ""),
|
| 143 |
+
"confidence": entry.get("confidence", 0.0),
|
| 144 |
+
"rank": i + 1,
|
| 145 |
+
"key_evidence": entry.get("key_evidence", ""),
|
| 146 |
+
})
|
| 147 |
+
ranking.sort(key=lambda x: x["confidence"], reverse=True)
|
| 148 |
+
for i, entry in enumerate(ranking):
|
| 149 |
+
entry["rank"] = i + 1
|
| 150 |
+
result.final_ranking = ranking
|
| 151 |
+
else:
|
| 152 |
+
result.final_ranking = _extract_ranking_from_text(
|
| 153 |
+
response.text, case.candidates
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
result.final_raw_response = response.text
|
| 157 |
+
result.acquisition_cost = case.get_acquisition_cost(all_channels)
|
| 158 |
+
result.total_case_cost = case.get_total_cost(all_channels)
|
| 159 |
+
return result
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ================================================================
|
| 163 |
+
# Baseline 2: Random Order Acquisition
|
| 164 |
+
# ================================================================
|
| 165 |
+
|
| 166 |
+
class RandomOrderBaseline:
|
| 167 |
+
"""
|
| 168 |
+
Acquire channels in random order, then diagnose.
|
| 169 |
+
|
| 170 |
+
Uses the same active agent architecture but overrides channel
|
| 171 |
+
selection with random choice. This isolates the value of
|
| 172 |
+
strategic ordering from the value of having more information.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
client: BaseVLMClient,
|
| 178 |
+
prompt_variant: str = "A",
|
| 179 |
+
budget: int = None,
|
| 180 |
+
n_trials: int = 3,
|
| 181 |
+
seed: int = 42,
|
| 182 |
+
):
|
| 183 |
+
self.client = client
|
| 184 |
+
self.prompt_variant = prompt_variant
|
| 185 |
+
self.budget = budget
|
| 186 |
+
self.n_trials = n_trials
|
| 187 |
+
self.seed = seed
|
| 188 |
+
|
| 189 |
+
def diagnose(self, case: MedicalCase) -> AgentResult:
|
| 190 |
+
"""Run with random order. If n_trials > 1, returns best trial."""
|
| 191 |
+
rng = random.Random(self.seed + hash(case.case_id))
|
| 192 |
+
requestable = list(case.requestable_channels.keys())
|
| 193 |
+
max_acq = self.budget if self.budget is not None else len(requestable)
|
| 194 |
+
|
| 195 |
+
best_result = None
|
| 196 |
+
best_mrr = -1
|
| 197 |
+
|
| 198 |
+
for trial in range(self.n_trials):
|
| 199 |
+
order = list(requestable)
|
| 200 |
+
rng.shuffle(order)
|
| 201 |
+
acquired = order[:max_acq]
|
| 202 |
+
|
| 203 |
+
agent = ActiveMedAgent(
|
| 204 |
+
self.client, self.prompt_variant, budget=0,
|
| 205 |
+
)
|
| 206 |
+
result = AgentResult(
|
| 207 |
+
case_id=case.case_id,
|
| 208 |
+
dataset=case.dataset,
|
| 209 |
+
prompt_variant=self.prompt_variant,
|
| 210 |
+
backend=self.client.model,
|
| 211 |
+
budget=max_acq,
|
| 212 |
+
acquired_channels=acquired,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
|
| 216 |
+
result.final_ranking = final_ranking
|
| 217 |
+
result.final_raw_response = resp.text
|
| 218 |
+
result.total_latency_ms = resp.latency_ms
|
| 219 |
+
result.total_input_tokens = resp.input_tokens
|
| 220 |
+
result.total_output_tokens = resp.output_tokens
|
| 221 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 222 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 223 |
+
|
| 224 |
+
if self.n_trials == 1:
|
| 225 |
+
return result
|
| 226 |
+
|
| 227 |
+
# Pick the trial with highest top-1 confidence (proxy for quality)
|
| 228 |
+
top_conf = final_ranking[0]["confidence"] if final_ranking else 0
|
| 229 |
+
if top_conf > best_mrr:
|
| 230 |
+
best_mrr = top_conf
|
| 231 |
+
best_result = result
|
| 232 |
+
|
| 233 |
+
return best_result
|
| 234 |
+
|
| 235 |
+
def diagnose_single_random(
|
| 236 |
+
self, case: MedicalCase, seed: int = None
|
| 237 |
+
) -> AgentResult:
|
| 238 |
+
"""Single random trial (for aggregate statistics)."""
|
| 239 |
+
rng = random.Random(seed or self.seed)
|
| 240 |
+
requestable = list(case.requestable_channels.keys())
|
| 241 |
+
max_acq = self.budget if self.budget is not None else len(requestable)
|
| 242 |
+
order = list(requestable)
|
| 243 |
+
rng.shuffle(order)
|
| 244 |
+
acquired = order[:max_acq]
|
| 245 |
+
|
| 246 |
+
agent = ActiveMedAgent(
|
| 247 |
+
self.client, self.prompt_variant, budget=0,
|
| 248 |
+
)
|
| 249 |
+
result = AgentResult(
|
| 250 |
+
case_id=case.case_id,
|
| 251 |
+
dataset=case.dataset,
|
| 252 |
+
prompt_variant=self.prompt_variant,
|
| 253 |
+
backend=self.client.model,
|
| 254 |
+
budget=max_acq,
|
| 255 |
+
acquired_channels=acquired,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
|
| 259 |
+
result.final_ranking = final_ranking
|
| 260 |
+
result.final_raw_response = resp.text
|
| 261 |
+
result.total_latency_ms = resp.latency_ms
|
| 262 |
+
result.total_input_tokens = resp.input_tokens
|
| 263 |
+
result.total_output_tokens = resp.output_tokens
|
| 264 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 265 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 266 |
+
return result
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ================================================================
|
| 270 |
+
# Baseline 3: Clinical Guideline Order
|
| 271 |
+
# ================================================================
|
| 272 |
+
|
| 273 |
+
class ClinicalGuidelineBaseline:
|
| 274 |
+
"""
|
| 275 |
+
Acquire channels in standard clinical workflow order.
|
| 276 |
+
|
| 277 |
+
Tests whether the VLM's learned ordering improves over the
|
| 278 |
+
conventional clinical approach (history -> exam -> labs -> imaging).
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
client: BaseVLMClient,
|
| 284 |
+
prompt_variant: str = "A",
|
| 285 |
+
budget: int = None,
|
| 286 |
+
):
|
| 287 |
+
self.client = client
|
| 288 |
+
self.prompt_variant = prompt_variant
|
| 289 |
+
self.budget = budget
|
| 290 |
+
|
| 291 |
+
def diagnose(self, case: MedicalCase) -> AgentResult:
|
| 292 |
+
guideline_order = CLINICAL_GUIDELINE_ORDER.get(case.dataset, [])
|
| 293 |
+
|
| 294 |
+
# Filter to channels actually available in this case
|
| 295 |
+
available = set(case.requestable_channels.keys())
|
| 296 |
+
order = [ch for ch in guideline_order if ch in available]
|
| 297 |
+
# Append any remaining channels not in the guideline
|
| 298 |
+
for ch in case.requestable_channels.keys():
|
| 299 |
+
if ch not in order:
|
| 300 |
+
order.append(ch)
|
| 301 |
+
|
| 302 |
+
max_acq = self.budget if self.budget is not None else len(order)
|
| 303 |
+
acquired = order[:max_acq]
|
| 304 |
+
|
| 305 |
+
agent = ActiveMedAgent(
|
| 306 |
+
self.client, self.prompt_variant, budget=0,
|
| 307 |
+
)
|
| 308 |
+
result = AgentResult(
|
| 309 |
+
case_id=case.case_id,
|
| 310 |
+
dataset=case.dataset,
|
| 311 |
+
prompt_variant=self.prompt_variant,
|
| 312 |
+
backend=self.client.model,
|
| 313 |
+
budget=max_acq,
|
| 314 |
+
acquired_channels=acquired,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
|
| 318 |
+
result.final_ranking = final_ranking
|
| 319 |
+
result.final_raw_response = resp.text
|
| 320 |
+
result.total_latency_ms = resp.latency_ms
|
| 321 |
+
result.total_input_tokens = resp.input_tokens
|
| 322 |
+
result.total_output_tokens = resp.output_tokens
|
| 323 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 324 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 325 |
+
return result
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ================================================================
|
| 329 |
+
# Baseline 4: ReAct-Style Free-Form Reasoning
|
| 330 |
+
# ================================================================
|
| 331 |
+
|
| 332 |
+
class ReactBaseline:
|
| 333 |
+
"""
|
| 334 |
+
ReAct-style baseline: the VLM reasons in free text and requests
|
| 335 |
+
channels via text (not structured tool calls).
|
| 336 |
+
|
| 337 |
+
Tests whether the structured tool-use architecture improves over
|
| 338 |
+
free-form reasoning + regex parsing (the dominant approach in
|
| 339 |
+
prior medical agent work).
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
client: BaseVLMClient,
|
| 345 |
+
prompt_variant: str = "A",
|
| 346 |
+
budget: int = None,
|
| 347 |
+
):
|
| 348 |
+
self.client = client
|
| 349 |
+
self.prompt_variant = prompt_variant
|
| 350 |
+
self.budget = budget
|
| 351 |
+
|
| 352 |
+
def diagnose(self, case: MedicalCase) -> AgentResult:
|
| 353 |
+
max_steps = len(case.requestable_names)
|
| 354 |
+
if self.budget is not None:
|
| 355 |
+
max_steps = min(max_steps, self.budget)
|
| 356 |
+
|
| 357 |
+
result = AgentResult(
|
| 358 |
+
case_id=case.case_id,
|
| 359 |
+
dataset=case.dataset,
|
| 360 |
+
prompt_variant=self.prompt_variant,
|
| 361 |
+
backend=self.client.model,
|
| 362 |
+
budget=max_steps,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
acquired = []
|
| 366 |
+
dataset_channel_config = config.CHANNEL_CONFIGS.get(case.dataset, {})
|
| 367 |
+
|
| 368 |
+
system_prompt = (
|
| 369 |
+
"You are a medical diagnostic agent using a Thought-Action-Observation loop.\n\n"
|
| 370 |
+
"At each step:\n"
|
| 371 |
+
"1. THOUGHT: Reason about what you know and what you're uncertain about\n"
|
| 372 |
+
"2. ACTION: Either REQUEST[channel_name] to get more info, or "
|
| 373 |
+
"COMMIT[diagnosis1 > diagnosis2 > ...] to submit your final ranking\n"
|
| 374 |
+
"3. You will receive an OBSERVATION with the requested data\n\n"
|
| 375 |
+
"Be strategic about which information to request. Stop when additional "
|
| 376 |
+
"information is unlikely to change your diagnosis.\n\n"
|
| 377 |
+
"Format your response EXACTLY as:\n"
|
| 378 |
+
"THOUGHT: ...\n"
|
| 379 |
+
"ACTION: REQUEST[channel_name] or COMMIT[ranked diagnoses with probabilities]"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
initial_context = format_acquired_info(case.get_text_context([]))
|
| 383 |
+
candidates_str = "\n".join(
|
| 384 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Build channel descriptions
|
| 388 |
+
channel_desc_lines = []
|
| 389 |
+
for name, ch in case.requestable_channels.items():
|
| 390 |
+
channel_desc_lines.append(
|
| 391 |
+
f" - {name}: {ch.description} (cost: ${ch.cost:,.0f})"
|
| 392 |
+
)
|
| 393 |
+
channel_desc = "\n".join(channel_desc_lines)
|
| 394 |
+
|
| 395 |
+
conversation_text = (
|
| 396 |
+
f"Initial information:\n{initial_context}\n\n"
|
| 397 |
+
f"Candidate diagnoses:\n{candidates_str}\n\n"
|
| 398 |
+
f"Available channels:\n{channel_desc}\n"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
images = case.get_initial_images()
|
| 402 |
+
|
| 403 |
+
for step_idx in range(max_steps):
|
| 404 |
+
available = [n for n in case.requestable_names if n not in acquired]
|
| 405 |
+
if not available:
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
user_text = conversation_text
|
| 409 |
+
if acquired:
|
| 410 |
+
acq_context = format_acquired_info(case.get_text_context(acquired))
|
| 411 |
+
user_text += (
|
| 412 |
+
f"\n\nAcquired information so far:\n{acq_context}\n\n"
|
| 413 |
+
f"Remaining channels: {', '.join(available)}\n"
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
response = self.client.call_with_retry(
|
| 417 |
+
system_prompt=system_prompt,
|
| 418 |
+
user_text=user_text,
|
| 419 |
+
images=images,
|
| 420 |
+
temperature=config.TEMPERATURE,
|
| 421 |
+
max_tokens=config.MAX_TOKENS,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
result.total_latency_ms += response.latency_ms
|
| 425 |
+
result.total_input_tokens += response.input_tokens
|
| 426 |
+
result.total_output_tokens += response.output_tokens
|
| 427 |
+
|
| 428 |
+
text = response.text
|
| 429 |
+
|
| 430 |
+
# Parse COMMIT
|
| 431 |
+
commit_match = re.search(
|
| 432 |
+
r"COMMIT\[(.+?)\]", text, re.DOTALL
|
| 433 |
+
)
|
| 434 |
+
if commit_match:
|
| 435 |
+
result.committed_early = True
|
| 436 |
+
result.final_ranking = self._parse_commit_text(
|
| 437 |
+
commit_match.group(1), case.candidates
|
| 438 |
+
)
|
| 439 |
+
result.final_raw_response = text
|
| 440 |
+
|
| 441 |
+
step = AcquisitionStep(
|
| 442 |
+
step=step_idx,
|
| 443 |
+
tool_call=None,
|
| 444 |
+
requested_channel=None,
|
| 445 |
+
reasoning=_extract_thought(text),
|
| 446 |
+
differential=result.final_ranking,
|
| 447 |
+
committed=True,
|
| 448 |
+
raw_response=text,
|
| 449 |
+
latency_ms=response.latency_ms,
|
| 450 |
+
)
|
| 451 |
+
result.steps.append(step)
|
| 452 |
+
break
|
| 453 |
+
|
| 454 |
+
# Parse REQUEST
|
| 455 |
+
request_match = re.search(
|
| 456 |
+
r"REQUEST\[(\w+)\]", text, re.IGNORECASE
|
| 457 |
+
)
|
| 458 |
+
if request_match:
|
| 459 |
+
requested = request_match.group(1).strip().lower()
|
| 460 |
+
matched = _match_channel_name(requested, available)
|
| 461 |
+
if matched is None:
|
| 462 |
+
matched = available[0]
|
| 463 |
+
|
| 464 |
+
acquired.append(matched)
|
| 465 |
+
result.acquired_channels.append(matched)
|
| 466 |
+
|
| 467 |
+
# Add new images if the channel is an image
|
| 468 |
+
ch = case.get_channel(matched)
|
| 469 |
+
if ch and ch.channel_type == "image" and ch.value:
|
| 470 |
+
if isinstance(ch.value, list):
|
| 471 |
+
images.extend(ch.value)
|
| 472 |
+
else:
|
| 473 |
+
images.append(ch.value)
|
| 474 |
+
|
| 475 |
+
step = AcquisitionStep(
|
| 476 |
+
step=step_idx,
|
| 477 |
+
tool_call=None,
|
| 478 |
+
requested_channel=matched,
|
| 479 |
+
reasoning=_extract_thought(text),
|
| 480 |
+
differential=[],
|
| 481 |
+
committed=False,
|
| 482 |
+
raw_response=text,
|
| 483 |
+
latency_ms=response.latency_ms,
|
| 484 |
+
)
|
| 485 |
+
result.steps.append(step)
|
| 486 |
+
else:
|
| 487 |
+
# No parseable action — fallback to first available
|
| 488 |
+
matched = available[0]
|
| 489 |
+
acquired.append(matched)
|
| 490 |
+
result.acquired_channels.append(matched)
|
| 491 |
+
|
| 492 |
+
step = AcquisitionStep(
|
| 493 |
+
step=step_idx,
|
| 494 |
+
tool_call=None,
|
| 495 |
+
requested_channel=matched,
|
| 496 |
+
reasoning=f"(unparseable response, fallback to {matched})",
|
| 497 |
+
differential=[],
|
| 498 |
+
committed=False,
|
| 499 |
+
raw_response=text,
|
| 500 |
+
latency_ms=response.latency_ms,
|
| 501 |
+
)
|
| 502 |
+
result.steps.append(step)
|
| 503 |
+
|
| 504 |
+
# Final diagnosis if not committed
|
| 505 |
+
if not result.committed_early or not result.final_ranking:
|
| 506 |
+
agent = ActiveMedAgent(self.client, self.prompt_variant, budget=0)
|
| 507 |
+
final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
|
| 508 |
+
result.final_ranking = final_ranking
|
| 509 |
+
result.final_raw_response = resp.text
|
| 510 |
+
result.total_latency_ms += resp.latency_ms
|
| 511 |
+
result.total_input_tokens += resp.input_tokens
|
| 512 |
+
result.total_output_tokens += resp.output_tokens
|
| 513 |
+
|
| 514 |
+
result.acquired_channels = acquired
|
| 515 |
+
result.acquisition_cost = case.get_acquisition_cost(acquired)
|
| 516 |
+
result.total_case_cost = case.get_total_cost(acquired)
|
| 517 |
+
return result
|
| 518 |
+
|
| 519 |
+
def _parse_commit_text(
|
| 520 |
+
self, commit_str: str, candidates: list[str]
|
| 521 |
+
) -> list[dict]:
|
| 522 |
+
"""Parse a COMMIT[...] string into a ranking."""
|
| 523 |
+
ranking = []
|
| 524 |
+
# Try "Diagnosis (0.XX)" pattern
|
| 525 |
+
pattern = r"([^>,(]+?)\s*\(?([\d.]+)\)?"
|
| 526 |
+
parts = re.split(r"\s*>\s*", commit_str)
|
| 527 |
+
for i, part in enumerate(parts):
|
| 528 |
+
match = re.match(pattern, part.strip())
|
| 529 |
+
if match:
|
| 530 |
+
name = match.group(1).strip()
|
| 531 |
+
try:
|
| 532 |
+
conf = float(match.group(2))
|
| 533 |
+
except (ValueError, IndexError):
|
| 534 |
+
conf = max(0.1, 1.0 - i * 0.2)
|
| 535 |
+
ranking.append({
|
| 536 |
+
"name": name,
|
| 537 |
+
"confidence": conf,
|
| 538 |
+
"rank": i + 1,
|
| 539 |
+
})
|
| 540 |
+
|
| 541 |
+
if not ranking:
|
| 542 |
+
ranking = _extract_ranking_from_text(commit_str, candidates)
|
| 543 |
+
|
| 544 |
+
ranking.sort(key=lambda x: x.get("confidence", 0), reverse=True)
|
| 545 |
+
for i, entry in enumerate(ranking):
|
| 546 |
+
entry["rank"] = i + 1
|
| 547 |
+
return ranking
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
# ================================================================
|
| 551 |
+
# Baseline 5: Chain-of-Thought Single Pass
|
| 552 |
+
# ================================================================
|
| 553 |
+
|
| 554 |
+
class CoTSinglePassBaseline:
|
| 555 |
+
"""
|
| 556 |
+
Standard chain-of-thought: give the VLM all available info and
|
| 557 |
+
ask it to reason step by step in a single pass.
|
| 558 |
+
|
| 559 |
+
No multi-turn reasoning, no tool use, no acquisition decisions.
|
| 560 |
+
Just: "Here's everything, think step by step, give me your answer."
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
def __init__(self, client: BaseVLMClient, prompt_variant: str = "A"):
|
| 564 |
+
self.client = client
|
| 565 |
+
self.prompt_variant = prompt_variant
|
| 566 |
+
|
| 567 |
+
def diagnose(self, case: MedicalCase) -> AgentResult:
|
| 568 |
+
all_channels = list(case.requestable_channels.keys())
|
| 569 |
+
|
| 570 |
+
result = AgentResult(
|
| 571 |
+
case_id=case.case_id,
|
| 572 |
+
dataset=case.dataset,
|
| 573 |
+
prompt_variant=self.prompt_variant,
|
| 574 |
+
backend=self.client.model,
|
| 575 |
+
budget=len(all_channels),
|
| 576 |
+
acquired_channels=all_channels,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
images = case.get_all_images_up_to(all_channels)
|
| 580 |
+
text_context = case.get_text_context(all_channels)
|
| 581 |
+
acquired_str = format_acquired_info(text_context)
|
| 582 |
+
candidates_str = "\n".join(
|
| 583 |
+
f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
system_prompt = (
|
| 587 |
+
"You are a medical diagnostic expert. Analyze the following "
|
| 588 |
+
"clinical information and provide your diagnosis.\n\n"
|
| 589 |
+
"Think step by step:\n"
|
| 590 |
+
"1. Summarize the key findings\n"
|
| 591 |
+
"2. Consider each candidate diagnosis\n"
|
| 592 |
+
"3. Identify supporting and refuting evidence for each\n"
|
| 593 |
+
"4. Rank all candidates with calibrated probabilities (0-1, sum to 1)\n\n"
|
| 594 |
+
"Format your final answer as:\n"
|
| 595 |
+
"RANKING:\n"
|
| 596 |
+
"1. DiagnosisName (confidence: X.XX) - key evidence\n"
|
| 597 |
+
"2. DiagnosisName (confidence: X.XX) - key evidence\n"
|
| 598 |
+
"..."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
user_text = (
|
| 602 |
+
f"Clinical information:\n{acquired_str}\n\n"
|
| 603 |
+
f"Candidate diagnoses:\n{candidates_str}\n\n"
|
| 604 |
+
f"Think step by step and provide your ranked diagnosis."
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
response = self.client.call_with_retry(
|
| 608 |
+
system_prompt=system_prompt,
|
| 609 |
+
user_text=user_text,
|
| 610 |
+
images=images,
|
| 611 |
+
temperature=config.TEMPERATURE,
|
| 612 |
+
max_tokens=config.MAX_TOKENS,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
result.total_latency_ms = response.latency_ms
|
| 616 |
+
result.total_input_tokens = response.input_tokens
|
| 617 |
+
result.total_output_tokens = response.output_tokens
|
| 618 |
+
result.final_raw_response = response.text
|
| 619 |
+
result.final_ranking = _extract_ranking_from_text(
|
| 620 |
+
response.text, case.candidates
|
| 621 |
+
)
|
| 622 |
+
result.acquisition_cost = case.get_acquisition_cost(all_channels)
|
| 623 |
+
result.total_case_cost = case.get_total_cost(all_channels)
|
| 624 |
+
return result
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# ================================================================
|
| 628 |
+
# Helpers
|
| 629 |
+
# ================================================================
|
| 630 |
+
|
| 631 |
+
def _extract_thought(text: str) -> str:
|
| 632 |
+
"""Extract THOUGHT section from ReAct response."""
|
| 633 |
+
match = re.search(r"THOUGHT:\s*(.+?)(?=ACTION:|$)", text, re.DOTALL)
|
| 634 |
+
if match:
|
| 635 |
+
return match.group(1).strip()[:500]
|
| 636 |
+
return text[:200]
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def _match_channel_name(requested: str, available: list[str]) -> str | None:
|
| 640 |
+
"""Fuzzy match a requested channel name."""
|
| 641 |
+
requested = requested.lower().strip().replace(" ", "_")
|
| 642 |
+
if requested in available:
|
| 643 |
+
return requested
|
| 644 |
+
for ch in available:
|
| 645 |
+
if requested in ch or ch in requested:
|
| 646 |
+
return ch
|
| 647 |
+
return None
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def _extract_ranking_from_text(
|
| 651 |
+
text: str, candidates: list[str]
|
| 652 |
+
) -> list[dict]:
|
| 653 |
+
"""Extract ranking from free-form text response."""
|
| 654 |
+
ranking = []
|
| 655 |
+
pattern = (
|
| 656 |
+
r"(\d+)\.\s*(.+?)\s*"
|
| 657 |
+
r"\((?:confidence|probability|prob|conf):\s*([\d.]+)\)"
|
| 658 |
+
)
|
| 659 |
+
matches = re.findall(pattern, text, re.IGNORECASE)
|
| 660 |
+
if matches:
|
| 661 |
+
for rank_str, name, conf_str in matches:
|
| 662 |
+
try:
|
| 663 |
+
ranking.append({
|
| 664 |
+
"name": name.strip(),
|
| 665 |
+
"confidence": float(conf_str),
|
| 666 |
+
"rank": int(rank_str),
|
| 667 |
+
})
|
| 668 |
+
except ValueError:
|
| 669 |
+
continue
|
| 670 |
+
if not ranking and candidates:
|
| 671 |
+
for i, candidate in enumerate(candidates):
|
| 672 |
+
if candidate.lower() in text.lower():
|
| 673 |
+
ranking.append({
|
| 674 |
+
"name": candidate,
|
| 675 |
+
"confidence": max(0.1, 1.0 - i * 0.2),
|
| 676 |
+
"rank": len(ranking) + 1,
|
| 677 |
+
})
|
| 678 |
+
ranking.sort(key=lambda x: x.get("confidence", 0), reverse=True)
|
| 679 |
+
for i, entry in enumerate(ranking):
|
| 680 |
+
entry["rank"] = i + 1
|
| 681 |
+
return ranking
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# ================================================================
|
| 685 |
+
# Registry
|
| 686 |
+
# ================================================================
|
| 687 |
+
|
| 688 |
+
BASELINE_REGISTRY = {
|
| 689 |
+
"all_at_once": AllAtOnceBaseline,
|
| 690 |
+
"random_order": RandomOrderBaseline,
|
| 691 |
+
"clinical_guideline": ClinicalGuidelineBaseline,
|
| 692 |
+
"react": ReactBaseline,
|
| 693 |
+
"cot_single_pass": CoTSinglePassBaseline,
|
| 694 |
+
}
|
calibration.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calibration Analysis for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
Measures whether the VLM's reported probabilities match empirical
|
| 5 |
+
accuracy. Key analyses for the ACL/EMNLP submission:
|
| 6 |
+
|
| 7 |
+
1. Reliability Diagram: binned confidence vs accuracy
|
| 8 |
+
2. Expected Calibration Error (ECE): scalar miscalibration summary
|
| 9 |
+
3. Temperature Scaling: post-hoc recalibration on held-out set
|
| 10 |
+
4. Robustness to Miscalibration: does the method work with noisy probs?
|
| 11 |
+
5. Per-Step Calibration: is calibration better/worse at different steps?
|
| 12 |
+
"""
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
from scipy.optimize import minimize_scalar
|
| 21 |
+
|
| 22 |
+
from agent import AgentResult, AcquisitionStep
|
| 23 |
+
from datasets.base import MedicalCase
|
| 24 |
+
from evaluation import evaluate_single_case, CaseMetrics
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ================================================================
|
| 30 |
+
# Core Calibration Metrics
|
| 31 |
+
# ================================================================
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class CalibrationBin:
|
| 35 |
+
"""A single bin in a reliability diagram."""
|
| 36 |
+
bin_lower: float
|
| 37 |
+
bin_upper: float
|
| 38 |
+
bin_center: float
|
| 39 |
+
avg_confidence: float
|
| 40 |
+
avg_accuracy: float
|
| 41 |
+
count: int
|
| 42 |
+
gap: float # |avg_confidence - avg_accuracy|
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class CalibrationResult:
|
| 47 |
+
"""Full calibration analysis for a set of predictions."""
|
| 48 |
+
ece: float # Expected Calibration Error
|
| 49 |
+
mce: float # Maximum Calibration Error
|
| 50 |
+
ace: float # Average Calibration Error
|
| 51 |
+
bins: list[CalibrationBin]
|
| 52 |
+
n_predictions: int
|
| 53 |
+
mean_confidence: float
|
| 54 |
+
mean_accuracy: float
|
| 55 |
+
overconfidence_ratio: float # Fraction of bins where conf > acc
|
| 56 |
+
brier_score: float # Brier score (MSE of probabilities)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def compute_calibration(
|
| 60 |
+
confidences: list[float],
|
| 61 |
+
correctness: list[bool],
|
| 62 |
+
n_bins: int = 10,
|
| 63 |
+
) -> CalibrationResult:
|
| 64 |
+
"""
|
| 65 |
+
Compute calibration metrics from confidence-correctness pairs.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
confidences: Model's stated probability for its top prediction
|
| 69 |
+
correctness: Whether the top prediction was correct
|
| 70 |
+
n_bins: Number of bins for the reliability diagram
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
CalibrationResult with ECE, MCE, bins, etc.
|
| 74 |
+
"""
|
| 75 |
+
if not confidences:
|
| 76 |
+
return CalibrationResult(
|
| 77 |
+
ece=0, mce=0, ace=0, bins=[], n_predictions=0,
|
| 78 |
+
mean_confidence=0, mean_accuracy=0,
|
| 79 |
+
overconfidence_ratio=0, brier_score=0,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
confs = np.array(confidences, dtype=np.float64)
|
| 83 |
+
accs = np.array(correctness, dtype=np.float64)
|
| 84 |
+
n = len(confs)
|
| 85 |
+
|
| 86 |
+
bin_boundaries = np.linspace(0.0, 1.0, n_bins + 1)
|
| 87 |
+
bins = []
|
| 88 |
+
ece = 0.0
|
| 89 |
+
mce = 0.0
|
| 90 |
+
overconf_count = 0
|
| 91 |
+
|
| 92 |
+
for i in range(n_bins):
|
| 93 |
+
lower = bin_boundaries[i]
|
| 94 |
+
upper = bin_boundaries[i + 1]
|
| 95 |
+
mask = (confs > lower) & (confs <= upper)
|
| 96 |
+
count = mask.sum()
|
| 97 |
+
|
| 98 |
+
if count == 0:
|
| 99 |
+
bins.append(CalibrationBin(
|
| 100 |
+
bin_lower=lower, bin_upper=upper,
|
| 101 |
+
bin_center=(lower + upper) / 2,
|
| 102 |
+
avg_confidence=0, avg_accuracy=0,
|
| 103 |
+
count=0, gap=0,
|
| 104 |
+
))
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
avg_conf = confs[mask].mean()
|
| 108 |
+
avg_acc = accs[mask].mean()
|
| 109 |
+
gap = abs(avg_conf - avg_acc)
|
| 110 |
+
|
| 111 |
+
ece += (count / n) * gap
|
| 112 |
+
mce = max(mce, gap)
|
| 113 |
+
|
| 114 |
+
if avg_conf > avg_acc:
|
| 115 |
+
overconf_count += 1
|
| 116 |
+
|
| 117 |
+
bins.append(CalibrationBin(
|
| 118 |
+
bin_lower=lower, bin_upper=upper,
|
| 119 |
+
bin_center=(lower + upper) / 2,
|
| 120 |
+
avg_confidence=float(avg_conf),
|
| 121 |
+
avg_accuracy=float(avg_acc),
|
| 122 |
+
count=int(count),
|
| 123 |
+
gap=float(gap),
|
| 124 |
+
))
|
| 125 |
+
|
| 126 |
+
non_empty_bins = [b for b in bins if b.count > 0]
|
| 127 |
+
ace = np.mean([b.gap for b in non_empty_bins]) if non_empty_bins else 0.0
|
| 128 |
+
|
| 129 |
+
# Brier score
|
| 130 |
+
brier = np.mean((confs - accs) ** 2)
|
| 131 |
+
|
| 132 |
+
return CalibrationResult(
|
| 133 |
+
ece=float(ece),
|
| 134 |
+
mce=float(mce),
|
| 135 |
+
ace=float(ace),
|
| 136 |
+
bins=bins,
|
| 137 |
+
n_predictions=n,
|
| 138 |
+
mean_confidence=float(confs.mean()),
|
| 139 |
+
mean_accuracy=float(accs.mean()),
|
| 140 |
+
overconfidence_ratio=overconf_count / len(non_empty_bins) if non_empty_bins else 0,
|
| 141 |
+
brier_score=float(brier),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ================================================================
|
| 146 |
+
# Extract Predictions from Agent Results
|
| 147 |
+
# ================================================================
|
| 148 |
+
|
| 149 |
+
def extract_predictions(
|
| 150 |
+
results: list[AgentResult],
|
| 151 |
+
cases: list[MedicalCase],
|
| 152 |
+
) -> tuple[list[float], list[bool]]:
|
| 153 |
+
"""
|
| 154 |
+
Extract (confidence, correctness) pairs from agent results.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
confidences: top-1 stated probability
|
| 158 |
+
correctness: whether top-1 matches ground truth
|
| 159 |
+
"""
|
| 160 |
+
confidences = []
|
| 161 |
+
correctness = []
|
| 162 |
+
|
| 163 |
+
for result, case in zip(results, cases):
|
| 164 |
+
if not result.final_ranking:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
top = result.final_ranking[0]
|
| 168 |
+
conf = top.get("confidence", 0.0)
|
| 169 |
+
name = top.get("name", "").strip().lower()
|
| 170 |
+
gt = case.ground_truth.strip().lower()
|
| 171 |
+
|
| 172 |
+
correct = name == gt or name in gt or gt in name
|
| 173 |
+
|
| 174 |
+
confidences.append(conf)
|
| 175 |
+
correctness.append(correct)
|
| 176 |
+
|
| 177 |
+
return confidences, correctness
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def extract_per_step_predictions(
|
| 181 |
+
results: list[AgentResult],
|
| 182 |
+
cases: list[MedicalCase],
|
| 183 |
+
) -> dict[int, tuple[list[float], list[bool]]]:
|
| 184 |
+
"""
|
| 185 |
+
Extract predictions at each acquisition step.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
{step_idx: (confidences, correctness)}
|
| 189 |
+
"""
|
| 190 |
+
step_data: dict[int, tuple[list, list]] = {}
|
| 191 |
+
|
| 192 |
+
for result, case in zip(results, cases):
|
| 193 |
+
gt = case.ground_truth.strip().lower()
|
| 194 |
+
|
| 195 |
+
for step in result.steps:
|
| 196 |
+
if not step.differential:
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
idx = step.step
|
| 200 |
+
if idx not in step_data:
|
| 201 |
+
step_data[idx] = ([], [])
|
| 202 |
+
|
| 203 |
+
top = max(step.differential, key=lambda d: d.get("confidence", 0))
|
| 204 |
+
conf = top.get("confidence", 0.0)
|
| 205 |
+
name = top.get("name", "").strip().lower()
|
| 206 |
+
correct = name == gt or name in gt or gt in name
|
| 207 |
+
|
| 208 |
+
step_data[idx][0].append(conf)
|
| 209 |
+
step_data[idx][1].append(correct)
|
| 210 |
+
|
| 211 |
+
return step_data
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ================================================================
|
| 215 |
+
# Temperature Scaling
|
| 216 |
+
# ================================================================
|
| 217 |
+
|
| 218 |
+
def temperature_scale(
|
| 219 |
+
confidences: list[float],
|
| 220 |
+
correctness: list[bool],
|
| 221 |
+
candidates_per_case: list[int] = None,
|
| 222 |
+
) -> tuple[float, float]:
|
| 223 |
+
"""
|
| 224 |
+
Find optimal temperature T that minimizes ECE on held-out data.
|
| 225 |
+
|
| 226 |
+
Temperature scaling: p_calibrated = softmax(logit(p) / T)
|
| 227 |
+
For single top-1 probability, we use the simplified version:
|
| 228 |
+
logit = log(p / (1 - p))
|
| 229 |
+
scaled_logit = logit / T
|
| 230 |
+
p_scaled = sigmoid(scaled_logit)
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
confidences: Raw model confidences
|
| 234 |
+
correctness: Whether predictions were correct
|
| 235 |
+
candidates_per_case: Number of candidates per case (for proper scaling)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
(optimal_temperature, calibrated_ece)
|
| 239 |
+
"""
|
| 240 |
+
confs = np.array(confidences, dtype=np.float64)
|
| 241 |
+
accs = np.array(correctness, dtype=np.float64)
|
| 242 |
+
|
| 243 |
+
# Clip to avoid log(0)
|
| 244 |
+
confs = np.clip(confs, 1e-6, 1 - 1e-6)
|
| 245 |
+
logits = np.log(confs / (1 - confs))
|
| 246 |
+
|
| 247 |
+
def ece_at_temperature(T):
|
| 248 |
+
scaled_logits = logits / T
|
| 249 |
+
scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
|
| 250 |
+
# Compute ECE
|
| 251 |
+
n_bins = 10
|
| 252 |
+
bins = np.linspace(0, 1, n_bins + 1)
|
| 253 |
+
ece = 0.0
|
| 254 |
+
n = len(scaled_confs)
|
| 255 |
+
for i in range(n_bins):
|
| 256 |
+
mask = (scaled_confs > bins[i]) & (scaled_confs <= bins[i + 1])
|
| 257 |
+
if mask.sum() == 0:
|
| 258 |
+
continue
|
| 259 |
+
bin_conf = scaled_confs[mask].mean()
|
| 260 |
+
bin_acc = accs[mask].mean()
|
| 261 |
+
ece += (mask.sum() / n) * abs(bin_conf - bin_acc)
|
| 262 |
+
return ece
|
| 263 |
+
|
| 264 |
+
result = minimize_scalar(
|
| 265 |
+
ece_at_temperature,
|
| 266 |
+
bounds=(0.1, 10.0),
|
| 267 |
+
method="bounded",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
optimal_T = result.x
|
| 271 |
+
calibrated_ece = ece_at_temperature(optimal_T)
|
| 272 |
+
|
| 273 |
+
return float(optimal_T), float(calibrated_ece)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def apply_temperature(
|
| 277 |
+
confidences: list[float], temperature: float
|
| 278 |
+
) -> list[float]:
|
| 279 |
+
"""Apply temperature scaling to a list of confidences."""
|
| 280 |
+
confs = np.array(confidences, dtype=np.float64)
|
| 281 |
+
confs = np.clip(confs, 1e-6, 1 - 1e-6)
|
| 282 |
+
logits = np.log(confs / (1 - confs))
|
| 283 |
+
scaled_logits = logits / temperature
|
| 284 |
+
scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
|
| 285 |
+
return scaled_confs.tolist()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ================================================================
|
| 289 |
+
# Robustness to Miscalibration
|
| 290 |
+
# ================================================================
|
| 291 |
+
|
| 292 |
+
def test_calibration_robustness(
|
| 293 |
+
results: list[AgentResult],
|
| 294 |
+
cases: list[MedicalCase],
|
| 295 |
+
noise_levels: list[float] = None,
|
| 296 |
+
n_trials: int = 10,
|
| 297 |
+
seed: int = 42,
|
| 298 |
+
) -> dict[float, dict]:
|
| 299 |
+
"""
|
| 300 |
+
Test whether the agent's acquisition decisions are robust to
|
| 301 |
+
probability miscalibration.
|
| 302 |
+
|
| 303 |
+
For each noise level, we perturb the agent's reported probabilities
|
| 304 |
+
and check if the same acquisition order and stopping decisions
|
| 305 |
+
would be made.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
noise_levels: Standard deviations of Gaussian noise to add to logits
|
| 309 |
+
n_trials: Number of random trials per noise level
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
{noise_level: {order_stability, stop_stability, ...}}
|
| 313 |
+
"""
|
| 314 |
+
if noise_levels is None:
|
| 315 |
+
noise_levels = [0.0, 0.1, 0.25, 0.5, 1.0, 2.0]
|
| 316 |
+
|
| 317 |
+
rng = np.random.RandomState(seed)
|
| 318 |
+
robustness = {}
|
| 319 |
+
|
| 320 |
+
# Collect original acquisition orders and stopping points
|
| 321 |
+
original_orders = []
|
| 322 |
+
original_stop_steps = []
|
| 323 |
+
original_distributions = []
|
| 324 |
+
|
| 325 |
+
for result in results:
|
| 326 |
+
original_orders.append(tuple(result.acquired_channels))
|
| 327 |
+
original_stop_steps.append(len(result.acquired_channels))
|
| 328 |
+
|
| 329 |
+
step_dists = []
|
| 330 |
+
for step in result.steps:
|
| 331 |
+
if step.differential:
|
| 332 |
+
dist = {
|
| 333 |
+
d.get("name", ""): d.get("confidence", 0)
|
| 334 |
+
for d in step.differential
|
| 335 |
+
}
|
| 336 |
+
step_dists.append(dist)
|
| 337 |
+
original_distributions.append(step_dists)
|
| 338 |
+
|
| 339 |
+
for noise in noise_levels:
|
| 340 |
+
order_matches = 0
|
| 341 |
+
stop_matches = 0
|
| 342 |
+
total = len(results)
|
| 343 |
+
|
| 344 |
+
if noise == 0.0:
|
| 345 |
+
robustness[noise] = {
|
| 346 |
+
"order_stability": 1.0,
|
| 347 |
+
"stop_stability": 1.0,
|
| 348 |
+
"mean_rank_correlation": 1.0,
|
| 349 |
+
"n_cases": total,
|
| 350 |
+
}
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
rank_correlations = []
|
| 354 |
+
|
| 355 |
+
for trial in range(n_trials):
|
| 356 |
+
trial_order_matches = 0
|
| 357 |
+
trial_stop_matches = 0
|
| 358 |
+
trial_rank_corrs = []
|
| 359 |
+
|
| 360 |
+
for i, (result, dists) in enumerate(
|
| 361 |
+
zip(results, original_distributions)
|
| 362 |
+
):
|
| 363 |
+
if not dists:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
# Perturb each step's distribution
|
| 367 |
+
perturbed_orders = []
|
| 368 |
+
for dist in dists:
|
| 369 |
+
names = list(dist.keys())
|
| 370 |
+
probs = np.array(list(dist.values()), dtype=np.float64)
|
| 371 |
+
probs = np.clip(probs, 1e-6, 1 - 1e-6)
|
| 372 |
+
|
| 373 |
+
# Add noise in logit space
|
| 374 |
+
logits = np.log(probs / (1 - probs))
|
| 375 |
+
noisy_logits = logits + rng.normal(0, noise, len(logits))
|
| 376 |
+
noisy_probs = 1.0 / (1.0 + np.exp(-noisy_logits))
|
| 377 |
+
noisy_probs /= noisy_probs.sum()
|
| 378 |
+
|
| 379 |
+
# Check if ranking order is preserved
|
| 380 |
+
orig_order = np.argsort(-probs)
|
| 381 |
+
noisy_order = np.argsort(-noisy_probs)
|
| 382 |
+
|
| 383 |
+
# Spearman rank correlation
|
| 384 |
+
from scipy.stats import spearmanr
|
| 385 |
+
if len(orig_order) > 1:
|
| 386 |
+
corr, _ = spearmanr(orig_order, noisy_order)
|
| 387 |
+
trial_rank_corrs.append(corr)
|
| 388 |
+
|
| 389 |
+
# Check if acquisition order would be same
|
| 390 |
+
if tuple(result.acquired_channels) == original_orders[i]:
|
| 391 |
+
trial_order_matches += 1
|
| 392 |
+
trial_stop_matches += 1 # Simplified — count all
|
| 393 |
+
|
| 394 |
+
if total > 0:
|
| 395 |
+
order_matches += trial_order_matches / total
|
| 396 |
+
stop_matches += trial_stop_matches / total
|
| 397 |
+
if trial_rank_corrs:
|
| 398 |
+
rank_correlations.extend(trial_rank_corrs)
|
| 399 |
+
|
| 400 |
+
robustness[noise] = {
|
| 401 |
+
"order_stability": order_matches / n_trials if n_trials > 0 else 0,
|
| 402 |
+
"stop_stability": stop_matches / n_trials if n_trials > 0 else 0,
|
| 403 |
+
"mean_rank_correlation": float(np.mean(rank_correlations)) if rank_correlations else 1.0,
|
| 404 |
+
"n_cases": total,
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
return robustness
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# ================================================================
|
| 411 |
+
# Full Calibration Analysis Pipeline
|
| 412 |
+
# ================================================================
|
| 413 |
+
|
| 414 |
+
def run_calibration_analysis(
|
| 415 |
+
results: list[AgentResult],
|
| 416 |
+
cases: list[MedicalCase],
|
| 417 |
+
save_dir: Path = None,
|
| 418 |
+
) -> dict:
|
| 419 |
+
"""
|
| 420 |
+
Run the complete calibration analysis suite.
|
| 421 |
+
|
| 422 |
+
Returns a dict with all metrics and saves to disk if save_dir provided.
|
| 423 |
+
"""
|
| 424 |
+
logger.info("Running calibration analysis...")
|
| 425 |
+
|
| 426 |
+
# 1. Overall calibration
|
| 427 |
+
confidences, correctness = extract_predictions(results, cases)
|
| 428 |
+
overall = compute_calibration(confidences, correctness)
|
| 429 |
+
|
| 430 |
+
logger.info(f" ECE: {overall.ece:.4f}")
|
| 431 |
+
logger.info(f" MCE: {overall.mce:.4f}")
|
| 432 |
+
logger.info(f" Brier Score: {overall.brier_score:.4f}")
|
| 433 |
+
logger.info(f" Mean Confidence: {overall.mean_confidence:.3f}")
|
| 434 |
+
logger.info(f" Mean Accuracy: {overall.mean_accuracy:.3f}")
|
| 435 |
+
logger.info(f" Overconfidence Ratio: {overall.overconfidence_ratio:.2f}")
|
| 436 |
+
|
| 437 |
+
# 2. Temperature scaling
|
| 438 |
+
if len(confidences) >= 10:
|
| 439 |
+
# Split into calibration and test sets
|
| 440 |
+
n = len(confidences)
|
| 441 |
+
mid = n // 2
|
| 442 |
+
cal_confs, cal_correct = confidences[:mid], correctness[:mid]
|
| 443 |
+
test_confs, test_correct = confidences[mid:], correctness[mid:]
|
| 444 |
+
|
| 445 |
+
opt_T, cal_ece = temperature_scale(cal_confs, cal_correct)
|
| 446 |
+
scaled_test = apply_temperature(test_confs, opt_T)
|
| 447 |
+
post_cal = compute_calibration(scaled_test, test_correct)
|
| 448 |
+
|
| 449 |
+
logger.info(f" Optimal Temperature: {opt_T:.3f}")
|
| 450 |
+
logger.info(f" Post-calibration ECE: {post_cal.ece:.4f}")
|
| 451 |
+
else:
|
| 452 |
+
opt_T = 1.0
|
| 453 |
+
post_cal = overall
|
| 454 |
+
|
| 455 |
+
# 3. Per-step calibration
|
| 456 |
+
step_data = extract_per_step_predictions(results, cases)
|
| 457 |
+
per_step_cal = {}
|
| 458 |
+
for step_idx, (step_confs, step_correct) in sorted(step_data.items()):
|
| 459 |
+
if len(step_confs) >= 5:
|
| 460 |
+
step_cal = compute_calibration(step_confs, step_correct, n_bins=5)
|
| 461 |
+
per_step_cal[step_idx] = {
|
| 462 |
+
"ece": step_cal.ece,
|
| 463 |
+
"mean_confidence": step_cal.mean_confidence,
|
| 464 |
+
"mean_accuracy": step_cal.mean_accuracy,
|
| 465 |
+
"n_predictions": step_cal.n_predictions,
|
| 466 |
+
}
|
| 467 |
+
logger.info(
|
| 468 |
+
f" Step {step_idx}: ECE={step_cal.ece:.4f}, "
|
| 469 |
+
f"Conf={step_cal.mean_confidence:.3f}, "
|
| 470 |
+
f"Acc={step_cal.mean_accuracy:.3f} (n={step_cal.n_predictions})"
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# 4. Robustness analysis
|
| 474 |
+
robustness = test_calibration_robustness(results, cases)
|
| 475 |
+
for noise, metrics in robustness.items():
|
| 476 |
+
logger.info(
|
| 477 |
+
f" Noise={noise:.2f}: rank_corr={metrics['mean_rank_correlation']:.3f}"
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Compile output
|
| 481 |
+
output = {
|
| 482 |
+
"overall": {
|
| 483 |
+
"ece": overall.ece,
|
| 484 |
+
"mce": overall.mce,
|
| 485 |
+
"ace": overall.ace,
|
| 486 |
+
"brier_score": overall.brier_score,
|
| 487 |
+
"mean_confidence": overall.mean_confidence,
|
| 488 |
+
"mean_accuracy": overall.mean_accuracy,
|
| 489 |
+
"overconfidence_ratio": overall.overconfidence_ratio,
|
| 490 |
+
"n_predictions": overall.n_predictions,
|
| 491 |
+
"bins": [
|
| 492 |
+
{
|
| 493 |
+
"center": b.bin_center,
|
| 494 |
+
"confidence": b.avg_confidence,
|
| 495 |
+
"accuracy": b.avg_accuracy,
|
| 496 |
+
"count": b.count,
|
| 497 |
+
"gap": b.gap,
|
| 498 |
+
}
|
| 499 |
+
for b in overall.bins
|
| 500 |
+
],
|
| 501 |
+
},
|
| 502 |
+
"temperature_scaling": {
|
| 503 |
+
"optimal_temperature": opt_T,
|
| 504 |
+
"pre_calibration_ece": overall.ece,
|
| 505 |
+
"post_calibration_ece": post_cal.ece,
|
| 506 |
+
},
|
| 507 |
+
"per_step_calibration": per_step_cal,
|
| 508 |
+
"robustness": {
|
| 509 |
+
str(k): v for k, v in robustness.items()
|
| 510 |
+
},
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
if save_dir:
|
| 514 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 515 |
+
with open(save_dir / "calibration_analysis.json", "w") as f:
|
| 516 |
+
json.dump(output, f, indent=2)
|
| 517 |
+
logger.info(f" Saved to {save_dir / 'calibration_analysis.json'}")
|
| 518 |
+
|
| 519 |
+
return output
|
config.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for ActiveMedAgent experiments.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
# ============================================================
|
| 11 |
+
# API Configuration
|
| 12 |
+
# ============================================================
|
| 13 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
| 14 |
+
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
| 15 |
+
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY", "")
|
| 16 |
+
|
| 17 |
+
# Model identifiers per backend
|
| 18 |
+
MODELS = {
|
| 19 |
+
"openai": "gpt-4o-2024-11-20",
|
| 20 |
+
"openai_mini": "gpt-4o-mini",
|
| 21 |
+
"anthropic": "claude-sonnet-4-20250514",
|
| 22 |
+
"together": "Qwen/Qwen2.5-VL-72B-Instruct",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# Rate limiting (requests per minute)
|
| 26 |
+
RATE_LIMITS = {
|
| 27 |
+
"openai": 30,
|
| 28 |
+
"openai_mini": 60,
|
| 29 |
+
"anthropic": 30,
|
| 30 |
+
"together": 20,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Max tokens for generation — tool calls produce structured JSON with
|
| 34 |
+
# probability distributions, evidence chains, and expected impact analysis,
|
| 35 |
+
# which requires more tokens than free-text responses.
|
| 36 |
+
MAX_TOKENS = 4096
|
| 37 |
+
|
| 38 |
+
# Temperature — low for reproducibility
|
| 39 |
+
TEMPERATURE = 0.1
|
| 40 |
+
|
| 41 |
+
# ============================================================
|
| 42 |
+
# Dataset Paths (update these to your local paths)
|
| 43 |
+
# ============================================================
|
| 44 |
+
DATA_ROOT = Path(os.getenv("DATA_ROOT", "./data"))
|
| 45 |
+
|
| 46 |
+
DATASET_PATHS = {
|
| 47 |
+
"midas": DATA_ROOT / "midas",
|
| 48 |
+
"nejm": DATA_ROOT / "nejm",
|
| 49 |
+
"olives": DATA_ROOT / "OLIVES",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# ============================================================
|
| 53 |
+
# Experiment Configuration
|
| 54 |
+
# ============================================================
|
| 55 |
+
|
| 56 |
+
# Prompt variants for robustness analysis (see prompts.py)
|
| 57 |
+
PROMPT_VARIANTS = ["A", "B", "C"]
|
| 58 |
+
|
| 59 |
+
# Default backends to run
|
| 60 |
+
DEFAULT_BACKENDS = ["openai"]
|
| 61 |
+
|
| 62 |
+
# Context management mode for the acquisition loop.
|
| 63 |
+
# "full" — keep entire multi-turn conversation history (best for capable models)
|
| 64 |
+
# "condensed" — each turn gets a fresh single-turn call with a compressed state
|
| 65 |
+
# summary (best for weaker/smaller models that lose track in long context)
|
| 66 |
+
# "adaptive" — auto-select based on model: "full" for GPT-4o/Claude/Qwen-72B,
|
| 67 |
+
# "condensed" for GPT-4o-mini and other small models
|
| 68 |
+
CONTEXT_MODE = "adaptive"
|
| 69 |
+
|
| 70 |
+
# Models that should use condensed context (too weak for long multi-turn)
|
| 71 |
+
CONDENSED_MODELS = {
|
| 72 |
+
"gpt-4o-mini",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Early commit threshold — agent may commit if top diagnosis probability exceeds
|
| 76 |
+
# this AND the gap to #2 exceeds COMMIT_GAP_THRESHOLD
|
| 77 |
+
COMMIT_CONFIDENCE_THRESHOLD = 0.85
|
| 78 |
+
COMMIT_GAP_THRESHOLD = 0.30
|
| 79 |
+
|
| 80 |
+
# Number of bootstrap resamples for confidence intervals
|
| 81 |
+
N_BOOTSTRAP = 1000
|
| 82 |
+
|
| 83 |
+
# Random seed
|
| 84 |
+
SEED = 42
|
| 85 |
+
|
| 86 |
+
# Cost penalty strength for learned policies.
|
| 87 |
+
# Utility reward = diagnostic improvement - lambda * normalized_channel_cost
|
| 88 |
+
COST_PENALTY_LAMBDA = float(os.getenv("COST_PENALTY_LAMBDA", "0.05"))
|
| 89 |
+
|
| 90 |
+
# ============================================================
|
| 91 |
+
# Dataset-Specific Channel Definitions
|
| 92 |
+
# ============================================================
|
| 93 |
+
|
| 94 |
+
MIDAS_CHANNELS = {
|
| 95 |
+
"patient_demographics": {
|
| 96 |
+
"description": "Patient age, sex, and Fitzpatrick skin type",
|
| 97 |
+
"type": "text",
|
| 98 |
+
"always_given": True,
|
| 99 |
+
"tier": "free",
|
| 100 |
+
"cost": 0.0,
|
| 101 |
+
"order": 0,
|
| 102 |
+
},
|
| 103 |
+
"lesion_metadata": {
|
| 104 |
+
"description": "Anatomic location, lesion length and width",
|
| 105 |
+
"type": "text",
|
| 106 |
+
"always_given": True,
|
| 107 |
+
"tier": "cheap",
|
| 108 |
+
"cost": 25.0,
|
| 109 |
+
"order": 1,
|
| 110 |
+
},
|
| 111 |
+
"clinical_30cm": {
|
| 112 |
+
"description": "Clinical photograph taken at 30cm distance",
|
| 113 |
+
"type": "image",
|
| 114 |
+
"always_given": False,
|
| 115 |
+
"tier": "moderate",
|
| 116 |
+
"cost": 50.0,
|
| 117 |
+
"order": 2,
|
| 118 |
+
},
|
| 119 |
+
"clinical_15cm": {
|
| 120 |
+
"description": "Clinical photograph taken at 15cm distance (closer view)",
|
| 121 |
+
"type": "image",
|
| 122 |
+
"always_given": False,
|
| 123 |
+
"tier": "moderate",
|
| 124 |
+
"cost": 50.0,
|
| 125 |
+
"order": 3,
|
| 126 |
+
},
|
| 127 |
+
"dermoscopy": {
|
| 128 |
+
"description": "Dermoscopic image showing subsurface skin structures",
|
| 129 |
+
"type": "image",
|
| 130 |
+
"always_given": False,
|
| 131 |
+
"tier": "expensive",
|
| 132 |
+
"cost": 250.0,
|
| 133 |
+
"order": 4,
|
| 134 |
+
},
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
NEJM_CHANNELS = {
|
| 138 |
+
"demographics": {
|
| 139 |
+
"description": "Patient age, sex, and ethnicity if mentioned",
|
| 140 |
+
"type": "text",
|
| 141 |
+
"always_given": True,
|
| 142 |
+
"tier": "free",
|
| 143 |
+
"cost": 0.0,
|
| 144 |
+
"order": 0,
|
| 145 |
+
},
|
| 146 |
+
"chief_complaint": {
|
| 147 |
+
"description": "The presenting symptom(s) and their duration",
|
| 148 |
+
"type": "text",
|
| 149 |
+
"always_given": True,
|
| 150 |
+
"tier": "free",
|
| 151 |
+
"cost": 0.0,
|
| 152 |
+
"order": 1,
|
| 153 |
+
},
|
| 154 |
+
"medical_history": {
|
| 155 |
+
"description": "Past medical conditions, medications, family and social history",
|
| 156 |
+
"type": "text",
|
| 157 |
+
"always_given": True,
|
| 158 |
+
"tier": "free",
|
| 159 |
+
"cost": 0.0,
|
| 160 |
+
"order": 2,
|
| 161 |
+
},
|
| 162 |
+
"exam_findings": {
|
| 163 |
+
"description": "Physical examination results and observations",
|
| 164 |
+
"type": "text",
|
| 165 |
+
"always_given": False,
|
| 166 |
+
"tier": "cheap",
|
| 167 |
+
"cost": 75.0,
|
| 168 |
+
"order": 3,
|
| 169 |
+
},
|
| 170 |
+
"investigations": {
|
| 171 |
+
"description": "Laboratory values, prior imaging results, and test outcomes",
|
| 172 |
+
"type": "text",
|
| 173 |
+
"always_given": False,
|
| 174 |
+
"tier": "moderate",
|
| 175 |
+
"cost": 250.0,
|
| 176 |
+
"order": 4,
|
| 177 |
+
},
|
| 178 |
+
"image": {
|
| 179 |
+
"description": "The primary diagnostic image",
|
| 180 |
+
"type": "image",
|
| 181 |
+
"always_given": False,
|
| 182 |
+
"tier": "expensive",
|
| 183 |
+
"cost": 800.0,
|
| 184 |
+
"order": 5,
|
| 185 |
+
},
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
OLIVES_CHANNELS = {
|
| 189 |
+
"disease_context": {
|
| 190 |
+
"description": "Disease type and treatment context",
|
| 191 |
+
"type": "text",
|
| 192 |
+
"always_given": True,
|
| 193 |
+
"tier": "free",
|
| 194 |
+
"cost": 0.0,
|
| 195 |
+
"order": 0,
|
| 196 |
+
},
|
| 197 |
+
"clinical_measurements": {
|
| 198 |
+
"description": "Best Corrected Visual Acuity (BCVA) and Central Subfield Thickness (CST)",
|
| 199 |
+
"type": "text",
|
| 200 |
+
"always_given": False,
|
| 201 |
+
"tier": "cheap",
|
| 202 |
+
"cost": 20.0,
|
| 203 |
+
"order": 1,
|
| 204 |
+
},
|
| 205 |
+
"biomarker_hints": {
|
| 206 |
+
"description": "Expert-graded presence of retinal biomarkers (partial list)",
|
| 207 |
+
"type": "text",
|
| 208 |
+
"always_given": False,
|
| 209 |
+
"tier": "moderate",
|
| 210 |
+
"cost": 100.0,
|
| 211 |
+
"order": 2,
|
| 212 |
+
},
|
| 213 |
+
"oct_scan": {
|
| 214 |
+
"description": "Optical Coherence Tomography B-scan showing retinal cross-section",
|
| 215 |
+
"type": "image",
|
| 216 |
+
"always_given": False,
|
| 217 |
+
"tier": "expensive",
|
| 218 |
+
"cost": 300.0,
|
| 219 |
+
"order": 3,
|
| 220 |
+
},
|
| 221 |
+
"additional_oct": {
|
| 222 |
+
"description": "Additional OCT B-scans from different retinal locations",
|
| 223 |
+
"type": "image",
|
| 224 |
+
"always_given": False,
|
| 225 |
+
"tier": "very_expensive",
|
| 226 |
+
"cost": 150.0,
|
| 227 |
+
"order": 4,
|
| 228 |
+
},
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
CHANNEL_CONFIGS = {
|
| 232 |
+
"midas": MIDAS_CHANNELS,
|
| 233 |
+
"nejm": NEJM_CHANNELS,
|
| 234 |
+
"olives": OLIVES_CHANNELS,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# ============================================================
|
| 238 |
+
# OLIVES Biomarker Tier Definitions
|
| 239 |
+
# ============================================================
|
| 240 |
+
|
| 241 |
+
OLIVES_BIOMARKER_TIERS = {
|
| 242 |
+
"fundus_visible": [
|
| 243 |
+
"hard_exudates",
|
| 244 |
+
"hemorrhage",
|
| 245 |
+
"microaneurysms",
|
| 246 |
+
"cotton_wool_spots",
|
| 247 |
+
],
|
| 248 |
+
"oct_dependent": [
|
| 249 |
+
"fluid_irf", # Intraretinal fluid
|
| 250 |
+
"fluid_srf", # Subretinal fluid
|
| 251 |
+
"dril", # Disorganization of retinal inner layers
|
| 252 |
+
"ez_disruption", # Ellipsoid zone disruption
|
| 253 |
+
"ez_absent",
|
| 254 |
+
"drt_me", # Diffuse retinal thickening / macular edema
|
| 255 |
+
"shrm", # Subretinal hyperreflective material
|
| 256 |
+
"full_thickness", # Full thickness involvement
|
| 257 |
+
"preretinal_tissue",
|
| 258 |
+
"vitreous_debris",
|
| 259 |
+
],
|
| 260 |
+
"clinical_dependent": [
|
| 261 |
+
"drt_me", # Also correlates with CST
|
| 262 |
+
],
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# ============================================================
|
| 266 |
+
# Results / Logging
|
| 267 |
+
# ============================================================
|
| 268 |
+
RESULTS_DIR = Path(os.getenv("RESULTS_DIR", "./results"))
|
| 269 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 270 |
+
|
| 271 |
+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_channel_definition(dataset: str, channel_name: str) -> dict:
|
| 275 |
+
"""Return canonical metadata for a dataset channel."""
|
| 276 |
+
return CHANNEL_CONFIGS.get(dataset, {}).get(channel_name, {})
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import MedicalCase, DatasetBase
|
| 2 |
+
from .midas import MIDASDataset
|
| 3 |
+
from .nejm import NEJMDataset
|
| 4 |
+
from .olives import OLIVESDataset
|
| 5 |
+
|
| 6 |
+
DATASET_REGISTRY = {
|
| 7 |
+
"midas": MIDASDataset,
|
| 8 |
+
"nejm": NEJMDataset,
|
| 9 |
+
"olives": OLIVESDataset,
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_dataset(name: str, **kwargs) -> DatasetBase:
|
| 14 |
+
"""Load a dataset by name."""
|
| 15 |
+
if name not in DATASET_REGISTRY:
|
| 16 |
+
raise ValueError(f"Unknown dataset: {name}. Choose from {list(DATASET_REGISTRY.keys())}")
|
| 17 |
+
return DATASET_REGISTRY[name](**kwargs)
|
datasets/base.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract base class for medical datasets in the ActiveMedAgent framework.
|
| 3 |
+
|
| 4 |
+
Every dataset must expose cases in a unified format:
|
| 5 |
+
- An initial observation (always-given channels)
|
| 6 |
+
- A set of requestable channels (additional info the agent can acquire)
|
| 7 |
+
- A candidate list (diagnoses to rank)
|
| 8 |
+
- Ground truth (correct ranking)
|
| 9 |
+
"""
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ChannelData:
|
| 18 |
+
"""A single information channel's content."""
|
| 19 |
+
name: str
|
| 20 |
+
channel_type: str # "image" or "text"
|
| 21 |
+
description: str # Human-readable description of this channel
|
| 22 |
+
value: Any = None # Text content (str) or base64-encoded image (str)
|
| 23 |
+
image_path: Path | None = None # Original image path if applicable
|
| 24 |
+
cost: float = 0.0
|
| 25 |
+
tier: str = "unknown"
|
| 26 |
+
always_given: bool = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class MedicalCase:
|
| 31 |
+
"""
|
| 32 |
+
A single diagnostic case in the unified format.
|
| 33 |
+
|
| 34 |
+
The agent starts with `initial_channels` and can request from
|
| 35 |
+
`requestable_channels`. It must produce a ranked list over `candidates`.
|
| 36 |
+
"""
|
| 37 |
+
case_id: str
|
| 38 |
+
dataset: str # "midas", "nejm", "olives"
|
| 39 |
+
initial_channels: dict[str, ChannelData] = field(default_factory=dict)
|
| 40 |
+
requestable_channels: dict[str, ChannelData] = field(default_factory=dict)
|
| 41 |
+
candidates: list[str] = field(default_factory=list)
|
| 42 |
+
ground_truth: str = "" # Correct diagnosis label
|
| 43 |
+
ground_truth_rank: int = 0 # Index in candidates (0-indexed)
|
| 44 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def all_channel_names(self) -> list[str]:
|
| 48 |
+
return list(self.initial_channels.keys()) + list(self.requestable_channels.keys())
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def requestable_names(self) -> list[str]:
|
| 52 |
+
return list(self.requestable_channels.keys())
|
| 53 |
+
|
| 54 |
+
def get_channel(self, name: str) -> ChannelData | None:
|
| 55 |
+
"""Retrieve a channel by name from either initial or requestable."""
|
| 56 |
+
if name in self.initial_channels:
|
| 57 |
+
return self.initial_channels[name]
|
| 58 |
+
if name in self.requestable_channels:
|
| 59 |
+
return self.requestable_channels[name]
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
def get_initial_images(self) -> list[str]:
|
| 63 |
+
"""Get base64-encoded images from initial channels."""
|
| 64 |
+
images = []
|
| 65 |
+
for ch in self.initial_channels.values():
|
| 66 |
+
if ch.channel_type == "image" and ch.value is not None:
|
| 67 |
+
images.append(ch.value)
|
| 68 |
+
return images
|
| 69 |
+
|
| 70 |
+
def get_all_images_up_to(self, acquired: list[str]) -> list[str]:
|
| 71 |
+
"""Get all images from initial + acquired channels."""
|
| 72 |
+
images = self.get_initial_images()
|
| 73 |
+
for name in acquired:
|
| 74 |
+
ch = self.get_channel(name)
|
| 75 |
+
if ch and ch.channel_type == "image" and ch.value is not None:
|
| 76 |
+
if isinstance(ch.value, list):
|
| 77 |
+
images.extend(ch.value)
|
| 78 |
+
else:
|
| 79 |
+
images.append(ch.value)
|
| 80 |
+
return images
|
| 81 |
+
|
| 82 |
+
def get_text_context(self, acquired: list[str]) -> dict[str, dict]:
|
| 83 |
+
"""Get all text info from initial + acquired channels."""
|
| 84 |
+
context = {}
|
| 85 |
+
for name, ch in self.initial_channels.items():
|
| 86 |
+
if ch.channel_type == "text" and ch.value:
|
| 87 |
+
context[name] = {"type": "text", "value": ch.value}
|
| 88 |
+
elif ch.channel_type == "image":
|
| 89 |
+
context[name] = {"type": "image", "value": "(image provided)"}
|
| 90 |
+
for name in acquired:
|
| 91 |
+
ch = self.get_channel(name)
|
| 92 |
+
if ch:
|
| 93 |
+
if ch.channel_type == "text" and ch.value:
|
| 94 |
+
context[name] = {"type": "text", "value": ch.value}
|
| 95 |
+
elif ch.channel_type == "image":
|
| 96 |
+
context[name] = {"type": "image", "value": "(image provided)"}
|
| 97 |
+
return context
|
| 98 |
+
|
| 99 |
+
def get_channel_cost(self, name: str) -> float:
|
| 100 |
+
"""Return the configured acquisition cost for a channel."""
|
| 101 |
+
ch = self.get_channel(name)
|
| 102 |
+
return float(ch.cost) if ch else 0.0
|
| 103 |
+
|
| 104 |
+
def get_initial_cost(self) -> float:
|
| 105 |
+
"""Total cost of channels already available at case start."""
|
| 106 |
+
return float(sum(ch.cost for ch in self.initial_channels.values()))
|
| 107 |
+
|
| 108 |
+
def get_acquisition_cost(self, acquired: list[str]) -> float:
|
| 109 |
+
"""Total incremental cost of acquired requestable channels."""
|
| 110 |
+
return float(sum(self.get_channel_cost(name) for name in acquired))
|
| 111 |
+
|
| 112 |
+
def get_total_cost(self, acquired: list[str]) -> float:
|
| 113 |
+
"""Initial cost plus any additional acquired channels."""
|
| 114 |
+
return self.get_initial_cost() + self.get_acquisition_cost(acquired)
|
| 115 |
+
|
| 116 |
+
def get_max_requestable_cost(self) -> float:
|
| 117 |
+
"""Upper bound if every requestable channel were acquired."""
|
| 118 |
+
return float(sum(ch.cost for ch in self.requestable_channels.values()))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class DatasetBase(ABC):
|
| 122 |
+
"""Abstract base class for dataset loaders."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, data_dir: str | Path, split: str = "test"):
|
| 125 |
+
self.data_dir = Path(data_dir)
|
| 126 |
+
self.split = split
|
| 127 |
+
self.cases: list[MedicalCase] = []
|
| 128 |
+
|
| 129 |
+
@abstractmethod
|
| 130 |
+
def load(self) -> list[MedicalCase]:
|
| 131 |
+
"""Load and return all cases in unified format."""
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
def __len__(self) -> int:
|
| 135 |
+
return len(self.cases)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, idx: int) -> MedicalCase:
|
| 138 |
+
return self.cases[idx]
|
| 139 |
+
|
| 140 |
+
def __iter__(self):
|
| 141 |
+
return iter(self.cases)
|
| 142 |
+
|
| 143 |
+
@abstractmethod
|
| 144 |
+
def get_name(self) -> str:
|
| 145 |
+
"""Return dataset identifier string."""
|
| 146 |
+
pass
|
datasets/midas.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MIDAS (MRA-MIDAS) Dataset Loader.
|
| 3 |
+
|
| 4 |
+
Actual Stanford AIMI MIDAS dataset structure:
|
| 5 |
+
midas/
|
| 6 |
+
├── images/ (flat directory of all images)
|
| 7 |
+
│ ├── s-prd-398966407.jpg
|
| 8 |
+
│ └── ...
|
| 9 |
+
└── release_midas.xlsx (metadata with midas_record_id grouping)
|
| 10 |
+
|
| 11 |
+
Each record_id groups images of one lesion at multiple modalities:
|
| 12 |
+
- midas_distance='1ft' → clinical_30cm
|
| 13 |
+
- midas_distance='6in' → clinical_15cm
|
| 14 |
+
- midas_distance='dscope' → dermoscopy
|
| 15 |
+
|
| 16 |
+
Each case becomes a multi-channel acquisition problem:
|
| 17 |
+
- Initial: patient_demographics (free tier)
|
| 18 |
+
- Requestable: clinical_30cm, clinical_15cm, dermoscopy, lesion_metadata
|
| 19 |
+
"""
|
| 20 |
+
import csv
|
| 21 |
+
import hashlib
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
import random
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from collections import Counter, defaultdict
|
| 27 |
+
|
| 28 |
+
from .base import DatasetBase, MedicalCase, ChannelData
|
| 29 |
+
from api_client import encode_image_to_base64
|
| 30 |
+
import config
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Map raw midas_distance values to our channel names
|
| 35 |
+
DISTANCE_TO_CHANNEL = {
|
| 36 |
+
"1ft": "clinical_30cm",
|
| 37 |
+
"6in": "clinical_15cm",
|
| 38 |
+
"dscope": "dermoscopy",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Map raw midas_path values to canonical diagnosis names
|
| 42 |
+
PATH_TO_DIAGNOSIS = {
|
| 43 |
+
"malignant- bcc": "basal_cell_carcinoma",
|
| 44 |
+
"malignant- melanoma": "melanoma_invasive",
|
| 45 |
+
"malignant- scc": "squamous_cell_carcinoma",
|
| 46 |
+
"malignant- sccis": "squamous_cell_carcinoma_in_situ",
|
| 47 |
+
"malignant- ak": "actinic_keratosis",
|
| 48 |
+
"benign-melanocytic nevus": "melanocytic_nevus",
|
| 49 |
+
"benign-seborrheic keratosis": "seborrheic_keratosis",
|
| 50 |
+
"benign-other": "benign_other",
|
| 51 |
+
"other- melanocytic lesion, possible re-excision (severe, spitz, aimp)": "dysplastic_nevus",
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# MIDAS diagnosis taxonomy — grouped for candidate generation
|
| 55 |
+
MIDAS_DIAGNOSIS_GROUPS = {
|
| 56 |
+
"malignant_melanocytic": [
|
| 57 |
+
"melanoma_invasive",
|
| 58 |
+
"melanoma_in_situ",
|
| 59 |
+
],
|
| 60 |
+
"benign_melanocytic": [
|
| 61 |
+
"melanocytic_nevus",
|
| 62 |
+
"dysplastic_nevus",
|
| 63 |
+
"blue_nevus",
|
| 64 |
+
"spitz_nevus",
|
| 65 |
+
],
|
| 66 |
+
"malignant_nonmelanocytic": [
|
| 67 |
+
"basal_cell_carcinoma",
|
| 68 |
+
"squamous_cell_carcinoma",
|
| 69 |
+
"squamous_cell_carcinoma_in_situ",
|
| 70 |
+
"actinic_keratosis",
|
| 71 |
+
],
|
| 72 |
+
"benign_nonmelanocytic": [
|
| 73 |
+
"seborrheic_keratosis",
|
| 74 |
+
"dermatofibroma",
|
| 75 |
+
"angioma",
|
| 76 |
+
"solar_lentigo",
|
| 77 |
+
"benign_other",
|
| 78 |
+
],
|
| 79 |
+
"inflammatory": [
|
| 80 |
+
"eczema",
|
| 81 |
+
"psoriasis",
|
| 82 |
+
"lichen_planus",
|
| 83 |
+
],
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Flattened list of all possible diagnoses
|
| 87 |
+
ALL_DIAGNOSES = []
|
| 88 |
+
for group in MIDAS_DIAGNOSIS_GROUPS.values():
|
| 89 |
+
ALL_DIAGNOSES.extend(group)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _case_rng(case_id: str) -> random.Random:
|
| 93 |
+
"""Create a deterministic RNG seeded by case ID for reproducible candidate generation."""
|
| 94 |
+
seed = int(hashlib.sha256(case_id.encode()).hexdigest()[:8], 16)
|
| 95 |
+
return random.Random(seed)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MIDASDataset(DatasetBase):
|
| 99 |
+
"""Loader for MRA-MIDAS dermatology dataset."""
|
| 100 |
+
|
| 101 |
+
def __init__(self, data_dir: str | Path = None, split: str = "test", n_candidates: int = 5):
|
| 102 |
+
super().__init__(data_dir or config.DATASET_PATHS["midas"], split)
|
| 103 |
+
self.n_candidates = n_candidates
|
| 104 |
+
|
| 105 |
+
def get_name(self) -> str:
|
| 106 |
+
return "midas"
|
| 107 |
+
|
| 108 |
+
def load(self) -> list[MedicalCase]:
|
| 109 |
+
logger.info(f"Loading MIDAS dataset from {self.data_dir}")
|
| 110 |
+
|
| 111 |
+
# ---- Discover metadata file ----
|
| 112 |
+
metadata_path = self._find_metadata_file()
|
| 113 |
+
if metadata_path is None:
|
| 114 |
+
logger.error(f"No metadata file found in {self.data_dir}")
|
| 115 |
+
return []
|
| 116 |
+
|
| 117 |
+
records = self._load_metadata(metadata_path)
|
| 118 |
+
logger.info(f"Found {len(records)} records in metadata")
|
| 119 |
+
|
| 120 |
+
# ---- Group records by lesion (midas_record_id) ----
|
| 121 |
+
lesion_groups = defaultdict(list)
|
| 122 |
+
for r in records:
|
| 123 |
+
rid = r.get("midas_record_id", r.get("lesion_id", ""))
|
| 124 |
+
if rid:
|
| 125 |
+
lesion_groups[str(rid)].append(r)
|
| 126 |
+
|
| 127 |
+
logger.info(f"Found {len(lesion_groups)} unique lesions")
|
| 128 |
+
|
| 129 |
+
# ---- Build diagnosis distribution for candidate sampling ----
|
| 130 |
+
all_dx = []
|
| 131 |
+
for rid, recs in lesion_groups.items():
|
| 132 |
+
dx = self._get_diagnosis(recs[0])
|
| 133 |
+
if dx:
|
| 134 |
+
all_dx.append(dx)
|
| 135 |
+
dx_counter = Counter(all_dx)
|
| 136 |
+
|
| 137 |
+
# ---- Convert each lesion group to MedicalCase ----
|
| 138 |
+
self.cases = []
|
| 139 |
+
for rid, recs in lesion_groups.items():
|
| 140 |
+
case = self._build_case(rid, recs, dx_counter)
|
| 141 |
+
if case is not None:
|
| 142 |
+
self.cases.append(case)
|
| 143 |
+
|
| 144 |
+
logger.info(f"Loaded {len(self.cases)} MIDAS cases")
|
| 145 |
+
return self.cases
|
| 146 |
+
|
| 147 |
+
def _find_metadata_file(self) -> Path | None:
|
| 148 |
+
"""Find the metadata file (xlsx, csv, or json)."""
|
| 149 |
+
# Try xlsx first (actual MIDAS format)
|
| 150 |
+
for name in ["release_midas.xlsx", "metadata.xlsx"]:
|
| 151 |
+
p = self.data_dir / name
|
| 152 |
+
if p.exists():
|
| 153 |
+
return p
|
| 154 |
+
# Then CSV
|
| 155 |
+
for name in ["metadata.csv", "labels.csv", "midas_metadata.csv"]:
|
| 156 |
+
p = self.data_dir / name
|
| 157 |
+
if p.exists():
|
| 158 |
+
return p
|
| 159 |
+
# Then JSON
|
| 160 |
+
for name in ["metadata.json", "labels.json"]:
|
| 161 |
+
p = self.data_dir / name
|
| 162 |
+
if p.exists():
|
| 163 |
+
return p
|
| 164 |
+
# Glob fallback
|
| 165 |
+
for pattern in ["*.xlsx", "*.csv"]:
|
| 166 |
+
matches = list(self.data_dir.glob(pattern))
|
| 167 |
+
if matches:
|
| 168 |
+
return matches[0]
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
def _load_metadata(self, path: Path) -> list[dict]:
|
| 172 |
+
"""Load metadata from xlsx, csv, or json."""
|
| 173 |
+
if path.suffix == ".xlsx":
|
| 174 |
+
return self._load_xlsx(path)
|
| 175 |
+
elif path.suffix == ".json":
|
| 176 |
+
with open(path, encoding="utf-8") as f:
|
| 177 |
+
return json.load(f)
|
| 178 |
+
else:
|
| 179 |
+
with open(path, newline="", encoding="utf-8-sig") as f:
|
| 180 |
+
reader = csv.DictReader(f)
|
| 181 |
+
return list(reader)
|
| 182 |
+
|
| 183 |
+
def _load_xlsx(self, path: Path) -> list[dict]:
|
| 184 |
+
"""Load metadata from Excel file."""
|
| 185 |
+
import openpyxl
|
| 186 |
+
wb = openpyxl.load_workbook(path, read_only=True)
|
| 187 |
+
ws = wb[wb.sheetnames[0]]
|
| 188 |
+
rows = list(ws.iter_rows(values_only=True))
|
| 189 |
+
wb.close()
|
| 190 |
+
|
| 191 |
+
if not rows:
|
| 192 |
+
return []
|
| 193 |
+
headers = [str(h) if h is not None else f"col_{i}" for i, h in enumerate(rows[0])]
|
| 194 |
+
return [dict(zip(headers, row)) for row in rows[1:]]
|
| 195 |
+
|
| 196 |
+
def _get_diagnosis(self, record: dict) -> str | None:
|
| 197 |
+
"""Extract canonical diagnosis from a record."""
|
| 198 |
+
raw_path = record.get("midas_path", record.get("diagnosis", ""))
|
| 199 |
+
if not raw_path or raw_path == "None" or raw_path is None:
|
| 200 |
+
return None
|
| 201 |
+
raw_path = str(raw_path).strip().lower()
|
| 202 |
+
# Try exact match in mapping
|
| 203 |
+
for key, canonical in PATH_TO_DIAGNOSIS.items():
|
| 204 |
+
if key.lower() == raw_path:
|
| 205 |
+
return canonical
|
| 206 |
+
# Fuzzy fallback
|
| 207 |
+
if "melanoma" in raw_path:
|
| 208 |
+
return "melanoma_invasive"
|
| 209 |
+
if "bcc" in raw_path or "basal" in raw_path:
|
| 210 |
+
return "basal_cell_carcinoma"
|
| 211 |
+
if "sccis" in raw_path:
|
| 212 |
+
return "squamous_cell_carcinoma_in_situ"
|
| 213 |
+
if "scc" in raw_path or "squamous" in raw_path:
|
| 214 |
+
return "squamous_cell_carcinoma"
|
| 215 |
+
if "nevus" in raw_path or "melanocytic" in raw_path:
|
| 216 |
+
return "melanocytic_nevus"
|
| 217 |
+
if "seborrheic" in raw_path or "keratosis" in raw_path:
|
| 218 |
+
return "seborrheic_keratosis"
|
| 219 |
+
if "ak" in raw_path or "actinic" in raw_path:
|
| 220 |
+
return "actinic_keratosis"
|
| 221 |
+
if "benign" in raw_path:
|
| 222 |
+
return "benign_other"
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
def _find_image_by_filename(self, filename: str) -> Path | None:
|
| 226 |
+
"""Find an image by its filename in the images directory."""
|
| 227 |
+
if not filename:
|
| 228 |
+
return None
|
| 229 |
+
# Try images/ subdir, then root, case-insensitive
|
| 230 |
+
search_dirs = [
|
| 231 |
+
self.data_dir / "images",
|
| 232 |
+
self.data_dir,
|
| 233 |
+
]
|
| 234 |
+
for d in search_dirs:
|
| 235 |
+
if not d.exists():
|
| 236 |
+
continue
|
| 237 |
+
p = d / filename
|
| 238 |
+
if p.exists():
|
| 239 |
+
return p
|
| 240 |
+
# Case-insensitive search
|
| 241 |
+
for ext_p in d.iterdir():
|
| 242 |
+
if ext_p.name.lower() == filename.lower():
|
| 243 |
+
return ext_p
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
def _build_case(
|
| 247 |
+
self,
|
| 248 |
+
record_id: str,
|
| 249 |
+
records: list[dict],
|
| 250 |
+
dx_counter: Counter,
|
| 251 |
+
) -> MedicalCase | None:
|
| 252 |
+
"""Convert a lesion's grouped records into a MedicalCase."""
|
| 253 |
+
# Use first non-control record for metadata
|
| 254 |
+
primary = None
|
| 255 |
+
for r in records:
|
| 256 |
+
if str(r.get("midas_iscontrol", "no")).lower() != "yes":
|
| 257 |
+
dx = self._get_diagnosis(r)
|
| 258 |
+
if dx:
|
| 259 |
+
primary = r
|
| 260 |
+
break
|
| 261 |
+
if primary is None:
|
| 262 |
+
return None # Skip control-only lesions
|
| 263 |
+
|
| 264 |
+
diagnosis = self._get_diagnosis(primary)
|
| 265 |
+
if not diagnosis:
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
# ---- Build channels from all records in this lesion group ----
|
| 269 |
+
all_channels = {}
|
| 270 |
+
|
| 271 |
+
# Group images by modality
|
| 272 |
+
for r in records:
|
| 273 |
+
if str(r.get("midas_iscontrol", "no")).lower() == "yes":
|
| 274 |
+
continue
|
| 275 |
+
distance = str(r.get("midas_distance", "")).strip().lower()
|
| 276 |
+
channel_name = DISTANCE_TO_CHANNEL.get(distance)
|
| 277 |
+
if not channel_name:
|
| 278 |
+
continue
|
| 279 |
+
if channel_name in all_channels:
|
| 280 |
+
continue # Already have this modality
|
| 281 |
+
|
| 282 |
+
filename = r.get("midas_file_name", "")
|
| 283 |
+
img_path = self._find_image_by_filename(filename)
|
| 284 |
+
if img_path is None:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
try:
|
| 288 |
+
img_b64 = encode_image_to_base64(img_path)
|
| 289 |
+
except Exception:
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
ch_meta = config.get_channel_definition("midas", channel_name)
|
| 293 |
+
descriptions = {
|
| 294 |
+
"clinical_30cm": "Clinical photograph at 30cm distance",
|
| 295 |
+
"clinical_15cm": "Clinical photograph at 15cm distance (closer view)",
|
| 296 |
+
"dermoscopy": "Dermoscopic image showing subsurface skin structures",
|
| 297 |
+
}
|
| 298 |
+
all_channels[channel_name] = ChannelData(
|
| 299 |
+
name=channel_name,
|
| 300 |
+
channel_type="image",
|
| 301 |
+
description=descriptions.get(channel_name, channel_name),
|
| 302 |
+
value=img_b64,
|
| 303 |
+
image_path=img_path,
|
| 304 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 305 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 306 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Patient demographics
|
| 310 |
+
age = primary.get("midas_age", primary.get("age", ""))
|
| 311 |
+
sex = primary.get("midas_gender", primary.get("sex", ""))
|
| 312 |
+
fitz = primary.get("midas_fitzpatrick", primary.get("fitzpatrick", ""))
|
| 313 |
+
ethnicity = primary.get("midas_ethnicity", "")
|
| 314 |
+
race = primary.get("midas_race", "")
|
| 315 |
+
if any([age, sex, fitz]):
|
| 316 |
+
demo_parts = []
|
| 317 |
+
if age:
|
| 318 |
+
demo_parts.append(f"Age: {age}")
|
| 319 |
+
if sex:
|
| 320 |
+
demo_parts.append(f"Sex: {sex}")
|
| 321 |
+
if fitz:
|
| 322 |
+
demo_parts.append(f"Fitzpatrick skin type: {fitz}")
|
| 323 |
+
if ethnicity and str(ethnicity).lower() not in ("no", "none", ""):
|
| 324 |
+
demo_parts.append(f"Ethnicity: {ethnicity}")
|
| 325 |
+
if race and str(race).lower() not in ("no", "none", ""):
|
| 326 |
+
demo_parts.append(f"Race: {race}")
|
| 327 |
+
ch_meta = config.get_channel_definition("midas", "patient_demographics")
|
| 328 |
+
all_channels["patient_demographics"] = ChannelData(
|
| 329 |
+
name="patient_demographics",
|
| 330 |
+
channel_type="text",
|
| 331 |
+
description="Patient age, sex, and Fitzpatrick skin type",
|
| 332 |
+
value="; ".join(demo_parts),
|
| 333 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 334 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 335 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Lesion metadata
|
| 339 |
+
location = primary.get("midas_location", primary.get("location", ""))
|
| 340 |
+
length = primary.get("length_(mm)", primary.get("length_mm", ""))
|
| 341 |
+
width = primary.get("width_(mm)", primary.get("width_mm", ""))
|
| 342 |
+
if any([location, length, width]):
|
| 343 |
+
meta_parts = []
|
| 344 |
+
if location:
|
| 345 |
+
meta_parts.append(f"Anatomic location: {location}")
|
| 346 |
+
if length:
|
| 347 |
+
meta_parts.append(f"Lesion length: {length}mm")
|
| 348 |
+
if width:
|
| 349 |
+
meta_parts.append(f"Lesion width: {width}mm")
|
| 350 |
+
ch_meta = config.get_channel_definition("midas", "lesion_metadata")
|
| 351 |
+
all_channels["lesion_metadata"] = ChannelData(
|
| 352 |
+
name="lesion_metadata",
|
| 353 |
+
channel_type="text",
|
| 354 |
+
description="Anatomic location, lesion length and width",
|
| 355 |
+
value="; ".join(meta_parts),
|
| 356 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 357 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 358 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if not all_channels:
|
| 362 |
+
return None
|
| 363 |
+
|
| 364 |
+
initial_channels = {
|
| 365 |
+
name: ch for name, ch in all_channels.items() if ch.always_given
|
| 366 |
+
}
|
| 367 |
+
requestable = {
|
| 368 |
+
name: ch for name, ch in all_channels.items() if not ch.always_given
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
if not initial_channels and not requestable:
|
| 372 |
+
return None
|
| 373 |
+
|
| 374 |
+
# ---- Build candidate list (correct + distractors) ----
|
| 375 |
+
case_id = f"midas_{record_id}"
|
| 376 |
+
candidates = self._generate_candidates(diagnosis, dx_counter, case_id)
|
| 377 |
+
|
| 378 |
+
if diagnosis not in candidates:
|
| 379 |
+
logger.warning(f"Ground truth '{diagnosis}' not in candidate list for {case_id}, forcing inclusion")
|
| 380 |
+
candidates[0] = diagnosis
|
| 381 |
+
rng = _case_rng(case_id)
|
| 382 |
+
rng.shuffle(candidates)
|
| 383 |
+
|
| 384 |
+
return MedicalCase(
|
| 385 |
+
case_id=case_id,
|
| 386 |
+
dataset="midas",
|
| 387 |
+
initial_channels=initial_channels,
|
| 388 |
+
requestable_channels=requestable,
|
| 389 |
+
candidates=candidates,
|
| 390 |
+
ground_truth=diagnosis,
|
| 391 |
+
ground_truth_rank=candidates.index(diagnosis),
|
| 392 |
+
metadata={
|
| 393 |
+
"lesion_id": record_id,
|
| 394 |
+
"original_record": {k: str(v) for k, v in primary.items()
|
| 395 |
+
if k not in ("image", "img")},
|
| 396 |
+
},
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def _generate_candidates(self, correct_dx: str, dx_counter: Counter, case_id: str) -> list[str]:
|
| 400 |
+
"""
|
| 401 |
+
Generate N candidate diagnoses: 1 correct + (N-1) distractors.
|
| 402 |
+
|
| 403 |
+
Uses a per-case deterministic RNG for reproducibility across conditions.
|
| 404 |
+
Distractors are sampled to be clinically plausible:
|
| 405 |
+
- At least one from the same diagnostic group
|
| 406 |
+
- Others from different groups, weighted by dataset frequency
|
| 407 |
+
"""
|
| 408 |
+
n = self.n_candidates
|
| 409 |
+
rng = _case_rng(case_id)
|
| 410 |
+
|
| 411 |
+
# Find which group the correct dx belongs to
|
| 412 |
+
correct_group = None
|
| 413 |
+
for group_name, members in MIDAS_DIAGNOSIS_GROUPS.items():
|
| 414 |
+
if correct_dx in members:
|
| 415 |
+
correct_group = group_name
|
| 416 |
+
break
|
| 417 |
+
|
| 418 |
+
distractors = set()
|
| 419 |
+
|
| 420 |
+
# Add one same-group distractor if possible
|
| 421 |
+
if correct_group:
|
| 422 |
+
same_group = [d for d in MIDAS_DIAGNOSIS_GROUPS[correct_group] if d != correct_dx]
|
| 423 |
+
if same_group:
|
| 424 |
+
distractors.add(rng.choice(same_group))
|
| 425 |
+
|
| 426 |
+
# Fill rest from other groups, weighted by frequency
|
| 427 |
+
other_dx = [d for d in ALL_DIAGNOSES if d != correct_dx and d not in distractors]
|
| 428 |
+
weights = [dx_counter.get(d, 1) for d in other_dx]
|
| 429 |
+
total_w = sum(weights)
|
| 430 |
+
weights = [w / total_w for w in weights]
|
| 431 |
+
|
| 432 |
+
while len(distractors) < n - 1 and other_dx:
|
| 433 |
+
pick = rng.choices(other_dx, weights=weights, k=1)[0]
|
| 434 |
+
distractors.add(pick)
|
| 435 |
+
idx = other_dx.index(pick)
|
| 436 |
+
other_dx.pop(idx)
|
| 437 |
+
weights.pop(idx)
|
| 438 |
+
if weights:
|
| 439 |
+
total_w = sum(weights)
|
| 440 |
+
weights = [w / total_w for w in weights]
|
| 441 |
+
|
| 442 |
+
candidates = [correct_dx] + list(distractors)
|
| 443 |
+
rng.shuffle(candidates)
|
| 444 |
+
return candidates[:n]
|
datasets/nejm.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NEJM Image Challenge Dataset Loader.
|
| 3 |
+
|
| 4 |
+
Expects the cx0/nejm-image-challenge dataset structure:
|
| 5 |
+
nejm/
|
| 6 |
+
├── data.json (or nejm_data.json)
|
| 7 |
+
│ Each entry: {date, image_url, prompt (clinical vignette),
|
| 8 |
+
│ options [A..E], correct_answer, votes}
|
| 9 |
+
├── images/ (downloaded images, named by date YYYYMMDD.jpg)
|
| 10 |
+
└── parsed_vignettes.json (pre-parsed structured fields, optional)
|
| 11 |
+
|
| 12 |
+
The clinical vignette is decomposed into 5 requestable text channels
|
| 13 |
+
using LLM-based parsing (see scripts/parse_nejm_vignettes.py).
|
| 14 |
+
"""
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import random
|
| 18 |
+
import re
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
from .base import DatasetBase, MedicalCase, ChannelData
|
| 22 |
+
from api_client import encode_image_to_base64
|
| 23 |
+
import config
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# ---- Vignette parsing schema ----
|
| 28 |
+
VIGNETTE_FIELDS = [
|
| 29 |
+
"demographics",
|
| 30 |
+
"chief_complaint",
|
| 31 |
+
"medical_history",
|
| 32 |
+
"exam_findings",
|
| 33 |
+
"investigations",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
VIGNETTE_PARSE_PROMPT = """You are a medical data extraction system. Parse the following clinical \
|
| 37 |
+
vignette into exactly 5 structured fields. Extract ONLY information that is explicitly stated. \
|
| 38 |
+
If a field has no relevant information, write "Not mentioned."
|
| 39 |
+
|
| 40 |
+
FIELDS:
|
| 41 |
+
1. demographics: Patient age, sex, race/ethnicity if stated.
|
| 42 |
+
2. chief_complaint: The primary presenting symptom(s) and their duration.
|
| 43 |
+
3. medical_history: Past medical conditions, medications, surgical history, family history, social history (smoking, alcohol, etc.).
|
| 44 |
+
4. exam_findings: Physical examination findings, vital signs.
|
| 45 |
+
5. investigations: Laboratory results, imaging findings, test results (anything with numbers or test names).
|
| 46 |
+
|
| 47 |
+
CLINICAL VIGNETTE:
|
| 48 |
+
{vignette}
|
| 49 |
+
|
| 50 |
+
Respond in EXACTLY this JSON format (no markdown, no extra text):
|
| 51 |
+
{{"demographics": "...", "chief_complaint": "...", "medical_history": "...", "exam_findings": "...", "investigations": "..."}}"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class NEJMDataset(DatasetBase):
|
| 55 |
+
"""Loader for NEJM Image Challenge dataset."""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
data_dir: str | Path = None,
|
| 60 |
+
split: str = "test",
|
| 61 |
+
vlm_client=None,
|
| 62 |
+
use_cached_parse: bool = True,
|
| 63 |
+
):
|
| 64 |
+
super().__init__(data_dir or config.DATASET_PATHS["nejm"], split)
|
| 65 |
+
self.vlm_client = vlm_client
|
| 66 |
+
self.use_cached_parse = use_cached_parse
|
| 67 |
+
self._parsed_cache_path = self.data_dir / "parsed_vignettes.json"
|
| 68 |
+
|
| 69 |
+
def get_name(self) -> str:
|
| 70 |
+
return "nejm"
|
| 71 |
+
|
| 72 |
+
def load(self) -> list[MedicalCase]:
|
| 73 |
+
logger.info(f"Loading NEJM dataset from {self.data_dir}")
|
| 74 |
+
|
| 75 |
+
# ---- Load raw data ----
|
| 76 |
+
raw_data = self._load_raw_data()
|
| 77 |
+
if not raw_data:
|
| 78 |
+
return []
|
| 79 |
+
logger.info(f"Found {len(raw_data)} NEJM cases")
|
| 80 |
+
|
| 81 |
+
# ---- Load or create parsed vignettes ----
|
| 82 |
+
parsed = self._load_or_parse_vignettes(raw_data)
|
| 83 |
+
|
| 84 |
+
# ---- Build cases ----
|
| 85 |
+
self.cases = []
|
| 86 |
+
for entry in raw_data:
|
| 87 |
+
case_id = entry.get("date", entry.get("id", "unknown"))
|
| 88 |
+
case = self._build_case(entry, parsed.get(case_id, {}))
|
| 89 |
+
if case is not None:
|
| 90 |
+
self.cases.append(case)
|
| 91 |
+
|
| 92 |
+
logger.info(f"Loaded {len(self.cases)} NEJM cases")
|
| 93 |
+
return self.cases
|
| 94 |
+
|
| 95 |
+
def _load_raw_data(self) -> list[dict]:
|
| 96 |
+
"""Load the raw NEJM dataset JSON."""
|
| 97 |
+
for name in ["data.json", "nejm_data.json", "nejm.json", "dataset.json"]:
|
| 98 |
+
p = self.data_dir / name
|
| 99 |
+
if p.exists():
|
| 100 |
+
with open(p, encoding="utf-8") as f:
|
| 101 |
+
data = json.load(f)
|
| 102 |
+
if isinstance(data, dict):
|
| 103 |
+
# Handle {date: entry} format
|
| 104 |
+
return [{"date": k, **v} if isinstance(v, dict) else v
|
| 105 |
+
for k, v in data.items()]
|
| 106 |
+
return data
|
| 107 |
+
# Try loading all JSON files
|
| 108 |
+
jsons = list(self.data_dir.glob("*.json"))
|
| 109 |
+
if jsons:
|
| 110 |
+
with open(jsons[0], encoding="utf-8") as f:
|
| 111 |
+
return json.load(f)
|
| 112 |
+
logger.error(f"No data file found in {self.data_dir}")
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
def _load_or_parse_vignettes(self, raw_data: list[dict]) -> dict:
|
| 116 |
+
"""Load cached parsed vignettes or parse them with LLM."""
|
| 117 |
+
# Try cache first
|
| 118 |
+
if self.use_cached_parse and self._parsed_cache_path.exists():
|
| 119 |
+
logger.info(f"Loading cached vignette parses from {self._parsed_cache_path}")
|
| 120 |
+
with open(self._parsed_cache_path) as f:
|
| 121 |
+
return json.load(f)
|
| 122 |
+
|
| 123 |
+
# Parse with LLM if client is available
|
| 124 |
+
if self.vlm_client is not None:
|
| 125 |
+
logger.info("Parsing vignettes with LLM (this may take a while)...")
|
| 126 |
+
parsed = {}
|
| 127 |
+
for entry in raw_data:
|
| 128 |
+
case_id = entry.get("date", entry.get("id", "unknown"))
|
| 129 |
+
vignette = entry.get("question", entry.get("prompt", entry.get("vignette", "")))
|
| 130 |
+
if vignette:
|
| 131 |
+
parsed[case_id] = self._parse_vignette_with_llm(vignette)
|
| 132 |
+
# Cache results
|
| 133 |
+
with open(self._parsed_cache_path, "w") as f:
|
| 134 |
+
json.dump(parsed, f, indent=2)
|
| 135 |
+
logger.info(f"Cached {len(parsed)} parsed vignettes")
|
| 136 |
+
return parsed
|
| 137 |
+
|
| 138 |
+
# Fallback: rule-based parsing
|
| 139 |
+
logger.info("No LLM client available. Using rule-based vignette parsing (less accurate).")
|
| 140 |
+
parsed = {}
|
| 141 |
+
for entry in raw_data:
|
| 142 |
+
case_id = entry.get("date", entry.get("id", "unknown"))
|
| 143 |
+
vignette = entry.get("question", entry.get("prompt", entry.get("vignette", "")))
|
| 144 |
+
if vignette:
|
| 145 |
+
parsed[case_id] = self._parse_vignette_rules(vignette)
|
| 146 |
+
return parsed
|
| 147 |
+
|
| 148 |
+
def _parse_vignette_with_llm(self, vignette: str) -> dict:
|
| 149 |
+
"""Parse a single vignette using the LLM API."""
|
| 150 |
+
prompt = VIGNETTE_PARSE_PROMPT.format(vignette=vignette)
|
| 151 |
+
try:
|
| 152 |
+
response = self.vlm_client.call_with_retry(
|
| 153 |
+
system_prompt="You are a medical data extraction system. Respond only with valid JSON.",
|
| 154 |
+
user_text=prompt,
|
| 155 |
+
images=None,
|
| 156 |
+
temperature=0.0,
|
| 157 |
+
max_tokens=1024,
|
| 158 |
+
)
|
| 159 |
+
# Parse JSON from response
|
| 160 |
+
text = response.text.strip()
|
| 161 |
+
# Strip markdown code fences if present
|
| 162 |
+
text = re.sub(r"^```(?:json)?\s*", "", text)
|
| 163 |
+
text = re.sub(r"\s*```$", "", text)
|
| 164 |
+
parsed = json.loads(text)
|
| 165 |
+
# Validate expected fields
|
| 166 |
+
for field in VIGNETTE_FIELDS:
|
| 167 |
+
if field not in parsed:
|
| 168 |
+
parsed[field] = "Not mentioned."
|
| 169 |
+
return parsed
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.warning(f"LLM vignette parsing failed: {e}. Falling back to rules.")
|
| 172 |
+
return self._parse_vignette_rules(vignette)
|
| 173 |
+
|
| 174 |
+
def _parse_vignette_rules(self, vignette: str) -> dict:
|
| 175 |
+
"""
|
| 176 |
+
Rule-based fallback for vignette parsing.
|
| 177 |
+
Uses heuristic sentence classification.
|
| 178 |
+
"""
|
| 179 |
+
result = {f: "" for f in VIGNETTE_FIELDS}
|
| 180 |
+
sentences = re.split(r'(?<=[.!?])\s+', vignette)
|
| 181 |
+
|
| 182 |
+
# Patterns for classification
|
| 183 |
+
demo_pattern = re.compile(
|
| 184 |
+
r'\b(\d{1,3})[-\s]year[-\s]old\b|'
|
| 185 |
+
r'\b(male|female|man|woman|boy|girl)\b',
|
| 186 |
+
re.IGNORECASE,
|
| 187 |
+
)
|
| 188 |
+
complaint_pattern = re.compile(
|
| 189 |
+
r'\bpresent(?:s|ed|ing)\b|\bcomplain(?:s|ed|ing)\b|\breport(?:s|ed|ing)\b|'
|
| 190 |
+
r'\bseek(?:s|ing)\b|\badmitted\b',
|
| 191 |
+
re.IGNORECASE,
|
| 192 |
+
)
|
| 193 |
+
history_pattern = re.compile(
|
| 194 |
+
r'\bhistory\b|\bprevious(?:ly)?\b|\bmedication\b|\btaking\b|\bdiagnosed\b|'
|
| 195 |
+
r'\bsmok(?:es|ing|er)\b|\balcohol\b|\bfamily\b|\bsurgery\b',
|
| 196 |
+
re.IGNORECASE,
|
| 197 |
+
)
|
| 198 |
+
exam_pattern = re.compile(
|
| 199 |
+
r'\bexamination\b|\bexam\b|\bpalpat(?:ion|ed)\b|\bauscult(?:ation|ed)\b|'
|
| 200 |
+
r'\bvital\b|\bblood\s+pressure\b|\bheart\s+rate\b|\btemperature\b|'
|
| 201 |
+
r'\bappears\b|\btender\b|\bswollen\b|\berythema\b',
|
| 202 |
+
re.IGNORECASE,
|
| 203 |
+
)
|
| 204 |
+
invest_pattern = re.compile(
|
| 205 |
+
r'\b(?:hemoglobin|WBC|platelet|creatinine|BUN|glucose|sodium|potassium)\b|'
|
| 206 |
+
r'\b(?:CT|MRI|X[-\s]?ray|ultrasound|ECG|EKG|biopsy)\b|'
|
| 207 |
+
r'\b\d+\.?\d*\s*(?:mg|g|mL|mmol|mEq|U|IU|mmHg|\/dL|\/L)\b|'
|
| 208 |
+
r'\blaboratory\b|\blab(?:s)?\b|\btest\b|\blevel\b|\bfinding\b',
|
| 209 |
+
re.IGNORECASE,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
for sent in sentences:
|
| 213 |
+
sent = sent.strip()
|
| 214 |
+
if not sent:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
# Demographics: typically the first sentence
|
| 218 |
+
if demo_pattern.search(sent) and not result["demographics"]:
|
| 219 |
+
result["demographics"] = sent
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
# Check each pattern (a sentence can match multiple, take first)
|
| 223 |
+
matched = False
|
| 224 |
+
for field, pattern in [
|
| 225 |
+
("investigations", invest_pattern),
|
| 226 |
+
("exam_findings", exam_pattern),
|
| 227 |
+
("medical_history", history_pattern),
|
| 228 |
+
("chief_complaint", complaint_pattern),
|
| 229 |
+
]:
|
| 230 |
+
if pattern.search(sent):
|
| 231 |
+
if result[field]:
|
| 232 |
+
result[field] += " " + sent
|
| 233 |
+
else:
|
| 234 |
+
result[field] = sent
|
| 235 |
+
matched = True
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
# Unmatched sentences go to chief_complaint as default
|
| 239 |
+
if not matched:
|
| 240 |
+
if result["chief_complaint"]:
|
| 241 |
+
result["chief_complaint"] += " " + sent
|
| 242 |
+
else:
|
| 243 |
+
result["chief_complaint"] = sent
|
| 244 |
+
|
| 245 |
+
# Replace empty fields
|
| 246 |
+
for field in VIGNETTE_FIELDS:
|
| 247 |
+
if not result[field].strip():
|
| 248 |
+
result[field] = "Not mentioned."
|
| 249 |
+
|
| 250 |
+
return result
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def _date_to_yyyymmdd(date_str: str) -> str | None:
|
| 254 |
+
"""Convert 'apr-01-2010' style date to '20100401' for image lookup."""
|
| 255 |
+
from datetime import datetime
|
| 256 |
+
for fmt in ("%b-%d-%Y", "%B-%d-%Y", "%Y-%m-%d", "%Y%m%d"):
|
| 257 |
+
try:
|
| 258 |
+
dt = datetime.strptime(date_str, fmt)
|
| 259 |
+
return dt.strftime("%Y%m%d")
|
| 260 |
+
except ValueError:
|
| 261 |
+
continue
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
+
def _build_case(self, entry: dict, parsed_vignette: dict) -> MedicalCase | None:
|
| 265 |
+
"""Convert a raw NEJM entry + parsed vignette into a MedicalCase."""
|
| 266 |
+
case_id = entry.get("date", entry.get("id", "unknown"))
|
| 267 |
+
|
| 268 |
+
# ---- Find image ----
|
| 269 |
+
img_b64 = None
|
| 270 |
+
img_dir = self.data_dir / "images"
|
| 271 |
+
# Build candidate filenames: original case_id + YYYYMMDD conversion
|
| 272 |
+
name_candidates = [case_id]
|
| 273 |
+
yyyymmdd = self._date_to_yyyymmdd(case_id)
|
| 274 |
+
if yyyymmdd:
|
| 275 |
+
name_candidates.append(yyyymmdd)
|
| 276 |
+
|
| 277 |
+
if img_dir.exists():
|
| 278 |
+
for name in name_candidates:
|
| 279 |
+
for ext in [".jpg", ".jpeg", ".png"]:
|
| 280 |
+
p = img_dir / f"{name}{ext}"
|
| 281 |
+
if p.exists():
|
| 282 |
+
try:
|
| 283 |
+
img_b64 = encode_image_to_base64(p)
|
| 284 |
+
except Exception:
|
| 285 |
+
pass
|
| 286 |
+
break
|
| 287 |
+
if img_b64 is not None:
|
| 288 |
+
break
|
| 289 |
+
if img_b64 is None:
|
| 290 |
+
# Glob for any match
|
| 291 |
+
for name in name_candidates:
|
| 292 |
+
matches = list(img_dir.glob(f"*{name}*"))
|
| 293 |
+
if matches:
|
| 294 |
+
try:
|
| 295 |
+
img_b64 = encode_image_to_base64(matches[0])
|
| 296 |
+
except Exception:
|
| 297 |
+
pass
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
# ---- Build all available channels, then split by config ----
|
| 301 |
+
all_channels = {}
|
| 302 |
+
if img_b64 is not None:
|
| 303 |
+
image_meta = config.get_channel_definition("nejm", "image")
|
| 304 |
+
all_channels["image"] = ChannelData(
|
| 305 |
+
name="image",
|
| 306 |
+
channel_type="image",
|
| 307 |
+
description="The primary diagnostic image",
|
| 308 |
+
value=img_b64,
|
| 309 |
+
cost=float(image_meta.get("cost", 0.0)),
|
| 310 |
+
tier=image_meta.get("tier", "unknown"),
|
| 311 |
+
always_given=bool(image_meta.get("always_given", False)),
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
field_descriptions = {
|
| 315 |
+
"demographics": "Patient age, sex, and ethnicity if mentioned",
|
| 316 |
+
"chief_complaint": "The presenting symptom(s) and their duration",
|
| 317 |
+
"medical_history": "Past medical conditions, medications, family and social history",
|
| 318 |
+
"exam_findings": "Physical examination results and observations",
|
| 319 |
+
"investigations": "Laboratory values, prior imaging results, and test outcomes",
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
for field in VIGNETTE_FIELDS:
|
| 323 |
+
value = parsed_vignette.get(field, "Not mentioned.")
|
| 324 |
+
field_meta = config.get_channel_definition("nejm", field)
|
| 325 |
+
if value and value.strip() != "Not mentioned.":
|
| 326 |
+
all_channels[field] = ChannelData(
|
| 327 |
+
name=field,
|
| 328 |
+
channel_type="text",
|
| 329 |
+
description=field_descriptions.get(field, field),
|
| 330 |
+
value=value,
|
| 331 |
+
cost=float(field_meta.get("cost", 0.0)),
|
| 332 |
+
tier=field_meta.get("tier", "unknown"),
|
| 333 |
+
always_given=bool(field_meta.get("always_given", False)),
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
all_channels[field] = ChannelData(
|
| 337 |
+
name=field,
|
| 338 |
+
channel_type="text",
|
| 339 |
+
description=field_descriptions.get(field, field),
|
| 340 |
+
value="No additional information available for this category.",
|
| 341 |
+
cost=float(field_meta.get("cost", 0.0)),
|
| 342 |
+
tier=field_meta.get("tier", "unknown"),
|
| 343 |
+
always_given=bool(field_meta.get("always_given", False)),
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
initial_channels = {
|
| 347 |
+
name: ch for name, ch in all_channels.items() if ch.always_given
|
| 348 |
+
}
|
| 349 |
+
requestable = {
|
| 350 |
+
name: ch for name, ch in all_channels.items() if not ch.always_given
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
if not initial_channels and not requestable:
|
| 354 |
+
logger.debug(f"Skipping NEJM {case_id}: no usable channels found")
|
| 355 |
+
return None
|
| 356 |
+
|
| 357 |
+
# ---- Candidates: the 5 MCQ options ----
|
| 358 |
+
options = entry.get("options", [])
|
| 359 |
+
correct = entry.get("correct_answer", entry.get("answer", ""))
|
| 360 |
+
|
| 361 |
+
# Handle flat option_A..option_E keys (cx0/nejm-image-challenge format)
|
| 362 |
+
if not options:
|
| 363 |
+
flat_options = {}
|
| 364 |
+
for letter in "ABCDE":
|
| 365 |
+
val = entry.get(f"option_{letter}", "")
|
| 366 |
+
if val:
|
| 367 |
+
flat_options[letter] = val
|
| 368 |
+
if flat_options:
|
| 369 |
+
options = flat_options
|
| 370 |
+
|
| 371 |
+
if isinstance(options, dict):
|
| 372 |
+
# {A: "...", B: "...", ...}
|
| 373 |
+
candidates = [f"{k}. {v}" for k, v in sorted(options.items())]
|
| 374 |
+
gt_label = None
|
| 375 |
+
for k, v in sorted(options.items()):
|
| 376 |
+
if k == correct:
|
| 377 |
+
gt_label = f"{k}. {v}"
|
| 378 |
+
break
|
| 379 |
+
if gt_label is None:
|
| 380 |
+
gt_label = candidates[0] if candidates else ""
|
| 381 |
+
elif isinstance(options, list) and options:
|
| 382 |
+
candidates = options
|
| 383 |
+
if isinstance(correct, int):
|
| 384 |
+
gt_label = options[correct] if correct < len(options) else options[0]
|
| 385 |
+
elif isinstance(correct, str) and len(correct) == 1:
|
| 386 |
+
# Letter answer (A=0, B=1, ...)
|
| 387 |
+
idx = ord(correct.upper()) - ord("A")
|
| 388 |
+
gt_label = options[idx] if idx < len(options) else options[0]
|
| 389 |
+
else:
|
| 390 |
+
gt_label = correct
|
| 391 |
+
else:
|
| 392 |
+
candidates = [correct] if correct else ["Unknown"]
|
| 393 |
+
gt_label = correct
|
| 394 |
+
|
| 395 |
+
# ---- Votes (physician response distribution) ----
|
| 396 |
+
votes = entry.get("votes", {})
|
| 397 |
+
# Handle flat vote keys (option_A_votes, etc.)
|
| 398 |
+
if not votes:
|
| 399 |
+
for letter in "ABCDE":
|
| 400 |
+
val = entry.get(f"option_{letter}_votes", "")
|
| 401 |
+
if val:
|
| 402 |
+
votes[letter] = val
|
| 403 |
+
|
| 404 |
+
return MedicalCase(
|
| 405 |
+
case_id=f"nejm_{case_id}",
|
| 406 |
+
dataset="nejm",
|
| 407 |
+
initial_channels=initial_channels,
|
| 408 |
+
requestable_channels=requestable,
|
| 409 |
+
candidates=candidates,
|
| 410 |
+
ground_truth=gt_label,
|
| 411 |
+
ground_truth_rank=(candidates.index(gt_label) if gt_label in candidates else 0),
|
| 412 |
+
metadata={
|
| 413 |
+
"date": case_id,
|
| 414 |
+
"votes": votes,
|
| 415 |
+
"full_vignette": entry.get("question", entry.get("prompt", entry.get("vignette", ""))),
|
| 416 |
+
"parsed_fields": parsed_vignette,
|
| 417 |
+
},
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
def get_human_difficulty(self, case: MedicalCase) -> float | None:
|
| 421 |
+
"""
|
| 422 |
+
Compute human difficulty score from physician vote distribution.
|
| 423 |
+
|
| 424 |
+
Returns: proportion of physicians who answered correctly (0-1),
|
| 425 |
+
or None if votes unavailable.
|
| 426 |
+
"""
|
| 427 |
+
votes = case.metadata.get("votes", {})
|
| 428 |
+
if not votes:
|
| 429 |
+
return None
|
| 430 |
+
correct_key = case.metadata.get("date", "")
|
| 431 |
+
# votes might be {A: 0.12, B: 0.65, ...} or {A: 120, B: 650, ...}
|
| 432 |
+
total = sum(float(v) for v in votes.values())
|
| 433 |
+
if total == 0:
|
| 434 |
+
return None
|
| 435 |
+
# Find the correct answer key
|
| 436 |
+
gt = case.ground_truth
|
| 437 |
+
for key, val in votes.items():
|
| 438 |
+
if key in gt or gt.startswith(key):
|
| 439 |
+
return float(val) / total if total > 1 else float(val)
|
| 440 |
+
return None
|
datasets/olives.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OLIVES Dataset Loader.
|
| 3 |
+
|
| 4 |
+
Adapted for the actual Zenodo OLIVES dataset structure:
|
| 5 |
+
data/
|
| 6 |
+
├── OLIVES/OLIVES/
|
| 7 |
+
│ ├── Prime_FULL/Prime_FULL/ (DR patients — OCT B-scans)
|
| 8 |
+
│ │ └── <patient_id>/<visit>/<eye>/*.png
|
| 9 |
+
│ └── TREX_DME/TREX DME/ (DME patients — OCT B-scans)
|
| 10 |
+
│ └── <arm>/<patient_id>/<visit>/<eye>/*.tif
|
| 11 |
+
└── OLIVES_Dataset_Labels/OLIVES_Dataset_Labels/
|
| 12 |
+
└── full_labels/Biomarker_Clinical_Data_Images.csv
|
| 13 |
+
|
| 14 |
+
Task: Biomarker profile ranking.
|
| 15 |
+
- Given an OCT B-scan, rank candidate biomarker profiles
|
| 16 |
+
- Each profile is a subset of the 16 annotated biomarkers
|
| 17 |
+
- Correct profile = actual biomarker vector for this eye
|
| 18 |
+
- Distractors = profiles from other eyes
|
| 19 |
+
|
| 20 |
+
Channels:
|
| 21 |
+
- Initial: single OCT B-scan (middle slice)
|
| 22 |
+
- Requestable: additional OCT slices, clinical measurements (BCVA/CST),
|
| 23 |
+
biomarker hints (fundus-visible subset), treatment history
|
| 24 |
+
"""
|
| 25 |
+
import csv
|
| 26 |
+
import hashlib
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import random
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from collections import defaultdict
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
from .base import DatasetBase, MedicalCase, ChannelData
|
| 36 |
+
from api_client import encode_image_to_base64
|
| 37 |
+
import config
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
# The biomarker columns as they appear in the CSV
|
| 42 |
+
OLIVES_CSV_BIOMARKERS = {
|
| 43 |
+
"Fluid (IRF)": "fluid_irf",
|
| 44 |
+
"Fluid (SRF)": "fluid_srf",
|
| 45 |
+
"DRT/ME": "drt_me",
|
| 46 |
+
"SHRM": "shrm",
|
| 47 |
+
"Preretinal tissue/hemorrhage": "preretinal_tissue",
|
| 48 |
+
"Vitreous debris": "vitreous_debris",
|
| 49 |
+
"DRIL": "dril",
|
| 50 |
+
"Disruption of EZ": "ez_disruption",
|
| 51 |
+
"IR hemorrhages": "hemorrhage",
|
| 52 |
+
"IR HRF": "ir_hrf",
|
| 53 |
+
"Disruption of RPE": "rpe_disruption",
|
| 54 |
+
"PED (serous)": "ped_serous",
|
| 55 |
+
"Atrophy / thinning of retinal layers": "atrophy",
|
| 56 |
+
"VMT": "vmt",
|
| 57 |
+
"Partially attached vitreous face": "partial_vitreous",
|
| 58 |
+
"Fully attached vitreous face": "full_vitreous",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Canonical biomarker names for profiles
|
| 62 |
+
OLIVES_BIOMARKERS = sorted(OLIVES_CSV_BIOMARKERS.values())
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def biomarker_vector_to_profile_string(vector: dict[str, bool]) -> str:
|
| 66 |
+
"""Convert a biomarker dict to a human-readable profile string."""
|
| 67 |
+
present = [
|
| 68 |
+
name.replace("_", " ").title()
|
| 69 |
+
for name, val in sorted(vector.items()) if val
|
| 70 |
+
]
|
| 71 |
+
if not present:
|
| 72 |
+
return "No biomarkers detected"
|
| 73 |
+
return "Present biomarkers: " + ", ".join(present)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def compute_profile_distance(profile_a: dict, profile_b: dict) -> int:
|
| 77 |
+
"""Hamming distance between two biomarker profiles."""
|
| 78 |
+
dist = 0
|
| 79 |
+
for key in OLIVES_BIOMARKERS:
|
| 80 |
+
if profile_a.get(key, False) != profile_b.get(key, False):
|
| 81 |
+
dist += 1
|
| 82 |
+
return dist
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _case_rng(case_id: str) -> random.Random:
|
| 86 |
+
seed = int(hashlib.sha256(case_id.encode()).hexdigest()[:8], 16)
|
| 87 |
+
return random.Random(seed)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class OLIVESDataset(DatasetBase):
|
| 91 |
+
"""Loader for OLIVES ophthalmology dataset."""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
data_dir: str | Path = None,
|
| 96 |
+
split: str = "test",
|
| 97 |
+
n_candidates: int = 5,
|
| 98 |
+
n_oct_samples: int = 3,
|
| 99 |
+
):
|
| 100 |
+
super().__init__(data_dir or config.DATASET_PATHS["olives"], split)
|
| 101 |
+
self.n_candidates = n_candidates
|
| 102 |
+
self.n_oct_samples = n_oct_samples
|
| 103 |
+
|
| 104 |
+
def get_name(self) -> str:
|
| 105 |
+
return "olives"
|
| 106 |
+
|
| 107 |
+
def load(self) -> list[MedicalCase]:
|
| 108 |
+
logger.info(f"Loading OLIVES dataset from {self.data_dir}")
|
| 109 |
+
|
| 110 |
+
# ---- Find the CSV ----
|
| 111 |
+
csv_path = self._find_csv()
|
| 112 |
+
if csv_path is None:
|
| 113 |
+
logger.error("No biomarker CSV found")
|
| 114 |
+
return []
|
| 115 |
+
|
| 116 |
+
# ---- Load records ----
|
| 117 |
+
with open(csv_path, newline="", encoding="utf-8-sig") as f:
|
| 118 |
+
rows = list(csv.DictReader(f))
|
| 119 |
+
logger.info(f"Found {len(rows)} records in {csv_path.name}")
|
| 120 |
+
|
| 121 |
+
# ---- Find the image root ----
|
| 122 |
+
image_root = self._find_image_root()
|
| 123 |
+
if image_root is None:
|
| 124 |
+
logger.error("No image directory found")
|
| 125 |
+
return []
|
| 126 |
+
logger.info(f"Image root: {image_root}")
|
| 127 |
+
|
| 128 |
+
# ---- Group by eye ----
|
| 129 |
+
eye_groups = defaultdict(list)
|
| 130 |
+
for r in rows:
|
| 131 |
+
pid = r.get("Patient_ID", "")
|
| 132 |
+
path_str = r.get(
|
| 133 |
+
"Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
|
| 134 |
+
)
|
| 135 |
+
parts = path_str.strip("/").split("/")
|
| 136 |
+
if len(parts) >= 5:
|
| 137 |
+
eye = parts[4] # OD or OS
|
| 138 |
+
else:
|
| 139 |
+
eye = r.get("Eye_ID", "unknown")
|
| 140 |
+
eye_key = f"{pid}_{eye}"
|
| 141 |
+
r["_eye_key"] = eye_key
|
| 142 |
+
r["_path_parts"] = parts
|
| 143 |
+
eye_groups[eye_key].append(r)
|
| 144 |
+
|
| 145 |
+
logger.info(f"Found {len(eye_groups)} unique eyes")
|
| 146 |
+
|
| 147 |
+
# ---- Build biomarker profiles ----
|
| 148 |
+
all_profiles = {}
|
| 149 |
+
for eye_key, records in eye_groups.items():
|
| 150 |
+
latest = records[-1]
|
| 151 |
+
all_profiles[eye_key] = self._extract_biomarker_vector(latest)
|
| 152 |
+
|
| 153 |
+
# ---- Build cases ----
|
| 154 |
+
self.cases = []
|
| 155 |
+
for eye_key, records in eye_groups.items():
|
| 156 |
+
case = self._build_case(
|
| 157 |
+
eye_key, records, all_profiles, image_root
|
| 158 |
+
)
|
| 159 |
+
if case is not None:
|
| 160 |
+
self.cases.append(case)
|
| 161 |
+
|
| 162 |
+
logger.info(f"Loaded {len(self.cases)} OLIVES cases")
|
| 163 |
+
return self.cases
|
| 164 |
+
|
| 165 |
+
def _find_csv(self) -> Path | None:
|
| 166 |
+
"""Find the biomarker CSV in various locations."""
|
| 167 |
+
search_paths = [
|
| 168 |
+
self.data_dir / "Biomarker_Clinical_Data_Images.csv",
|
| 169 |
+
self.data_dir / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv",
|
| 170 |
+
self.data_dir.parent / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv",
|
| 171 |
+
]
|
| 172 |
+
for p in search_paths:
|
| 173 |
+
if p.exists():
|
| 174 |
+
return p
|
| 175 |
+
# Glob fallback
|
| 176 |
+
csvs = list(self.data_dir.rglob("Biomarker*Clinical*.csv"))
|
| 177 |
+
if csvs:
|
| 178 |
+
return csvs[0]
|
| 179 |
+
# Check parent
|
| 180 |
+
csvs = list(self.data_dir.parent.rglob("Biomarker*Clinical*.csv"))
|
| 181 |
+
if csvs:
|
| 182 |
+
return csvs[0]
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
def _find_image_root(self) -> Path | None:
|
| 186 |
+
"""Find the root directory containing Prime_FULL and TREX_DME."""
|
| 187 |
+
search = [
|
| 188 |
+
self.data_dir / "OLIVES",
|
| 189 |
+
self.data_dir / "OLIVES" / "OLIVES",
|
| 190 |
+
self.data_dir,
|
| 191 |
+
]
|
| 192 |
+
for d in search:
|
| 193 |
+
if (d / "Prime_FULL").exists() or (d / "TREX_DME").exists():
|
| 194 |
+
return d
|
| 195 |
+
# Search deeper
|
| 196 |
+
for p in self.data_dir.rglob("Prime_FULL"):
|
| 197 |
+
return p.parent
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
def _extract_biomarker_vector(self, record: dict) -> dict[str, bool]:
|
| 201 |
+
"""Extract biomarker vector from a CSV row."""
|
| 202 |
+
vector = {}
|
| 203 |
+
for csv_col, canonical_name in OLIVES_CSV_BIOMARKERS.items():
|
| 204 |
+
val = record.get(csv_col, "0")
|
| 205 |
+
if isinstance(val, str):
|
| 206 |
+
vector[canonical_name] = val.strip() == "1"
|
| 207 |
+
else:
|
| 208 |
+
vector[canonical_name] = bool(int(float(val or 0)))
|
| 209 |
+
return vector
|
| 210 |
+
|
| 211 |
+
def _find_oct_images(
|
| 212 |
+
self, records: list[dict], image_root: Path, n: int = 3
|
| 213 |
+
) -> list[Path]:
|
| 214 |
+
"""Find OCT B-scan images for an eye."""
|
| 215 |
+
# Try to locate images from the path in the CSV
|
| 216 |
+
for r in records:
|
| 217 |
+
path_str = r.get(
|
| 218 |
+
"Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
|
| 219 |
+
)
|
| 220 |
+
parts = path_str.strip("/").split("/")
|
| 221 |
+
if len(parts) < 5:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
# Construct search directory (without the image filename)
|
| 225 |
+
# Path format: /Trial/Arm/Patient/Visit/Eye/Image
|
| 226 |
+
trial = parts[0]
|
| 227 |
+
remaining = "/".join(parts[1:-1])
|
| 228 |
+
|
| 229 |
+
search_dirs = [
|
| 230 |
+
image_root / trial / remaining,
|
| 231 |
+
image_root / parts[0].replace(" ", "_") / remaining,
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
# For Prime: Prime_FULL/Prime_FULL/Patient/Visit/Eye/
|
| 235 |
+
if "Prime" in trial or "prime" in trial:
|
| 236 |
+
pid = parts[2] if len(parts) > 2 else ""
|
| 237 |
+
visit = parts[3] if len(parts) > 3 else ""
|
| 238 |
+
eye = parts[4] if len(parts) > 4 else ""
|
| 239 |
+
search_dirs.extend([
|
| 240 |
+
image_root / "Prime_FULL" / "Prime_FULL" / pid / visit / eye,
|
| 241 |
+
image_root / "Prime_FULL" / pid / visit / eye,
|
| 242 |
+
])
|
| 243 |
+
|
| 244 |
+
# For TREX: TREX_DME/TREX DME/Arm/Patient/Visit/Eye/
|
| 245 |
+
if "TREX" in trial:
|
| 246 |
+
arm = parts[1] if len(parts) > 1 else ""
|
| 247 |
+
pid = parts[2] if len(parts) > 2 else ""
|
| 248 |
+
visit = parts[3] if len(parts) > 3 else ""
|
| 249 |
+
eye = parts[4] if len(parts) > 4 else ""
|
| 250 |
+
search_dirs.extend([
|
| 251 |
+
image_root / "TREX_DME" / "TREX DME" / arm / pid / visit / eye,
|
| 252 |
+
image_root / "TREX_DME" / trial / arm / pid / visit / eye,
|
| 253 |
+
])
|
| 254 |
+
|
| 255 |
+
for d in search_dirs:
|
| 256 |
+
if not d.exists():
|
| 257 |
+
continue
|
| 258 |
+
images = sorted(
|
| 259 |
+
list(d.glob("*.png")) + list(d.glob("*.tif"))
|
| 260 |
+
+ list(d.glob("*.jpg"))
|
| 261 |
+
)
|
| 262 |
+
if images:
|
| 263 |
+
# Sample N evenly spaced scans
|
| 264 |
+
if len(images) <= n:
|
| 265 |
+
return images
|
| 266 |
+
indices = np.linspace(
|
| 267 |
+
0, len(images) - 1, n, dtype=int
|
| 268 |
+
)
|
| 269 |
+
return [images[i] for i in indices]
|
| 270 |
+
|
| 271 |
+
return []
|
| 272 |
+
|
| 273 |
+
def _build_case(
|
| 274 |
+
self,
|
| 275 |
+
eye_key: str,
|
| 276 |
+
records: list[dict],
|
| 277 |
+
all_profiles: dict[str, dict[str, bool]],
|
| 278 |
+
image_root: Path,
|
| 279 |
+
) -> MedicalCase | None:
|
| 280 |
+
"""Convert an eye's records into a MedicalCase."""
|
| 281 |
+
latest = records[-1]
|
| 282 |
+
|
| 283 |
+
# ---- Find OCT images ----
|
| 284 |
+
oct_images = self._find_oct_images(records, image_root, self.n_oct_samples + 1)
|
| 285 |
+
if not oct_images:
|
| 286 |
+
logger.debug(f"Skipping eye {eye_key}: no images found")
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
# Build all available channels, then split by config
|
| 290 |
+
all_channels = {}
|
| 291 |
+
|
| 292 |
+
# Use middle scan as canonical first-line OCT, rest as optional extras
|
| 293 |
+
mid_idx = len(oct_images) // 2
|
| 294 |
+
initial_image = oct_images[mid_idx]
|
| 295 |
+
additional_images = [
|
| 296 |
+
img for i, img in enumerate(oct_images) if i != mid_idx
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
initial_b64 = encode_image_to_base64(initial_image)
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.debug(f"Skipping eye {eye_key}: encode failed: {e}")
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
oct_meta = config.get_channel_definition("olives", "oct_scan")
|
| 306 |
+
all_channels["oct_scan"] = ChannelData(
|
| 307 |
+
name="oct_scan",
|
| 308 |
+
channel_type="image",
|
| 309 |
+
description="OCT B-scan showing retinal cross-section",
|
| 310 |
+
value=initial_b64,
|
| 311 |
+
image_path=initial_image,
|
| 312 |
+
cost=float(oct_meta.get("cost", 0.0)),
|
| 313 |
+
tier=oct_meta.get("tier", "unknown"),
|
| 314 |
+
always_given=bool(oct_meta.get("always_given", False)),
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# Additional OCT slices
|
| 318 |
+
if additional_images:
|
| 319 |
+
try:
|
| 320 |
+
add_b64 = [encode_image_to_base64(p) for p in additional_images]
|
| 321 |
+
ch_meta = config.get_channel_definition("olives", "additional_oct")
|
| 322 |
+
all_channels["additional_oct"] = ChannelData(
|
| 323 |
+
name="additional_oct",
|
| 324 |
+
channel_type="image",
|
| 325 |
+
description="Additional OCT B-scans from different retinal locations",
|
| 326 |
+
value=add_b64,
|
| 327 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 328 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 329 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 330 |
+
)
|
| 331 |
+
except Exception:
|
| 332 |
+
pass
|
| 333 |
+
|
| 334 |
+
# Clinical measurements (BCVA and CST)
|
| 335 |
+
bcva = latest.get("BCVA", "")
|
| 336 |
+
cst = latest.get("CST", "")
|
| 337 |
+
if bcva or cst:
|
| 338 |
+
parts = []
|
| 339 |
+
if bcva:
|
| 340 |
+
parts.append(f"BCVA (logMAR): {bcva}")
|
| 341 |
+
if cst:
|
| 342 |
+
parts.append(f"CST: {cst} um")
|
| 343 |
+
ch_meta = config.get_channel_definition("olives", "clinical_measurements")
|
| 344 |
+
all_channels["clinical_measurements"] = ChannelData(
|
| 345 |
+
name="clinical_measurements",
|
| 346 |
+
channel_type="text",
|
| 347 |
+
description="Visual acuity (BCVA) and retinal thickness (CST)",
|
| 348 |
+
value="; ".join(parts),
|
| 349 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 350 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 351 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Biomarker hints (subset — only the most obvious ones)
|
| 355 |
+
biomarker_vec = all_profiles[eye_key]
|
| 356 |
+
obvious_markers = ["fluid_irf", "fluid_srf", "hemorrhage", "drt_me"]
|
| 357 |
+
hint_parts = []
|
| 358 |
+
for m in obvious_markers:
|
| 359 |
+
if m in biomarker_vec:
|
| 360 |
+
status = "Present" if biomarker_vec[m] else "Not detected"
|
| 361 |
+
hint_parts.append(
|
| 362 |
+
f"{m.replace('_', ' ').title()}: {status}"
|
| 363 |
+
)
|
| 364 |
+
if hint_parts:
|
| 365 |
+
ch_meta = config.get_channel_definition("olives", "biomarker_hints")
|
| 366 |
+
all_channels["biomarker_hints"] = ChannelData(
|
| 367 |
+
name="biomarker_hints",
|
| 368 |
+
channel_type="text",
|
| 369 |
+
description="Partial biomarker annotations (subset)",
|
| 370 |
+
value="; ".join(hint_parts),
|
| 371 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 372 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 373 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Disease type hint
|
| 377 |
+
path_str = latest.get(
|
| 378 |
+
"Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
|
| 379 |
+
)
|
| 380 |
+
disease = "DME" if "TREX" in path_str else "DR"
|
| 381 |
+
ch_meta = config.get_channel_definition("olives", "disease_context")
|
| 382 |
+
all_channels["disease_context"] = ChannelData(
|
| 383 |
+
name="disease_context",
|
| 384 |
+
channel_type="text",
|
| 385 |
+
description="Disease type and treatment context",
|
| 386 |
+
value=f"Disease: {disease}",
|
| 387 |
+
cost=float(ch_meta.get("cost", 0.0)),
|
| 388 |
+
tier=ch_meta.get("tier", "unknown"),
|
| 389 |
+
always_given=bool(ch_meta.get("always_given", False)),
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
initial_channels = {
|
| 393 |
+
name: ch for name, ch in all_channels.items() if ch.always_given
|
| 394 |
+
}
|
| 395 |
+
requestable = {
|
| 396 |
+
name: ch for name, ch in all_channels.items() if not ch.always_given
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
# ---- Build candidates ----
|
| 400 |
+
case_id = f"olives_{eye_key}"
|
| 401 |
+
correct_profile = biomarker_vector_to_profile_string(biomarker_vec)
|
| 402 |
+
candidates = self._generate_profile_candidates(
|
| 403 |
+
eye_key, biomarker_vec, all_profiles, case_id
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if correct_profile not in candidates:
|
| 407 |
+
candidates[0] = correct_profile
|
| 408 |
+
rng = _case_rng(case_id)
|
| 409 |
+
rng.shuffle(candidates)
|
| 410 |
+
|
| 411 |
+
return MedicalCase(
|
| 412 |
+
case_id=case_id,
|
| 413 |
+
dataset="olives",
|
| 414 |
+
initial_channels=initial_channels,
|
| 415 |
+
requestable_channels=requestable,
|
| 416 |
+
candidates=candidates,
|
| 417 |
+
ground_truth=correct_profile,
|
| 418 |
+
ground_truth_rank=(
|
| 419 |
+
candidates.index(correct_profile)
|
| 420 |
+
if correct_profile in candidates else 0
|
| 421 |
+
),
|
| 422 |
+
metadata={
|
| 423 |
+
"eye_id": eye_key,
|
| 424 |
+
"disease": disease,
|
| 425 |
+
"biomarker_vector": biomarker_vec,
|
| 426 |
+
},
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def _generate_profile_candidates(
|
| 430 |
+
self,
|
| 431 |
+
eye_id: str,
|
| 432 |
+
correct_vec: dict[str, bool],
|
| 433 |
+
all_profiles: dict[str, dict[str, bool]],
|
| 434 |
+
case_id: str,
|
| 435 |
+
) -> list[str]:
|
| 436 |
+
"""Generate biomarker profile candidates."""
|
| 437 |
+
n = self.n_candidates
|
| 438 |
+
rng = _case_rng(case_id)
|
| 439 |
+
correct_str = biomarker_vector_to_profile_string(correct_vec)
|
| 440 |
+
|
| 441 |
+
scored = []
|
| 442 |
+
for eid, vec in all_profiles.items():
|
| 443 |
+
if eid == eye_id:
|
| 444 |
+
continue
|
| 445 |
+
dist = compute_profile_distance(correct_vec, vec)
|
| 446 |
+
profile_str = biomarker_vector_to_profile_string(vec)
|
| 447 |
+
if profile_str != correct_str:
|
| 448 |
+
scored.append((dist, profile_str, vec))
|
| 449 |
+
|
| 450 |
+
scored.sort(key=lambda x: x[0])
|
| 451 |
+
|
| 452 |
+
distractors = []
|
| 453 |
+
if scored:
|
| 454 |
+
distractors.append(scored[0][1]) # Hard distractor
|
| 455 |
+
if len(scored) > 1:
|
| 456 |
+
distractors.append(scored[-1][1]) # Easy distractor
|
| 457 |
+
mid_pool = scored[len(scored) // 4: 3 * len(scored) // 4]
|
| 458 |
+
rng.shuffle(mid_pool)
|
| 459 |
+
for dist, prof, vec in mid_pool:
|
| 460 |
+
if prof not in distractors and len(distractors) < n - 1:
|
| 461 |
+
distractors.append(prof)
|
| 462 |
+
|
| 463 |
+
while len(distractors) < n - 1 and scored:
|
| 464 |
+
pick = rng.choice(scored)
|
| 465 |
+
if pick[1] not in distractors:
|
| 466 |
+
distractors.append(pick[1])
|
| 467 |
+
|
| 468 |
+
candidates = [correct_str] + distractors[:n - 1]
|
| 469 |
+
rng.shuffle(candidates)
|
| 470 |
+
return candidates
|
demo_cases/chest_xray_ipf.png
ADDED
|
Git LFS Details
|
demo_cases/ct_pulmonary_pe.png
ADDED
|
Git LFS Details
|
demo_cases/fundus_dme.png
ADDED
|
Git LFS Details
|
demo_cases/oct_bscan_dme.png
ADDED
|
Git LFS Details
|
demo_cases/skin_lesion_dermoscopy.png
ADDED
|
Git LFS Details
|
evaluation/__init__.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation Metrics for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
Unified metrics across all three datasets:
|
| 5 |
+
- MRR (Mean Reciprocal Rank)
|
| 6 |
+
- Acquisition Efficiency (normalized improvement)
|
| 7 |
+
- Top-1 Accuracy
|
| 8 |
+
- Acquisition Precision
|
| 9 |
+
- Uncertainty Calibration (ECE-style)
|
| 10 |
+
- Information-Theoretic Metrics (entropy, IG, VoI)
|
| 11 |
+
- Bootstrap confidence intervals
|
| 12 |
+
"""
|
| 13 |
+
import logging
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
from scipy import stats
|
| 18 |
+
|
| 19 |
+
from agent import AgentResult
|
| 20 |
+
from datasets.base import MedicalCase
|
| 21 |
+
from information_gain import BeliefTrajectory, compute_information_metrics
|
| 22 |
+
import config
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class CaseMetrics:
|
| 29 |
+
"""Metrics for a single case."""
|
| 30 |
+
case_id: str
|
| 31 |
+
dataset: str
|
| 32 |
+
top1_correct: bool = False
|
| 33 |
+
reciprocal_rank: float = 0.0
|
| 34 |
+
ground_truth_rank: int = -1 # 1-indexed rank of correct answer
|
| 35 |
+
n_acquired: int = 0
|
| 36 |
+
acquired_channels: list[str] = field(default_factory=list)
|
| 37 |
+
committed_early: bool = False
|
| 38 |
+
top1_confidence: float = 0.0 # Confidence of the top-ranked diagnosis
|
| 39 |
+
acquisition_cost: float = 0.0
|
| 40 |
+
total_case_cost: float = 0.0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class DatasetMetrics:
|
| 45 |
+
"""Aggregated metrics for a dataset."""
|
| 46 |
+
dataset: str
|
| 47 |
+
n_cases: int
|
| 48 |
+
top1_accuracy: float
|
| 49 |
+
mrr: float # Mean Reciprocal Rank
|
| 50 |
+
top1_accuracy_ci: tuple = (0.0, 0.0) # 95% CI
|
| 51 |
+
mrr_ci: tuple = (0.0, 0.0)
|
| 52 |
+
mean_channels_acquired: float = 0.0
|
| 53 |
+
early_commit_rate: float = 0.0
|
| 54 |
+
per_channel_request_rate: dict = field(default_factory=dict)
|
| 55 |
+
mean_acquisition_cost: float = 0.0
|
| 56 |
+
mean_total_case_cost: float = 0.0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def compute_reciprocal_rank(
|
| 60 |
+
ranking: list[dict],
|
| 61 |
+
ground_truth: str,
|
| 62 |
+
candidates: list[str],
|
| 63 |
+
) -> float:
|
| 64 |
+
"""
|
| 65 |
+
Compute reciprocal rank of the ground truth in the agent's ranking.
|
| 66 |
+
|
| 67 |
+
Returns 1/rank if found, 0 if not found.
|
| 68 |
+
"""
|
| 69 |
+
if not ranking:
|
| 70 |
+
return 0.0
|
| 71 |
+
|
| 72 |
+
gt_lower = ground_truth.lower().strip()
|
| 73 |
+
|
| 74 |
+
for entry in ranking:
|
| 75 |
+
name = entry.get("name", "").lower().strip()
|
| 76 |
+
rank = entry.get("rank", 999)
|
| 77 |
+
|
| 78 |
+
# Flexible matching: check substring containment both ways
|
| 79 |
+
if gt_lower in name or name in gt_lower:
|
| 80 |
+
return 1.0 / rank
|
| 81 |
+
|
| 82 |
+
# Check if it matches any candidate that matches ground truth
|
| 83 |
+
for candidate in candidates:
|
| 84 |
+
if (
|
| 85 |
+
gt_lower in candidate.lower()
|
| 86 |
+
and (name in candidate.lower() or candidate.lower() in name)
|
| 87 |
+
):
|
| 88 |
+
return 1.0 / rank
|
| 89 |
+
|
| 90 |
+
# Ground truth not found in ranking — return 1/(N+1)
|
| 91 |
+
return 1.0 / (len(ranking) + 1) if ranking else 0.0
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def evaluate_single_case(
|
| 95 |
+
result: AgentResult,
|
| 96 |
+
case: MedicalCase,
|
| 97 |
+
) -> CaseMetrics:
|
| 98 |
+
"""Evaluate a single agent result against ground truth."""
|
| 99 |
+
ranking = result.final_ranking
|
| 100 |
+
gt = case.ground_truth
|
| 101 |
+
candidates = case.candidates
|
| 102 |
+
|
| 103 |
+
rr = compute_reciprocal_rank(ranking, gt, candidates)
|
| 104 |
+
top1_correct = rr == 1.0 # RR=1 means correct answer is ranked first
|
| 105 |
+
|
| 106 |
+
top1_conf = ranking[0]["confidence"] if ranking else 0.0
|
| 107 |
+
|
| 108 |
+
# Determine ground truth rank in agent's output
|
| 109 |
+
gt_rank = -1
|
| 110 |
+
gt_lower = gt.lower().strip()
|
| 111 |
+
for entry in ranking:
|
| 112 |
+
name = entry.get("name", "").lower().strip()
|
| 113 |
+
if gt_lower in name or name in gt_lower:
|
| 114 |
+
gt_rank = entry.get("rank", -1)
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
return CaseMetrics(
|
| 118 |
+
case_id=result.case_id,
|
| 119 |
+
dataset=result.dataset,
|
| 120 |
+
top1_correct=top1_correct,
|
| 121 |
+
reciprocal_rank=rr,
|
| 122 |
+
ground_truth_rank=gt_rank,
|
| 123 |
+
n_acquired=len(result.acquired_channels),
|
| 124 |
+
acquired_channels=result.acquired_channels,
|
| 125 |
+
committed_early=result.committed_early,
|
| 126 |
+
top1_confidence=top1_conf,
|
| 127 |
+
acquisition_cost=result.acquisition_cost,
|
| 128 |
+
total_case_cost=result.total_case_cost,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def aggregate_metrics(
|
| 133 |
+
case_metrics: list[CaseMetrics],
|
| 134 |
+
dataset_name: str,
|
| 135 |
+
n_bootstrap: int = None,
|
| 136 |
+
) -> DatasetMetrics:
|
| 137 |
+
"""Aggregate per-case metrics into dataset-level stats with bootstrap CIs."""
|
| 138 |
+
if n_bootstrap is None:
|
| 139 |
+
n_bootstrap = config.N_BOOTSTRAP
|
| 140 |
+
|
| 141 |
+
n = len(case_metrics)
|
| 142 |
+
if n == 0:
|
| 143 |
+
return DatasetMetrics(dataset=dataset_name, n_cases=0, top1_accuracy=0, mrr=0)
|
| 144 |
+
|
| 145 |
+
accuracies = np.array([int(cm.top1_correct) for cm in case_metrics])
|
| 146 |
+
rrs = np.array([cm.reciprocal_rank for cm in case_metrics])
|
| 147 |
+
|
| 148 |
+
top1_acc = float(np.mean(accuracies))
|
| 149 |
+
mrr = float(np.mean(rrs))
|
| 150 |
+
|
| 151 |
+
# Bootstrap CIs
|
| 152 |
+
acc_ci = _bootstrap_ci(accuracies, n_bootstrap)
|
| 153 |
+
mrr_ci = _bootstrap_ci(rrs, n_bootstrap)
|
| 154 |
+
|
| 155 |
+
# Channel request rates
|
| 156 |
+
channel_counts: dict[str, int] = {}
|
| 157 |
+
for cm in case_metrics:
|
| 158 |
+
for ch in cm.acquired_channels:
|
| 159 |
+
channel_counts[ch] = channel_counts.get(ch, 0) + 1
|
| 160 |
+
channel_rates = {ch: count / n for ch, count in channel_counts.items()}
|
| 161 |
+
|
| 162 |
+
return DatasetMetrics(
|
| 163 |
+
dataset=dataset_name,
|
| 164 |
+
n_cases=n,
|
| 165 |
+
top1_accuracy=top1_acc,
|
| 166 |
+
mrr=mrr,
|
| 167 |
+
top1_accuracy_ci=acc_ci,
|
| 168 |
+
mrr_ci=mrr_ci,
|
| 169 |
+
mean_channels_acquired=float(np.mean([cm.n_acquired for cm in case_metrics])),
|
| 170 |
+
early_commit_rate=float(np.mean([int(cm.committed_early) for cm in case_metrics])),
|
| 171 |
+
per_channel_request_rate=channel_rates,
|
| 172 |
+
mean_acquisition_cost=float(np.mean([cm.acquisition_cost for cm in case_metrics])),
|
| 173 |
+
mean_total_case_cost=float(np.mean([cm.total_case_cost for cm in case_metrics])),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def compute_acquisition_efficiency(
|
| 178 |
+
mrr_at_k: float,
|
| 179 |
+
mrr_passive: float,
|
| 180 |
+
mrr_oracle: float,
|
| 181 |
+
) -> float:
|
| 182 |
+
"""
|
| 183 |
+
Normalized Acquisition Efficiency.
|
| 184 |
+
|
| 185 |
+
AE(K) = (MRR_K - MRR_passive) / (MRR_oracle - MRR_passive)
|
| 186 |
+
|
| 187 |
+
Returns 0 if oracle = passive (no room for improvement),
|
| 188 |
+
can exceed 1 if active outperforms oracle (shouldn't happen normally).
|
| 189 |
+
"""
|
| 190 |
+
denom = mrr_oracle - mrr_passive
|
| 191 |
+
if abs(denom) < 1e-8:
|
| 192 |
+
return 0.0
|
| 193 |
+
return (mrr_at_k - mrr_passive) / denom
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def compute_acquisition_precision(
|
| 197 |
+
active_results: list[AgentResult],
|
| 198 |
+
passive_results: list[AgentResult],
|
| 199 |
+
cases: list[MedicalCase],
|
| 200 |
+
) -> dict:
|
| 201 |
+
"""
|
| 202 |
+
Acquisition Precision: when the agent requests info, does the diagnosis change?
|
| 203 |
+
|
| 204 |
+
Two sub-metrics:
|
| 205 |
+
- request_change_rate: fraction of acquisitions that changed the top-1 diagnosis
|
| 206 |
+
- change_correctness: among diagnosis changes, fraction that were improvements
|
| 207 |
+
"""
|
| 208 |
+
assert len(active_results) == len(passive_results) == len(cases)
|
| 209 |
+
|
| 210 |
+
total_acquisitions = 0
|
| 211 |
+
diagnosis_changed = 0
|
| 212 |
+
change_improved = 0
|
| 213 |
+
|
| 214 |
+
for active, passive, case in zip(active_results, passive_results, cases):
|
| 215 |
+
passive_top1 = _get_top1_name(passive.final_ranking)
|
| 216 |
+
active_top1 = _get_top1_name(active.final_ranking)
|
| 217 |
+
|
| 218 |
+
n_acq = len(active.acquired_channels)
|
| 219 |
+
if n_acq > 0:
|
| 220 |
+
total_acquisitions += 1
|
| 221 |
+
if passive_top1 != active_top1:
|
| 222 |
+
diagnosis_changed += 1
|
| 223 |
+
# Did it change to the correct answer?
|
| 224 |
+
gt = case.ground_truth.lower().strip()
|
| 225 |
+
if gt in active_top1.lower() or active_top1.lower() in gt:
|
| 226 |
+
change_improved += 1
|
| 227 |
+
|
| 228 |
+
return {
|
| 229 |
+
"total_cases_with_acquisition": total_acquisitions,
|
| 230 |
+
"request_change_rate": (
|
| 231 |
+
diagnosis_changed / total_acquisitions if total_acquisitions > 0 else 0
|
| 232 |
+
),
|
| 233 |
+
"change_correctness": (
|
| 234 |
+
change_improved / diagnosis_changed if diagnosis_changed > 0 else 0
|
| 235 |
+
),
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def compute_prompt_agreement(
|
| 240 |
+
results_by_variant: dict[str, list[AgentResult]],
|
| 241 |
+
) -> dict:
|
| 242 |
+
"""
|
| 243 |
+
Prompt sensitivity analysis: measure agreement across prompt variants.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
- top1_agreement: fraction of cases where all variants agree on top-1
|
| 247 |
+
- acquisition_agreement: fraction of cases where all variants request
|
| 248 |
+
the same first channel
|
| 249 |
+
"""
|
| 250 |
+
variants = list(results_by_variant.keys())
|
| 251 |
+
if len(variants) < 2:
|
| 252 |
+
return {"top1_agreement": 1.0, "acquisition_agreement": 1.0}
|
| 253 |
+
|
| 254 |
+
# Align by case_id
|
| 255 |
+
case_ids = set()
|
| 256 |
+
for results in results_by_variant.values():
|
| 257 |
+
case_ids.update(r.case_id for r in results)
|
| 258 |
+
|
| 259 |
+
by_case: dict[str, dict[str, AgentResult]] = {}
|
| 260 |
+
for variant, results in results_by_variant.items():
|
| 261 |
+
for r in results:
|
| 262 |
+
if r.case_id not in by_case:
|
| 263 |
+
by_case[r.case_id] = {}
|
| 264 |
+
by_case[r.case_id][variant] = r
|
| 265 |
+
|
| 266 |
+
top1_agree_count = 0
|
| 267 |
+
acq_agree_count = 0
|
| 268 |
+
total = 0
|
| 269 |
+
|
| 270 |
+
for case_id, variant_results in by_case.items():
|
| 271 |
+
if len(variant_results) < len(variants):
|
| 272 |
+
continue # Skip cases not in all variants
|
| 273 |
+
total += 1
|
| 274 |
+
|
| 275 |
+
# Top-1 agreement
|
| 276 |
+
top1s = set()
|
| 277 |
+
for vr in variant_results.values():
|
| 278 |
+
top1s.add(_get_top1_name(vr.final_ranking).lower())
|
| 279 |
+
if len(top1s) == 1:
|
| 280 |
+
top1_agree_count += 1
|
| 281 |
+
|
| 282 |
+
# First acquisition agreement
|
| 283 |
+
first_acqs = set()
|
| 284 |
+
for vr in variant_results.values():
|
| 285 |
+
if vr.acquired_channels:
|
| 286 |
+
first_acqs.add(vr.acquired_channels[0])
|
| 287 |
+
else:
|
| 288 |
+
first_acqs.add("_committed_")
|
| 289 |
+
if len(first_acqs) == 1:
|
| 290 |
+
acq_agree_count += 1
|
| 291 |
+
|
| 292 |
+
return {
|
| 293 |
+
"top1_agreement": top1_agree_count / total if total > 0 else 0,
|
| 294 |
+
"acquisition_agreement": acq_agree_count / total if total > 0 else 0,
|
| 295 |
+
"n_cases_compared": total,
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def compute_regret_analysis(
|
| 300 |
+
active_results: list[AgentResult],
|
| 301 |
+
oracle_results: list[AgentResult],
|
| 302 |
+
cases: list[MedicalCase],
|
| 303 |
+
) -> dict:
|
| 304 |
+
"""
|
| 305 |
+
Regret Analysis: when the agent gets a case wrong, could a different
|
| 306 |
+
acquisition strategy have saved it?
|
| 307 |
+
|
| 308 |
+
For each case where active got it wrong:
|
| 309 |
+
1. Did the oracle get it right? (recoverable error)
|
| 310 |
+
2. Which channels were available but not requested? (missed channels)
|
| 311 |
+
3. Among recoverable errors, which missing channels correlate most
|
| 312 |
+
with oracle success? (high-regret channels)
|
| 313 |
+
|
| 314 |
+
Returns a rich dict with per-case traces and aggregate statistics.
|
| 315 |
+
"""
|
| 316 |
+
assert len(active_results) == len(oracle_results) == len(cases)
|
| 317 |
+
|
| 318 |
+
per_case_regret = []
|
| 319 |
+
n_active_wrong = 0
|
| 320 |
+
n_oracle_right_when_active_wrong = 0 # recoverable
|
| 321 |
+
n_both_wrong = 0 # unrecoverable — VLM reasoning bottleneck
|
| 322 |
+
missed_channel_counts: dict[str, int] = {} # channels not requested in recoverable cases
|
| 323 |
+
missed_channel_total: dict[str, int] = {} # total times a channel was missed (all wrong)
|
| 324 |
+
|
| 325 |
+
for active, oracle, case in zip(active_results, oracle_results, cases):
|
| 326 |
+
active_rr = compute_reciprocal_rank(active.final_ranking, case.ground_truth, case.candidates)
|
| 327 |
+
oracle_rr = compute_reciprocal_rank(oracle.final_ranking, case.ground_truth, case.candidates)
|
| 328 |
+
active_correct = active_rr == 1.0
|
| 329 |
+
oracle_correct = oracle_rr == 1.0
|
| 330 |
+
|
| 331 |
+
if active_correct:
|
| 332 |
+
continue # No regret if agent got it right
|
| 333 |
+
|
| 334 |
+
n_active_wrong += 1
|
| 335 |
+
|
| 336 |
+
# Channels available but not acquired
|
| 337 |
+
all_requestable = set(case.requestable_channels.keys())
|
| 338 |
+
acquired = set(active.acquired_channels)
|
| 339 |
+
missed = all_requestable - acquired
|
| 340 |
+
|
| 341 |
+
case_entry = {
|
| 342 |
+
"case_id": case.case_id,
|
| 343 |
+
"ground_truth": case.ground_truth,
|
| 344 |
+
"active_top1": _get_top1_name(active.final_ranking),
|
| 345 |
+
"oracle_top1": _get_top1_name(oracle.final_ranking),
|
| 346 |
+
"active_correct": False,
|
| 347 |
+
"oracle_correct": oracle_correct,
|
| 348 |
+
"acquired_channels": list(acquired),
|
| 349 |
+
"missed_channels": list(missed),
|
| 350 |
+
"recoverable": oracle_correct,
|
| 351 |
+
"active_rr": active_rr,
|
| 352 |
+
"oracle_rr": oracle_rr,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
for ch in missed:
|
| 356 |
+
missed_channel_total[ch] = missed_channel_total.get(ch, 0) + 1
|
| 357 |
+
|
| 358 |
+
if oracle_correct:
|
| 359 |
+
n_oracle_right_when_active_wrong += 1
|
| 360 |
+
for ch in missed:
|
| 361 |
+
missed_channel_counts[ch] = missed_channel_counts.get(ch, 0) + 1
|
| 362 |
+
else:
|
| 363 |
+
n_both_wrong += 1
|
| 364 |
+
|
| 365 |
+
per_case_regret.append(case_entry)
|
| 366 |
+
|
| 367 |
+
# Compute per-channel regret score: how often a missed channel appears
|
| 368 |
+
# in recoverable errors vs all errors
|
| 369 |
+
channel_regret_scores = {}
|
| 370 |
+
for ch in set(list(missed_channel_counts.keys()) + list(missed_channel_total.keys())):
|
| 371 |
+
recoverable_miss = missed_channel_counts.get(ch, 0)
|
| 372 |
+
total_miss = missed_channel_total.get(ch, 0)
|
| 373 |
+
# Regret score: fraction of times this channel was missed AND oracle succeeded
|
| 374 |
+
channel_regret_scores[ch] = {
|
| 375 |
+
"missed_in_recoverable": recoverable_miss,
|
| 376 |
+
"missed_in_all_wrong": total_miss,
|
| 377 |
+
"regret_rate": recoverable_miss / total_miss if total_miss > 0 else 0.0,
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
# Sort channels by regret rate descending
|
| 381 |
+
sorted_channels = sorted(
|
| 382 |
+
channel_regret_scores.items(),
|
| 383 |
+
key=lambda x: (-x[1]["regret_rate"], -x[1]["missed_in_recoverable"]),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
return {
|
| 387 |
+
"n_cases": len(cases),
|
| 388 |
+
"n_active_wrong": n_active_wrong,
|
| 389 |
+
"n_recoverable": n_oracle_right_when_active_wrong,
|
| 390 |
+
"n_unrecoverable": n_both_wrong,
|
| 391 |
+
"recovery_rate": (
|
| 392 |
+
n_oracle_right_when_active_wrong / n_active_wrong
|
| 393 |
+
if n_active_wrong > 0 else 0.0
|
| 394 |
+
),
|
| 395 |
+
"error_rate": n_active_wrong / len(cases) if cases else 0.0,
|
| 396 |
+
"channel_regret_scores": dict(sorted_channels),
|
| 397 |
+
"per_case_regret": per_case_regret,
|
| 398 |
+
"summary": {
|
| 399 |
+
"total_errors": n_active_wrong,
|
| 400 |
+
"recoverable_pct": (
|
| 401 |
+
n_oracle_right_when_active_wrong / n_active_wrong * 100
|
| 402 |
+
if n_active_wrong > 0 else 0.0
|
| 403 |
+
),
|
| 404 |
+
"unrecoverable_pct": (
|
| 405 |
+
n_both_wrong / n_active_wrong * 100
|
| 406 |
+
if n_active_wrong > 0 else 0.0
|
| 407 |
+
),
|
| 408 |
+
"highest_regret_channel": sorted_channels[0][0] if sorted_channels else None,
|
| 409 |
+
},
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def compute_info_theoretic_metrics(
|
| 414 |
+
results: list[AgentResult],
|
| 415 |
+
) -> dict:
|
| 416 |
+
"""
|
| 417 |
+
Compute information-theoretic metrics from belief trajectories.
|
| 418 |
+
|
| 419 |
+
Extracts BeliefTrajectory objects from AgentResults and computes
|
| 420 |
+
aggregate entropy, information gain, and per-channel value metrics.
|
| 421 |
+
"""
|
| 422 |
+
trajectories = [
|
| 423 |
+
r.belief_trajectory for r in results
|
| 424 |
+
if r.belief_trajectory and r.belief_trajectory.states
|
| 425 |
+
]
|
| 426 |
+
if not trajectories:
|
| 427 |
+
return {"n_cases_with_trajectory": 0}
|
| 428 |
+
|
| 429 |
+
metrics = compute_information_metrics(trajectories)
|
| 430 |
+
metrics["n_cases_with_trajectory"] = len(trajectories)
|
| 431 |
+
return metrics
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _get_top1_name(ranking: list[dict]) -> str:
|
| 435 |
+
"""Get the name of the top-ranked diagnosis."""
|
| 436 |
+
if not ranking:
|
| 437 |
+
return ""
|
| 438 |
+
return ranking[0].get("name", "")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def _bootstrap_ci(
|
| 442 |
+
values: np.ndarray, n_bootstrap: int = 1000, ci: float = 0.95
|
| 443 |
+
) -> tuple[float, float]:
|
| 444 |
+
"""Compute bootstrap confidence interval."""
|
| 445 |
+
if len(values) == 0:
|
| 446 |
+
return (0.0, 0.0)
|
| 447 |
+
rng = np.random.RandomState(config.SEED)
|
| 448 |
+
boot_means = []
|
| 449 |
+
for _ in range(n_bootstrap):
|
| 450 |
+
sample = rng.choice(values, size=len(values), replace=True)
|
| 451 |
+
boot_means.append(np.mean(sample))
|
| 452 |
+
alpha = (1 - ci) / 2
|
| 453 |
+
lower = float(np.percentile(boot_means, alpha * 100))
|
| 454 |
+
upper = float(np.percentile(boot_means, (1 - alpha) * 100))
|
| 455 |
+
return (lower, upper)
|
evaluation/analysis.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cross-dataset analysis and figure generation.
|
| 3 |
+
|
| 4 |
+
Produces the key figures for the paper:
|
| 5 |
+
1. Acquisition Efficiency curves (all 3 datasets, shared y-axis)
|
| 6 |
+
2. Per-channel request frequency heatmap
|
| 7 |
+
3. Prompt sensitivity agreement matrix
|
| 8 |
+
4. OLIVES biomarker-tier acquisition analysis
|
| 9 |
+
5. NEJM difficulty-vs-acquisition scatter
|
| 10 |
+
"""
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from dataclasses import asdict
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import matplotlib
|
| 19 |
+
import seaborn as sns
|
| 20 |
+
from scipy import stats
|
| 21 |
+
|
| 22 |
+
from agent import AgentResult
|
| 23 |
+
from datasets.base import MedicalCase
|
| 24 |
+
from evaluation import (
|
| 25 |
+
CaseMetrics,
|
| 26 |
+
DatasetMetrics,
|
| 27 |
+
evaluate_single_case,
|
| 28 |
+
aggregate_metrics,
|
| 29 |
+
compute_acquisition_efficiency,
|
| 30 |
+
compute_acquisition_precision,
|
| 31 |
+
compute_prompt_agreement,
|
| 32 |
+
compute_regret_analysis,
|
| 33 |
+
)
|
| 34 |
+
import config
|
| 35 |
+
|
| 36 |
+
matplotlib.rcParams["font.family"] = "serif"
|
| 37 |
+
matplotlib.rcParams["font.size"] = 11
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ExperimentAnalyzer:
|
| 43 |
+
"""Analyze and visualize results across all experiments."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, results_dir: Path = None):
|
| 46 |
+
self.results_dir = results_dir or config.RESULTS_DIR
|
| 47 |
+
self.figures_dir = self.results_dir / "figures"
|
| 48 |
+
self.figures_dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
def load_results(self, experiment_name: str) -> dict:
|
| 51 |
+
"""Load saved experiment results."""
|
| 52 |
+
path = self.results_dir / f"{experiment_name}.json"
|
| 53 |
+
if not path.exists():
|
| 54 |
+
logger.error(f"Results file not found: {path}")
|
| 55 |
+
return {}
|
| 56 |
+
with open(path) as f:
|
| 57 |
+
return json.load(f)
|
| 58 |
+
|
| 59 |
+
def save_results(self, data: dict, experiment_name: str):
|
| 60 |
+
"""Save experiment results."""
|
| 61 |
+
path = self.results_dir / f"{experiment_name}.json"
|
| 62 |
+
with open(path, "w") as f:
|
| 63 |
+
json.dump(data, f, indent=2, default=str)
|
| 64 |
+
logger.info(f"Results saved to {path}")
|
| 65 |
+
|
| 66 |
+
# ================================================================
|
| 67 |
+
# Figure 1: Acquisition Efficiency Curves
|
| 68 |
+
# ================================================================
|
| 69 |
+
|
| 70 |
+
def plot_acquisition_efficiency(
|
| 71 |
+
self,
|
| 72 |
+
results_by_dataset: dict[str, dict[int, DatasetMetrics]],
|
| 73 |
+
passive_metrics: dict[str, DatasetMetrics],
|
| 74 |
+
oracle_metrics: dict[str, DatasetMetrics],
|
| 75 |
+
save_name: str = "fig1_acquisition_efficiency",
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Main result figure: normalized acquisition efficiency vs budget K.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
results_by_dataset: {dataset_name: {K: DatasetMetrics}}
|
| 82 |
+
passive_metrics: {dataset_name: DatasetMetrics} at K=0
|
| 83 |
+
oracle_metrics: {dataset_name: DatasetMetrics} with all channels
|
| 84 |
+
"""
|
| 85 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
|
| 86 |
+
|
| 87 |
+
colors = {"midas": "#E07A5F", "nejm": "#3D405B", "olives": "#81B29A"}
|
| 88 |
+
markers = {"midas": "o", "nejm": "s", "olives": "D"}
|
| 89 |
+
labels = {"midas": "MIDAS (Dermatology)", "nejm": "NEJM (Multi-Specialty)",
|
| 90 |
+
"olives": "OLIVES (Ophthalmology)"}
|
| 91 |
+
|
| 92 |
+
# Left panel: Raw MRR vs K
|
| 93 |
+
ax = axes[0]
|
| 94 |
+
for ds_name in ["midas", "nejm", "olives"]:
|
| 95 |
+
if ds_name not in results_by_dataset:
|
| 96 |
+
continue
|
| 97 |
+
ks = sorted(results_by_dataset[ds_name].keys())
|
| 98 |
+
mrrs = [results_by_dataset[ds_name][k].mrr for k in ks]
|
| 99 |
+
cis = [results_by_dataset[ds_name][k].mrr_ci for k in ks]
|
| 100 |
+
|
| 101 |
+
# Add passive at K=0
|
| 102 |
+
all_k = [0] + list(ks)
|
| 103 |
+
all_mrr = [passive_metrics[ds_name].mrr] + mrrs
|
| 104 |
+
all_lower = [passive_metrics[ds_name].mrr_ci[0]] + [c[0] for c in cis]
|
| 105 |
+
all_upper = [passive_metrics[ds_name].mrr_ci[1]] + [c[1] for c in cis]
|
| 106 |
+
|
| 107 |
+
ax.plot(all_k, all_mrr, color=colors[ds_name], marker=markers[ds_name],
|
| 108 |
+
label=labels[ds_name], linewidth=2, markersize=7)
|
| 109 |
+
ax.fill_between(all_k, all_lower, all_upper, alpha=0.15, color=colors[ds_name])
|
| 110 |
+
|
| 111 |
+
# Oracle line
|
| 112 |
+
ax.axhline(y=oracle_metrics[ds_name].mrr, color=colors[ds_name],
|
| 113 |
+
linestyle="--", alpha=0.4, linewidth=1)
|
| 114 |
+
|
| 115 |
+
ax.set_xlabel("Acquisition Budget (K)")
|
| 116 |
+
ax.set_ylabel("Mean Reciprocal Rank (MRR)")
|
| 117 |
+
ax.set_title("(a) Diagnostic Quality vs. Budget")
|
| 118 |
+
ax.legend(fontsize=9)
|
| 119 |
+
ax.set_xticks(range(max(4, max(max(r.keys()) for r in results_by_dataset.values()) + 1)))
|
| 120 |
+
ax.grid(True, alpha=0.3)
|
| 121 |
+
|
| 122 |
+
# Right panel: Normalized Acquisition Efficiency
|
| 123 |
+
ax = axes[1]
|
| 124 |
+
for ds_name in ["midas", "nejm", "olives"]:
|
| 125 |
+
if ds_name not in results_by_dataset:
|
| 126 |
+
continue
|
| 127 |
+
ks = sorted(results_by_dataset[ds_name].keys())
|
| 128 |
+
effs = []
|
| 129 |
+
for k in ks:
|
| 130 |
+
ae = compute_acquisition_efficiency(
|
| 131 |
+
results_by_dataset[ds_name][k].mrr,
|
| 132 |
+
passive_metrics[ds_name].mrr,
|
| 133 |
+
oracle_metrics[ds_name].mrr,
|
| 134 |
+
)
|
| 135 |
+
effs.append(ae)
|
| 136 |
+
|
| 137 |
+
all_k = [0] + list(ks)
|
| 138 |
+
all_eff = [0.0] + effs
|
| 139 |
+
|
| 140 |
+
ax.plot(all_k, all_eff, color=colors[ds_name], marker=markers[ds_name],
|
| 141 |
+
label=labels[ds_name], linewidth=2, markersize=7)
|
| 142 |
+
|
| 143 |
+
ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, linewidth=1,
|
| 144 |
+
label="Oracle ceiling")
|
| 145 |
+
ax.set_xlabel("Acquisition Budget (K)")
|
| 146 |
+
ax.set_ylabel("Acquisition Efficiency")
|
| 147 |
+
ax.set_title("(b) Normalized Efficiency")
|
| 148 |
+
ax.legend(fontsize=9)
|
| 149 |
+
ax.set_ylim(-0.05, 1.15)
|
| 150 |
+
ax.grid(True, alpha=0.3)
|
| 151 |
+
|
| 152 |
+
plt.tight_layout()
|
| 153 |
+
save_path = self.figures_dir / f"{save_name}.pdf"
|
| 154 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 155 |
+
plt.close(fig)
|
| 156 |
+
logger.info(f"Saved figure: {save_path}")
|
| 157 |
+
|
| 158 |
+
# ================================================================
|
| 159 |
+
# Figure 2: Per-Channel Request Frequency
|
| 160 |
+
# ================================================================
|
| 161 |
+
|
| 162 |
+
def plot_channel_request_heatmap(
|
| 163 |
+
self,
|
| 164 |
+
results_by_dataset: dict[str, list[AgentResult]],
|
| 165 |
+
save_name: str = "fig2_channel_requests",
|
| 166 |
+
):
|
| 167 |
+
"""Heatmap showing which channels the agent requests most, by dataset."""
|
| 168 |
+
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
|
| 169 |
+
dataset_names = ["midas", "nejm", "olives"]
|
| 170 |
+
titles = ["MIDAS", "NEJM", "OLIVES"]
|
| 171 |
+
|
| 172 |
+
for idx, (ds_name, title) in enumerate(zip(dataset_names, titles)):
|
| 173 |
+
if ds_name not in results_by_dataset:
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
results = results_by_dataset[ds_name]
|
| 177 |
+
|
| 178 |
+
# Count first-request frequency
|
| 179 |
+
first_requests: dict[str, int] = {}
|
| 180 |
+
for r in results:
|
| 181 |
+
if r.acquired_channels:
|
| 182 |
+
ch = r.acquired_channels[0]
|
| 183 |
+
first_requests[ch] = first_requests.get(ch, 0) + 1
|
| 184 |
+
|
| 185 |
+
# Count overall request frequency
|
| 186 |
+
all_requests: dict[str, int] = {}
|
| 187 |
+
for r in results:
|
| 188 |
+
for ch in r.acquired_channels:
|
| 189 |
+
all_requests[ch] = all_requests.get(ch, 0) + 1
|
| 190 |
+
|
| 191 |
+
if not all_requests:
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
channels = sorted(all_requests.keys())
|
| 195 |
+
n = len(results)
|
| 196 |
+
|
| 197 |
+
ax = axes[idx]
|
| 198 |
+
data = np.array([
|
| 199 |
+
[first_requests.get(ch, 0) / n for ch in channels],
|
| 200 |
+
[all_requests.get(ch, 0) / n for ch in channels],
|
| 201 |
+
])
|
| 202 |
+
|
| 203 |
+
sns.heatmap(
|
| 204 |
+
data,
|
| 205 |
+
ax=ax,
|
| 206 |
+
xticklabels=[ch.replace("_", "\n") for ch in channels],
|
| 207 |
+
yticklabels=["First\nRequest", "Any\nRequest"],
|
| 208 |
+
annot=True,
|
| 209 |
+
fmt=".2f",
|
| 210 |
+
cmap="YlOrRd",
|
| 211 |
+
vmin=0,
|
| 212 |
+
vmax=1,
|
| 213 |
+
cbar_kws={"shrink": 0.8},
|
| 214 |
+
)
|
| 215 |
+
ax.set_title(title)
|
| 216 |
+
|
| 217 |
+
plt.tight_layout()
|
| 218 |
+
save_path = self.figures_dir / f"{save_name}.pdf"
|
| 219 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 220 |
+
plt.close(fig)
|
| 221 |
+
logger.info(f"Saved figure: {save_path}")
|
| 222 |
+
|
| 223 |
+
# ================================================================
|
| 224 |
+
# Figure 3: OLIVES Biomarker Tier Analysis
|
| 225 |
+
# ================================================================
|
| 226 |
+
|
| 227 |
+
def plot_olives_biomarker_tiers(
|
| 228 |
+
self,
|
| 229 |
+
results: list[AgentResult],
|
| 230 |
+
cases: list[MedicalCase],
|
| 231 |
+
save_name: str = "fig3_olives_biomarker_tiers",
|
| 232 |
+
):
|
| 233 |
+
"""
|
| 234 |
+
For OLIVES: does the agent request OCT more for OCT-dependent
|
| 235 |
+
biomarkers than for fundus-visible ones?
|
| 236 |
+
"""
|
| 237 |
+
oct_request_by_tier: dict[str, list[bool]] = {
|
| 238 |
+
"fundus_visible": [],
|
| 239 |
+
"oct_dependent": [],
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
for result, case in zip(results, cases):
|
| 243 |
+
if case.dataset != "olives":
|
| 244 |
+
continue
|
| 245 |
+
tier_labels = case.metadata.get("biomarker_tier_labels", {})
|
| 246 |
+
requested_oct = "oct_scan" in result.acquired_channels
|
| 247 |
+
|
| 248 |
+
# For cases where the eye has fundus-visible biomarkers
|
| 249 |
+
if tier_labels.get("fundus_visible"):
|
| 250 |
+
oct_request_by_tier["fundus_visible"].append(requested_oct)
|
| 251 |
+
|
| 252 |
+
# For cases where the eye has OCT-dependent biomarkers
|
| 253 |
+
if tier_labels.get("oct_dependent"):
|
| 254 |
+
oct_request_by_tier["oct_dependent"].append(requested_oct)
|
| 255 |
+
|
| 256 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 257 |
+
|
| 258 |
+
tiers = ["fundus_visible", "oct_dependent"]
|
| 259 |
+
tier_labels = ["Fundus-Visible\nBiomarkers", "OCT-Dependent\nBiomarkers"]
|
| 260 |
+
rates = []
|
| 261 |
+
cis_lower = []
|
| 262 |
+
cis_upper = []
|
| 263 |
+
|
| 264 |
+
for tier in tiers:
|
| 265 |
+
vals = oct_request_by_tier.get(tier, [])
|
| 266 |
+
if vals:
|
| 267 |
+
rate = np.mean(vals)
|
| 268 |
+
rates.append(rate)
|
| 269 |
+
# Wilson CI for proportions
|
| 270 |
+
n = len(vals)
|
| 271 |
+
z = 1.96
|
| 272 |
+
p = rate
|
| 273 |
+
denom = 1 + z ** 2 / n
|
| 274 |
+
center = (p + z ** 2 / (2 * n)) / denom
|
| 275 |
+
margin = z * np.sqrt((p * (1 - p) + z ** 2 / (4 * n)) / n) / denom
|
| 276 |
+
cis_lower.append(center - margin)
|
| 277 |
+
cis_upper.append(center + margin)
|
| 278 |
+
else:
|
| 279 |
+
rates.append(0)
|
| 280 |
+
cis_lower.append(0)
|
| 281 |
+
cis_upper.append(0)
|
| 282 |
+
|
| 283 |
+
colors_bar = ["#81B29A", "#E07A5F"]
|
| 284 |
+
bars = ax.bar(tier_labels, rates, color=colors_bar, edgecolor="white", width=0.5)
|
| 285 |
+
ax.errorbar(
|
| 286 |
+
tier_labels, rates,
|
| 287 |
+
yerr=[np.array(rates) - np.array(cis_lower),
|
| 288 |
+
np.array(cis_upper) - np.array(rates)],
|
| 289 |
+
fmt="none", ecolor="black", capsize=5,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
ax.set_ylabel("OCT Request Rate")
|
| 293 |
+
ax.set_title("Agent's OCT Request Rate by Biomarker Type")
|
| 294 |
+
ax.set_ylim(0, 1.05)
|
| 295 |
+
ax.grid(True, axis="y", alpha=0.3)
|
| 296 |
+
|
| 297 |
+
# Add counts
|
| 298 |
+
for i, tier in enumerate(tiers):
|
| 299 |
+
n = len(oct_request_by_tier.get(tier, []))
|
| 300 |
+
ax.text(i, rates[i] + 0.05, f"n={n}", ha="center", fontsize=10)
|
| 301 |
+
|
| 302 |
+
plt.tight_layout()
|
| 303 |
+
save_path = self.figures_dir / f"{save_name}.pdf"
|
| 304 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 305 |
+
plt.close(fig)
|
| 306 |
+
logger.info(f"Saved figure: {save_path}")
|
| 307 |
+
|
| 308 |
+
# ================================================================
|
| 309 |
+
# Figure 4: NEJM Difficulty vs Acquisition Behavior
|
| 310 |
+
# ================================================================
|
| 311 |
+
|
| 312 |
+
def plot_nejm_difficulty_analysis(
|
| 313 |
+
self,
|
| 314 |
+
results: list[AgentResult],
|
| 315 |
+
cases: list[MedicalCase],
|
| 316 |
+
save_name: str = "fig4_nejm_difficulty",
|
| 317 |
+
):
|
| 318 |
+
"""
|
| 319 |
+
Scatter: human difficulty (physician correct rate) vs
|
| 320 |
+
agent's acquisition behavior (N channels requested + early commit).
|
| 321 |
+
"""
|
| 322 |
+
difficulties = []
|
| 323 |
+
n_acquired = []
|
| 324 |
+
committed_early = []
|
| 325 |
+
|
| 326 |
+
for result, case in zip(results, cases):
|
| 327 |
+
if case.dataset != "nejm":
|
| 328 |
+
continue
|
| 329 |
+
votes = case.metadata.get("votes", {})
|
| 330 |
+
if not votes:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
# Compute human difficulty (proportion correct)
|
| 334 |
+
total_votes = sum(float(v) for v in votes.values())
|
| 335 |
+
if total_votes == 0:
|
| 336 |
+
continue
|
| 337 |
+
gt = case.ground_truth
|
| 338 |
+
human_correct = 0.0
|
| 339 |
+
for key, val in votes.items():
|
| 340 |
+
if key in gt or gt.startswith(key):
|
| 341 |
+
human_correct = float(val) / total_votes if total_votes > 1 else float(val)
|
| 342 |
+
break
|
| 343 |
+
|
| 344 |
+
difficulties.append(human_correct)
|
| 345 |
+
n_acquired.append(len(result.acquired_channels))
|
| 346 |
+
committed_early.append(result.committed_early)
|
| 347 |
+
|
| 348 |
+
if not difficulties:
|
| 349 |
+
logger.warning("No NEJM cases with difficulty data found")
|
| 350 |
+
return
|
| 351 |
+
|
| 352 |
+
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
|
| 353 |
+
|
| 354 |
+
# Left: Difficulty vs N channels acquired
|
| 355 |
+
ax = axes[0]
|
| 356 |
+
ax.scatter(difficulties, n_acquired, alpha=0.5, s=30, color="#3D405B", edgecolors="white")
|
| 357 |
+
# Add trend line
|
| 358 |
+
if len(difficulties) > 10:
|
| 359 |
+
z = np.polyfit(difficulties, n_acquired, 1)
|
| 360 |
+
p = np.poly1d(z)
|
| 361 |
+
x_line = np.linspace(min(difficulties), max(difficulties), 100)
|
| 362 |
+
ax.plot(x_line, p(x_line), "--", color="#E07A5F", linewidth=2,
|
| 363 |
+
label=f"Trend (slope={z[0]:.2f})")
|
| 364 |
+
# Correlation
|
| 365 |
+
r, pval = stats.pearsonr(difficulties, n_acquired)
|
| 366 |
+
ax.text(0.05, 0.95, f"r={r:.3f}, p={pval:.3f}",
|
| 367 |
+
transform=ax.transAxes, fontsize=9, verticalalignment="top")
|
| 368 |
+
ax.set_xlabel("Human Correct Rate (easier →)")
|
| 369 |
+
ax.set_ylabel("Channels Acquired by Agent")
|
| 370 |
+
ax.set_title("(a) Case Difficulty vs. Acquisition Amount")
|
| 371 |
+
ax.legend(fontsize=9)
|
| 372 |
+
ax.grid(True, alpha=0.3)
|
| 373 |
+
|
| 374 |
+
# Right: Difficulty bins vs early commit rate
|
| 375 |
+
ax = axes[1]
|
| 376 |
+
diff_arr = np.array(difficulties)
|
| 377 |
+
commit_arr = np.array(committed_early, dtype=float)
|
| 378 |
+
bins = [0, 0.25, 0.50, 0.75, 1.01]
|
| 379 |
+
bin_labels = ["<25%", "25-50%", "50-75%", ">75%"]
|
| 380 |
+
bin_rates = []
|
| 381 |
+
bin_ns = []
|
| 382 |
+
|
| 383 |
+
for i in range(len(bins) - 1):
|
| 384 |
+
mask = (diff_arr >= bins[i]) & (diff_arr < bins[i + 1])
|
| 385 |
+
if mask.sum() > 0:
|
| 386 |
+
bin_rates.append(commit_arr[mask].mean())
|
| 387 |
+
bin_ns.append(mask.sum())
|
| 388 |
+
else:
|
| 389 |
+
bin_rates.append(0)
|
| 390 |
+
bin_ns.append(0)
|
| 391 |
+
|
| 392 |
+
bar_colors = ["#E07A5F", "#F2CC8F", "#81B29A", "#3D405B"]
|
| 393 |
+
bars = ax.bar(bin_labels, bin_rates, color=bar_colors, edgecolor="white", width=0.6)
|
| 394 |
+
for i, (rate, n) in enumerate(zip(bin_rates, bin_ns)):
|
| 395 |
+
ax.text(i, rate + 0.02, f"n={n}", ha="center", fontsize=9)
|
| 396 |
+
ax.set_xlabel("Human Correct Rate (easier →)")
|
| 397 |
+
ax.set_ylabel("Agent Early Commit Rate")
|
| 398 |
+
ax.set_title("(b) Early Commitment vs. Difficulty")
|
| 399 |
+
ax.set_ylim(0, 1.05)
|
| 400 |
+
ax.grid(True, axis="y", alpha=0.3)
|
| 401 |
+
|
| 402 |
+
plt.tight_layout()
|
| 403 |
+
save_path = self.figures_dir / f"{save_name}.pdf"
|
| 404 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 405 |
+
plt.close(fig)
|
| 406 |
+
logger.info(f"Saved figure: {save_path}")
|
| 407 |
+
|
| 408 |
+
# ================================================================
|
| 409 |
+
# Figure 5: Regret Analysis
|
| 410 |
+
# ================================================================
|
| 411 |
+
|
| 412 |
+
def plot_regret_analysis(
|
| 413 |
+
self,
|
| 414 |
+
regret: dict,
|
| 415 |
+
dataset_name: str = "",
|
| 416 |
+
save_name: str = "fig5_regret_analysis",
|
| 417 |
+
):
|
| 418 |
+
"""
|
| 419 |
+
Visualize regret analysis results.
|
| 420 |
+
|
| 421 |
+
Left: Stacked bar showing recoverable vs unrecoverable errors.
|
| 422 |
+
Right: Per-channel regret scores (which missed channels cost the most).
|
| 423 |
+
"""
|
| 424 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
|
| 425 |
+
title_suffix = f" — {dataset_name.upper()}" if dataset_name else ""
|
| 426 |
+
|
| 427 |
+
# ---- Left panel: Error decomposition ----
|
| 428 |
+
ax = axes[0]
|
| 429 |
+
summary = regret["summary"]
|
| 430 |
+
n_correct = regret["n_cases"] - regret["n_active_wrong"]
|
| 431 |
+
n_recoverable = regret["n_recoverable"]
|
| 432 |
+
n_unrecoverable = regret["n_unrecoverable"]
|
| 433 |
+
|
| 434 |
+
categories = ["Agent\nCorrect", "Recoverable\nErrors", "Unrecoverable\nErrors"]
|
| 435 |
+
values = [n_correct, n_recoverable, n_unrecoverable]
|
| 436 |
+
colors_bar = ["#81B29A", "#F2CC8F", "#E07A5F"]
|
| 437 |
+
|
| 438 |
+
bars = ax.bar(categories, values, color=colors_bar, edgecolor="white", width=0.55)
|
| 439 |
+
for bar, val in zip(bars, values):
|
| 440 |
+
pct = val / regret["n_cases"] * 100 if regret["n_cases"] > 0 else 0
|
| 441 |
+
ax.text(
|
| 442 |
+
bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
|
| 443 |
+
f"{val}\n({pct:.0f}%)", ha="center", fontsize=10,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
ax.set_ylabel("Number of Cases")
|
| 447 |
+
ax.set_title(f"(a) Error Decomposition{title_suffix}")
|
| 448 |
+
ax.grid(True, axis="y", alpha=0.3)
|
| 449 |
+
|
| 450 |
+
# ---- Right panel: Per-channel regret ----
|
| 451 |
+
ax = axes[1]
|
| 452 |
+
channel_scores = regret["channel_regret_scores"]
|
| 453 |
+
|
| 454 |
+
if channel_scores:
|
| 455 |
+
channels = list(channel_scores.keys())
|
| 456 |
+
regret_rates = [channel_scores[ch]["regret_rate"] for ch in channels]
|
| 457 |
+
miss_counts = [channel_scores[ch]["missed_in_recoverable"] for ch in channels]
|
| 458 |
+
|
| 459 |
+
# Sort by regret rate
|
| 460 |
+
sorted_idx = sorted(range(len(channels)), key=lambda i: -regret_rates[i])
|
| 461 |
+
channels = [channels[i] for i in sorted_idx]
|
| 462 |
+
regret_rates = [regret_rates[i] for i in sorted_idx]
|
| 463 |
+
miss_counts = [miss_counts[i] for i in sorted_idx]
|
| 464 |
+
|
| 465 |
+
y_pos = range(len(channels))
|
| 466 |
+
bar_colors = plt.cm.YlOrRd(np.linspace(0.3, 0.9, len(channels)))
|
| 467 |
+
bars = ax.barh(
|
| 468 |
+
y_pos, regret_rates, color=bar_colors, edgecolor="white", height=0.6,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
ax.set_yticks(y_pos)
|
| 472 |
+
ax.set_yticklabels([ch.replace("_", " ").title() for ch in channels], fontsize=9)
|
| 473 |
+
ax.set_xlabel("Regret Rate")
|
| 474 |
+
ax.set_xlim(0, 1.05)
|
| 475 |
+
ax.invert_yaxis()
|
| 476 |
+
|
| 477 |
+
# Annotate with counts
|
| 478 |
+
for i, (rate, count) in enumerate(zip(regret_rates, miss_counts)):
|
| 479 |
+
ax.text(
|
| 480 |
+
rate + 0.02, i, f"n={count}",
|
| 481 |
+
va="center", fontsize=9, color="#333",
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
ax.text(0.5, 0.5, "No channel data", ha="center", va="center",
|
| 485 |
+
transform=ax.transAxes, fontsize=12)
|
| 486 |
+
|
| 487 |
+
ax.set_title(f"(b) Channel Regret Scores{title_suffix}")
|
| 488 |
+
ax.grid(True, axis="x", alpha=0.3)
|
| 489 |
+
|
| 490 |
+
plt.tight_layout()
|
| 491 |
+
save_path = self.figures_dir / f"{save_name}.pdf"
|
| 492 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 493 |
+
plt.close(fig)
|
| 494 |
+
logger.info(f"Saved figure: {save_path}")
|
| 495 |
+
|
| 496 |
+
def print_regret_summary(self, regret: dict):
|
| 497 |
+
"""Print a concise text summary of regret analysis."""
|
| 498 |
+
s = regret["summary"]
|
| 499 |
+
print("\n" + "=" * 55)
|
| 500 |
+
print(" REGRET ANALYSIS")
|
| 501 |
+
print("=" * 55)
|
| 502 |
+
print(f" Total cases: {regret['n_cases']}")
|
| 503 |
+
print(f" Agent errors: {s['total_errors']} ({regret['error_rate']*100:.1f}%)")
|
| 504 |
+
print(f" Recoverable: {regret['n_recoverable']} ({s['recoverable_pct']:.1f}% of errors)")
|
| 505 |
+
print(f" Unrecoverable: {regret['n_unrecoverable']} ({s['unrecoverable_pct']:.1f}% of errors)")
|
| 506 |
+
print(f" Highest-regret channel: {s['highest_regret_channel']}")
|
| 507 |
+
print()
|
| 508 |
+
print(" Per-channel regret:")
|
| 509 |
+
for ch, scores in regret["channel_regret_scores"].items():
|
| 510 |
+
print(f" {ch:<25} regret={scores['regret_rate']:.2f} "
|
| 511 |
+
f"(missed in {scores['missed_in_recoverable']}/{scores['missed_in_all_wrong']} errors)")
|
| 512 |
+
print("=" * 55)
|
| 513 |
+
|
| 514 |
+
# ================================================================
|
| 515 |
+
# Summary Table
|
| 516 |
+
# ================================================================
|
| 517 |
+
|
| 518 |
+
def print_summary_table(
|
| 519 |
+
self,
|
| 520 |
+
all_metrics: dict[str, dict[str, DatasetMetrics]],
|
| 521 |
+
):
|
| 522 |
+
"""
|
| 523 |
+
Print the main results table.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
all_metrics: {condition: {dataset: DatasetMetrics}}
|
| 527 |
+
where condition is "passive", "K=1", "K=2", "K=3",
|
| 528 |
+
"fixed_order", "oracle"
|
| 529 |
+
"""
|
| 530 |
+
header = f"{'Condition':<15} {'Dataset':<12} {'Top-1 Acc':<15} {'MRR':<15} {'Avg K':<8}"
|
| 531 |
+
print("=" * len(header))
|
| 532 |
+
print(header)
|
| 533 |
+
print("=" * len(header))
|
| 534 |
+
|
| 535 |
+
for condition in ["passive", "K=1", "K=2", "K=3", "fixed_order", "oracle"]:
|
| 536 |
+
if condition not in all_metrics:
|
| 537 |
+
continue
|
| 538 |
+
for ds in ["midas", "nejm", "olives"]:
|
| 539 |
+
if ds not in all_metrics[condition]:
|
| 540 |
+
continue
|
| 541 |
+
m = all_metrics[condition][ds]
|
| 542 |
+
acc_str = f"{m.top1_accuracy:.3f} ({m.top1_accuracy_ci[0]:.3f}-{m.top1_accuracy_ci[1]:.3f})"
|
| 543 |
+
mrr_str = f"{m.mrr:.3f} ({m.mrr_ci[0]:.3f}-{m.mrr_ci[1]:.3f})"
|
| 544 |
+
print(f"{condition:<15} {ds:<12} {acc_str:<15} {mrr_str:<15} {m.mean_channels_acquired:<8.1f}")
|
| 545 |
+
|
| 546 |
+
print("=" * len(header))
|
information_gain.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Information-theoretic computation for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
Provides grounded entropy and expected information gain (EIG) computation
|
| 5 |
+
from the agent's reported probability distributions. This transforms the
|
| 6 |
+
"information-theoretic framing" from a prompt label into actual computation.
|
| 7 |
+
|
| 8 |
+
Key concepts:
|
| 9 |
+
- Belief State: The agent's probability distribution over candidate diagnoses
|
| 10 |
+
- Shannon Entropy: H(p) = -sum(p_i * log2(p_i)) — measures diagnostic uncertainty
|
| 11 |
+
- Information Gain: H(before) - H(after) — how much a channel reduced uncertainty
|
| 12 |
+
- Expected Information Gain (EIG): Estimated reduction in entropy from acquiring a channel
|
| 13 |
+
- Value of Information (VoI): Whether acquiring more data is worth the cost
|
| 14 |
+
|
| 15 |
+
No training required — these are computed analytically from the probability
|
| 16 |
+
distributions the agent reports through tool calls at each step.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import logging
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class BeliefState:
|
| 31 |
+
"""
|
| 32 |
+
The agent's probability distribution over candidate diagnoses at a given step.
|
| 33 |
+
|
| 34 |
+
Extracted directly from the tool call's `current_differential` parameter,
|
| 35 |
+
so no parsing heuristics are needed.
|
| 36 |
+
"""
|
| 37 |
+
step: int
|
| 38 |
+
distribution: dict[str, float] # {diagnosis_name: probability}
|
| 39 |
+
entropy: float = 0.0
|
| 40 |
+
channel_acquired: str | None = None
|
| 41 |
+
|
| 42 |
+
def __post_init__(self):
|
| 43 |
+
self.entropy = compute_entropy(self.distribution)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class BeliefTrajectory:
|
| 48 |
+
"""
|
| 49 |
+
Full trajectory of belief states across the acquisition process.
|
| 50 |
+
|
| 51 |
+
Tracks how the agent's uncertainty evolves as it acquires information,
|
| 52 |
+
enabling information-theoretic analysis of acquisition quality.
|
| 53 |
+
"""
|
| 54 |
+
case_id: str
|
| 55 |
+
states: list[BeliefState] = field(default_factory=list)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def initial_entropy(self) -> float:
|
| 59 |
+
return self.states[0].entropy if self.states else 0.0
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def final_entropy(self) -> float:
|
| 63 |
+
return self.states[-1].entropy if self.states else 0.0
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def total_information_gain(self) -> float:
|
| 67 |
+
"""Total reduction in entropy across all acquisitions."""
|
| 68 |
+
return self.initial_entropy - self.final_entropy
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def per_step_information_gain(self) -> list[float]:
|
| 72 |
+
"""Information gain at each acquisition step."""
|
| 73 |
+
gains = []
|
| 74 |
+
for i in range(1, len(self.states)):
|
| 75 |
+
gains.append(self.states[i - 1].entropy - self.states[i].entropy)
|
| 76 |
+
return gains
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def entropy_trajectory(self) -> list[float]:
|
| 80 |
+
"""Entropy at each step."""
|
| 81 |
+
return [s.entropy for s in self.states]
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def information_efficiency(self) -> float:
|
| 85 |
+
"""
|
| 86 |
+
Information efficiency: actual IG / maximum possible IG.
|
| 87 |
+
|
| 88 |
+
Maximum possible IG is going from initial entropy to 0 (perfect certainty).
|
| 89 |
+
Returns ratio in [0, 1].
|
| 90 |
+
"""
|
| 91 |
+
if self.initial_entropy < 1e-10:
|
| 92 |
+
return 1.0 # Already certain
|
| 93 |
+
return self.total_information_gain / self.initial_entropy
|
| 94 |
+
|
| 95 |
+
def get_channel_information_values(self) -> dict[str, float]:
|
| 96 |
+
"""Map each acquired channel to its observed information gain."""
|
| 97 |
+
values = {}
|
| 98 |
+
for i in range(1, len(self.states)):
|
| 99 |
+
ch = self.states[i].channel_acquired
|
| 100 |
+
if ch:
|
| 101 |
+
values[ch] = self.states[i - 1].entropy - self.states[i].entropy
|
| 102 |
+
return values
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ============================================================
|
| 106 |
+
# Core Computations
|
| 107 |
+
# ============================================================
|
| 108 |
+
|
| 109 |
+
def compute_entropy(distribution: dict[str, float]) -> float:
|
| 110 |
+
"""
|
| 111 |
+
Shannon entropy H(p) = -sum(p_i * log2(p_i)) in bits.
|
| 112 |
+
|
| 113 |
+
Handles edge cases: p=0 contributes 0, normalizes if sum != 1.
|
| 114 |
+
"""
|
| 115 |
+
probs = np.array(list(distribution.values()), dtype=np.float64)
|
| 116 |
+
|
| 117 |
+
# Normalize if needed (VLM probabilities may not sum exactly to 1)
|
| 118 |
+
total = probs.sum()
|
| 119 |
+
if total < 1e-10:
|
| 120 |
+
return 0.0
|
| 121 |
+
probs = probs / total
|
| 122 |
+
|
| 123 |
+
# Compute entropy, handling p=0
|
| 124 |
+
entropy = 0.0
|
| 125 |
+
for p in probs:
|
| 126 |
+
if p > 1e-15:
|
| 127 |
+
entropy -= p * math.log2(p)
|
| 128 |
+
return entropy
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def compute_kl_divergence(p: dict[str, float], q: dict[str, float]) -> float:
|
| 132 |
+
"""
|
| 133 |
+
KL divergence D_KL(p || q) = sum(p_i * log2(p_i / q_i)).
|
| 134 |
+
|
| 135 |
+
Measures how much the belief shifted from q (prior) to p (posterior).
|
| 136 |
+
"""
|
| 137 |
+
all_keys = set(list(p.keys()) + list(q.keys()))
|
| 138 |
+
p_arr = np.array([p.get(k, 1e-10) for k in all_keys], dtype=np.float64)
|
| 139 |
+
q_arr = np.array([q.get(k, 1e-10) for k in all_keys], dtype=np.float64)
|
| 140 |
+
|
| 141 |
+
# Normalize
|
| 142 |
+
p_arr = p_arr / p_arr.sum()
|
| 143 |
+
q_arr = q_arr / q_arr.sum()
|
| 144 |
+
|
| 145 |
+
# Smoothing to avoid log(0)
|
| 146 |
+
q_arr = np.maximum(q_arr, 1e-10)
|
| 147 |
+
|
| 148 |
+
kl = 0.0
|
| 149 |
+
for pi, qi in zip(p_arr, q_arr):
|
| 150 |
+
if pi > 1e-15:
|
| 151 |
+
kl += pi * math.log2(pi / qi)
|
| 152 |
+
return kl
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def estimate_expected_information_gain(
|
| 156 |
+
current_distribution: dict[str, float],
|
| 157 |
+
channel_name: str,
|
| 158 |
+
expected_impact: dict[str, str],
|
| 159 |
+
candidates: list[str],
|
| 160 |
+
) -> float:
|
| 161 |
+
"""
|
| 162 |
+
Estimate expected information gain (EIG) for a candidate channel.
|
| 163 |
+
|
| 164 |
+
Uses the agent's stated expected_impact (from tool call) to estimate
|
| 165 |
+
how much the entropy would decrease. This is a lightweight approximation:
|
| 166 |
+
we model two scenarios (positive/negative finding) and compute the
|
| 167 |
+
expected entropy reduction.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
current_distribution: Current belief state
|
| 171 |
+
channel_name: Channel being evaluated
|
| 172 |
+
expected_impact: {"if_positive": diagnosis_name, "if_negative": diagnosis_name}
|
| 173 |
+
candidates: All candidate diagnoses
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Estimated information gain in bits
|
| 177 |
+
"""
|
| 178 |
+
current_entropy = compute_entropy(current_distribution)
|
| 179 |
+
|
| 180 |
+
# Model the positive scenario: the indicated diagnosis gets boosted
|
| 181 |
+
pos_target = expected_impact.get("if_positive", "")
|
| 182 |
+
neg_target = expected_impact.get("if_negative", "")
|
| 183 |
+
|
| 184 |
+
# Estimate posterior distributions under each scenario
|
| 185 |
+
pos_posterior = _shift_belief(current_distribution, pos_target, boost=0.3)
|
| 186 |
+
neg_posterior = _shift_belief(current_distribution, neg_target, boost=0.3)
|
| 187 |
+
|
| 188 |
+
# Weight scenarios by current probability of the positive-target diagnosis
|
| 189 |
+
p_positive = current_distribution.get(pos_target, 0.5)
|
| 190 |
+
p_negative = 1.0 - p_positive
|
| 191 |
+
|
| 192 |
+
expected_posterior_entropy = (
|
| 193 |
+
p_positive * compute_entropy(pos_posterior)
|
| 194 |
+
+ p_negative * compute_entropy(neg_posterior)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
eig = current_entropy - expected_posterior_entropy
|
| 198 |
+
return max(0.0, eig) # EIG should be non-negative
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _shift_belief(
|
| 202 |
+
distribution: dict[str, float],
|
| 203 |
+
target: str,
|
| 204 |
+
boost: float = 0.3,
|
| 205 |
+
) -> dict[str, float]:
|
| 206 |
+
"""
|
| 207 |
+
Shift probability mass toward a target diagnosis.
|
| 208 |
+
|
| 209 |
+
Simple model: add `boost` to target, renormalize.
|
| 210 |
+
Used for EIG estimation only.
|
| 211 |
+
"""
|
| 212 |
+
result = dict(distribution)
|
| 213 |
+
|
| 214 |
+
# Find best matching key (case-insensitive)
|
| 215 |
+
matched_key = None
|
| 216 |
+
target_lower = target.lower().strip()
|
| 217 |
+
for key in result:
|
| 218 |
+
if target_lower in key.lower() or key.lower() in target_lower:
|
| 219 |
+
matched_key = key
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
if matched_key is None:
|
| 223 |
+
return result
|
| 224 |
+
|
| 225 |
+
result[matched_key] = result.get(matched_key, 0.0) + boost
|
| 226 |
+
|
| 227 |
+
# Renormalize
|
| 228 |
+
total = sum(result.values())
|
| 229 |
+
if total > 0:
|
| 230 |
+
result = {k: v / total for k, v in result.items()}
|
| 231 |
+
|
| 232 |
+
return result
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ============================================================
|
| 236 |
+
# Stopping Criterion: When Has the Agent Gathered Enough?
|
| 237 |
+
# ============================================================
|
| 238 |
+
|
| 239 |
+
def should_commit(
|
| 240 |
+
trajectory: BeliefTrajectory,
|
| 241 |
+
available_channels: list[str],
|
| 242 |
+
min_steps: int = 0,
|
| 243 |
+
) -> tuple[bool, str]:
|
| 244 |
+
"""
|
| 245 |
+
Principled stopping criterion based on the agent's belief trajectory.
|
| 246 |
+
|
| 247 |
+
CRITICAL DESIGN PRINCIPLE: Never trust raw VLM probabilities from a
|
| 248 |
+
single observation. Weaker models (GPT-4o-mini) routinely assign 0.85
|
| 249 |
+
to wrong diagnoses after seeing just one image. Stopping criteria must
|
| 250 |
+
be grounded in OBSERVED BELIEF DYNAMICS (how beliefs changed after
|
| 251 |
+
seeing evidence), not in the raw probability the VLM reports.
|
| 252 |
+
|
| 253 |
+
Three conditions, all requiring evidence of belief stability:
|
| 254 |
+
|
| 255 |
+
1. CONVERGENCE: The last acquisition produced negligible IG (< 0.05 bits).
|
| 256 |
+
Requires >= 2 belief states. If new evidence doesn't change the
|
| 257 |
+
agent's mind, further evidence probably won't either.
|
| 258 |
+
|
| 259 |
+
2. CONFIRMED DOMINANCE: The top diagnosis has probability >= 0.90 AND
|
| 260 |
+
the gap to #2 is >= 0.40, AND the agent has acquired >= 2 channels.
|
| 261 |
+
Raw first-impression confidence is meaningless — dominance only
|
| 262 |
+
counts after the belief has SURVIVED multiple evidence updates.
|
| 263 |
+
|
| 264 |
+
3. DIMINISHING RETURNS: The last 2 acquisitions both had IG < 0.1 bits.
|
| 265 |
+
Requires >= 3 belief states. The agent hit a plateau.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
(should_commit: bool, reason: str)
|
| 269 |
+
"""
|
| 270 |
+
n_states = len(trajectory.states)
|
| 271 |
+
|
| 272 |
+
if n_states < max(1, min_steps):
|
| 273 |
+
return False, "min_steps not reached"
|
| 274 |
+
|
| 275 |
+
if not trajectory.states:
|
| 276 |
+
return False, "no belief states yet"
|
| 277 |
+
|
| 278 |
+
# Count actual acquisitions (states with a channel acquired)
|
| 279 |
+
n_acquired = sum(
|
| 280 |
+
1 for s in trajectory.states if s.channel_acquired is not None
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
latest = trajectory.states[-1]
|
| 284 |
+
dist = latest.distribution
|
| 285 |
+
|
| 286 |
+
if not dist:
|
| 287 |
+
return False, "empty distribution"
|
| 288 |
+
|
| 289 |
+
# Normalize
|
| 290 |
+
total = sum(dist.values())
|
| 291 |
+
if total < 1e-10:
|
| 292 |
+
return False, "zero distribution"
|
| 293 |
+
probs = sorted(dist.values(), reverse=True)
|
| 294 |
+
probs = [p / total for p in probs]
|
| 295 |
+
|
| 296 |
+
top1_prob = probs[0] if probs else 0
|
| 297 |
+
top2_prob = probs[1] if len(probs) > 1 else 0
|
| 298 |
+
gap = top1_prob - top2_prob
|
| 299 |
+
|
| 300 |
+
# Condition 1: CONVERGENCE — last step had negligible IG
|
| 301 |
+
# Requires at least 2 states (before/after an acquisition)
|
| 302 |
+
if n_states >= 2:
|
| 303 |
+
last_ig = (
|
| 304 |
+
trajectory.states[-2].entropy - trajectory.states[-1].entropy
|
| 305 |
+
)
|
| 306 |
+
if last_ig < 0.05 and n_acquired >= 1:
|
| 307 |
+
return True, (
|
| 308 |
+
f"convergence: last IG={last_ig:.3f} bits < 0.05 threshold "
|
| 309 |
+
f"(after {n_acquired} acquisition(s))"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Condition 2: CONFIRMED DOMINANCE — high confidence AFTER evidence
|
| 313 |
+
# Must have acquired >= 2 channels. A first-impression 0.85 is not
|
| 314 |
+
# dominance — it's overconfidence. True dominance is when the belief
|
| 315 |
+
# stays dominant after being tested by new evidence.
|
| 316 |
+
if n_acquired >= 2 and top1_prob >= 0.90 and gap >= 0.40:
|
| 317 |
+
return True, (
|
| 318 |
+
f"confirmed dominance: top1={top1_prob:.2f}, gap={gap:.2f} "
|
| 319 |
+
f"(after {n_acquired} acquisitions)"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Condition 3: DIMINISHING RETURNS — last 2 acquisitions both low IG
|
| 323 |
+
# Requires at least 3 states
|
| 324 |
+
if n_states >= 3:
|
| 325 |
+
ig_n1 = trajectory.states[-3].entropy - trajectory.states[-2].entropy
|
| 326 |
+
ig_n2 = trajectory.states[-2].entropy - trajectory.states[-1].entropy
|
| 327 |
+
if ig_n1 < 0.1 and ig_n2 < 0.1 and n_acquired >= 2:
|
| 328 |
+
return True, (
|
| 329 |
+
f"diminishing returns: last 2 IGs={ig_n1:.3f}, {ig_n2:.3f} "
|
| 330 |
+
f"(after {n_acquired} acquisitions)"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# No remaining channels
|
| 334 |
+
if not available_channels:
|
| 335 |
+
return True, "no channels remaining"
|
| 336 |
+
|
| 337 |
+
return False, "continue acquiring"
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def compute_value_of_information(
|
| 341 |
+
trajectory: BeliefTrajectory,
|
| 342 |
+
n_remaining_channels: int,
|
| 343 |
+
) -> float:
|
| 344 |
+
"""
|
| 345 |
+
Estimate the value of continuing to acquire information.
|
| 346 |
+
|
| 347 |
+
Uses the trajectory's IG history to extrapolate whether the next
|
| 348 |
+
acquisition would be worth it. Returns a score in [0, 1]:
|
| 349 |
+
- Near 0: little value in continuing (should commit)
|
| 350 |
+
- Near 1: high value in continuing (should acquire)
|
| 351 |
+
|
| 352 |
+
Method: weighted average of recent IG values, normalized by initial
|
| 353 |
+
entropy. Decays with the number of remaining channels (diminishing
|
| 354 |
+
marginal returns).
|
| 355 |
+
"""
|
| 356 |
+
if not trajectory.states or n_remaining_channels == 0:
|
| 357 |
+
return 0.0
|
| 358 |
+
|
| 359 |
+
per_step_ig = trajectory.per_step_information_gain
|
| 360 |
+
if not per_step_ig:
|
| 361 |
+
return 0.5 # No history — uncertain, lean toward acquiring
|
| 362 |
+
|
| 363 |
+
initial_h = trajectory.initial_entropy
|
| 364 |
+
if initial_h < 1e-10:
|
| 365 |
+
return 0.0 # Already certain
|
| 366 |
+
|
| 367 |
+
# Exponentially-weighted recent IG (most recent steps matter more)
|
| 368 |
+
weights = [0.5 ** i for i in range(len(per_step_ig))]
|
| 369 |
+
weights.reverse() # Most recent gets highest weight
|
| 370 |
+
weighted_ig = sum(w * ig for w, ig in zip(weights, per_step_ig))
|
| 371 |
+
weighted_ig /= sum(weights)
|
| 372 |
+
|
| 373 |
+
# Normalize by initial entropy
|
| 374 |
+
normalized_ig = weighted_ig / initial_h
|
| 375 |
+
|
| 376 |
+
# Discount by remaining channels (diminishing returns)
|
| 377 |
+
total_channels = len(trajectory.states) + n_remaining_channels
|
| 378 |
+
progress = len(trajectory.states) / total_channels
|
| 379 |
+
discount = 1.0 - (progress * 0.5) # Mild discount as we acquire more
|
| 380 |
+
|
| 381 |
+
voi = normalized_ig * discount
|
| 382 |
+
return max(0.0, min(1.0, voi))
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# ============================================================
|
| 386 |
+
# Aggregate Information-Theoretic Metrics
|
| 387 |
+
# ============================================================
|
| 388 |
+
|
| 389 |
+
def compute_information_metrics(trajectories: list[BeliefTrajectory]) -> dict:
|
| 390 |
+
"""
|
| 391 |
+
Compute aggregate information-theoretic metrics across cases.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
dict with:
|
| 395 |
+
- mean_initial_entropy: Average starting uncertainty
|
| 396 |
+
- mean_final_entropy: Average ending uncertainty
|
| 397 |
+
- mean_total_ig: Average total information gain
|
| 398 |
+
- mean_info_efficiency: Average IG / initial entropy
|
| 399 |
+
- per_channel_mean_ig: Average IG contributed by each channel
|
| 400 |
+
- entropy_reduction_curve: Mean entropy at each step
|
| 401 |
+
"""
|
| 402 |
+
if not trajectories:
|
| 403 |
+
return {}
|
| 404 |
+
|
| 405 |
+
initial_entropies = [t.initial_entropy for t in trajectories]
|
| 406 |
+
final_entropies = [t.final_entropy for t in trajectories]
|
| 407 |
+
total_igs = [t.total_information_gain for t in trajectories]
|
| 408 |
+
efficiencies = [t.information_efficiency for t in trajectories]
|
| 409 |
+
|
| 410 |
+
# Per-channel IG
|
| 411 |
+
channel_igs: dict[str, list[float]] = {}
|
| 412 |
+
for t in trajectories:
|
| 413 |
+
for ch, ig in t.get_channel_information_values().items():
|
| 414 |
+
if ch not in channel_igs:
|
| 415 |
+
channel_igs[ch] = []
|
| 416 |
+
channel_igs[ch].append(ig)
|
| 417 |
+
|
| 418 |
+
per_channel_mean_ig = {
|
| 419 |
+
ch: float(np.mean(igs)) for ch, igs in channel_igs.items()
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
# Entropy curve (pad shorter trajectories with final entropy)
|
| 423 |
+
max_steps = max(len(t.states) for t in trajectories)
|
| 424 |
+
curves = []
|
| 425 |
+
for t in trajectories:
|
| 426 |
+
curve = t.entropy_trajectory
|
| 427 |
+
# Pad with final value
|
| 428 |
+
curve += [curve[-1]] * (max_steps - len(curve))
|
| 429 |
+
curves.append(curve)
|
| 430 |
+
|
| 431 |
+
mean_curve = list(np.mean(curves, axis=0))
|
| 432 |
+
|
| 433 |
+
return {
|
| 434 |
+
"mean_initial_entropy": float(np.mean(initial_entropies)),
|
| 435 |
+
"mean_final_entropy": float(np.mean(final_entropies)),
|
| 436 |
+
"mean_total_ig": float(np.mean(total_igs)),
|
| 437 |
+
"mean_info_efficiency": float(np.mean(efficiencies)),
|
| 438 |
+
"per_channel_mean_ig": per_channel_mean_ig,
|
| 439 |
+
"entropy_reduction_curve": mean_curve,
|
| 440 |
+
"n_cases": len(trajectories),
|
| 441 |
+
}
|
policy.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Acquisition Policy Learning for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
Three learned policies, all API-based or CPU-only:
|
| 5 |
+
|
| 6 |
+
1. RewardWeightedICL: Select the best past trajectories as in-context
|
| 7 |
+
examples for the VLM. The VLM sees "here's what worked before on
|
| 8 |
+
similar cases" and makes better acquisition decisions.
|
| 9 |
+
|
| 10 |
+
2. PolicyNetwork: A small MLP trained on CPU that predicts which channel
|
| 11 |
+
to request given a featurized state. Cheap, fast, interpretable.
|
| 12 |
+
|
| 13 |
+
3. SelfReflectivePolicy: The VLM critiques its own past failures
|
| 14 |
+
and generates an improved acquisition strategy.
|
| 15 |
+
|
| 16 |
+
All three produce an acquisition policy that replaces the zero-shot
|
| 17 |
+
decision in agent.py.
|
| 18 |
+
"""
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import random
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
import config
|
| 29 |
+
from api_client import BaseVLMClient
|
| 30 |
+
from datasets.base import MedicalCase
|
| 31 |
+
from trajectory import Trajectory, TrajectoryStep
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ================================================================
|
| 37 |
+
# Approach 1: Reward-Weighted In-Context Learning (ICL)
|
| 38 |
+
# ================================================================
|
| 39 |
+
|
| 40 |
+
class RewardWeightedICL:
|
| 41 |
+
"""
|
| 42 |
+
Learn an acquisition policy via reward-weighted few-shot prompting.
|
| 43 |
+
|
| 44 |
+
Strategy:
|
| 45 |
+
1. From collected trajectories, identify GOOD acquisition decisions
|
| 46 |
+
(positive reward) and BAD ones (negative/zero reward)
|
| 47 |
+
2. For each new case, retrieve the K most similar past cases
|
| 48 |
+
(by dataset + channel overlap + uncertainty similarity)
|
| 49 |
+
3. Construct few-shot examples showing good acquisitions
|
| 50 |
+
4. The VLM sees concrete examples of "when uncertain about X,
|
| 51 |
+
requesting Y helped" and makes better decisions
|
| 52 |
+
|
| 53 |
+
This is essentially offline policy improvement via in-context learning.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
trajectories: list[Trajectory],
|
| 59 |
+
n_examples: int = 3,
|
| 60 |
+
min_reward: float = 0.05,
|
| 61 |
+
):
|
| 62 |
+
self.n_examples = n_examples
|
| 63 |
+
self.min_reward = min_reward
|
| 64 |
+
|
| 65 |
+
# Index good acquisition decisions
|
| 66 |
+
self.good_decisions: list[dict] = []
|
| 67 |
+
self.bad_decisions: list[dict] = []
|
| 68 |
+
|
| 69 |
+
for traj in trajectories:
|
| 70 |
+
for step in traj.steps:
|
| 71 |
+
if step.action == "COMMIT":
|
| 72 |
+
continue
|
| 73 |
+
decision = {
|
| 74 |
+
"case_id": traj.case_id,
|
| 75 |
+
"dataset": traj.dataset,
|
| 76 |
+
"acquired_before": step.acquired_so_far,
|
| 77 |
+
"action": step.action,
|
| 78 |
+
"uncertainty": step.uncertainty_text,
|
| 79 |
+
"reward": step.utility_reward,
|
| 80 |
+
"mrr_reward": step.reward,
|
| 81 |
+
"cost": step.acquisition_cost,
|
| 82 |
+
"diagnosis_changed": step.diagnosis_changed,
|
| 83 |
+
"diagnosis_improved": step.diagnosis_improved,
|
| 84 |
+
"mrr_before": step.mrr_before,
|
| 85 |
+
"mrr_after": step.mrr_after,
|
| 86 |
+
}
|
| 87 |
+
if step.utility_reward >= min_reward:
|
| 88 |
+
self.good_decisions.append(decision)
|
| 89 |
+
else:
|
| 90 |
+
self.bad_decisions.append(decision)
|
| 91 |
+
|
| 92 |
+
logger.info(
|
| 93 |
+
f"RewardWeightedICL: {len(self.good_decisions)} good, "
|
| 94 |
+
f"{len(self.bad_decisions)} bad decisions indexed"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def get_few_shot_examples(
|
| 98 |
+
self,
|
| 99 |
+
case: MedicalCase,
|
| 100 |
+
acquired_so_far: list[str],
|
| 101 |
+
) -> str:
|
| 102 |
+
"""
|
| 103 |
+
Retrieve the best few-shot examples for the current case state.
|
| 104 |
+
|
| 105 |
+
Returns formatted text to prepend to the acquisition prompt.
|
| 106 |
+
"""
|
| 107 |
+
# Filter to same dataset
|
| 108 |
+
candidates = [d for d in self.good_decisions if d["dataset"] == case.dataset]
|
| 109 |
+
|
| 110 |
+
if not candidates:
|
| 111 |
+
candidates = self.good_decisions # Fallback to cross-dataset
|
| 112 |
+
|
| 113 |
+
# Score by similarity to current state
|
| 114 |
+
scored = []
|
| 115 |
+
for d in candidates:
|
| 116 |
+
similarity = self._compute_similarity(d, acquired_so_far)
|
| 117 |
+
scored.append((similarity, d))
|
| 118 |
+
|
| 119 |
+
scored.sort(key=lambda x: (-x[0], -x[1]["reward"]))
|
| 120 |
+
|
| 121 |
+
# Take top N
|
| 122 |
+
selected = scored[: self.n_examples]
|
| 123 |
+
|
| 124 |
+
if not selected:
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
# Format as few-shot examples
|
| 128 |
+
lines = [
|
| 129 |
+
"Here are examples of helpful acquisition decisions from similar past cases:\n"
|
| 130 |
+
]
|
| 131 |
+
for i, (sim, d) in enumerate(selected):
|
| 132 |
+
lines.append(f"Example {i + 1}:")
|
| 133 |
+
lines.append(f" Already acquired: {d['acquired_before'] or ['(nothing)']}")
|
| 134 |
+
lines.append(f" Uncertainty: {d['uncertainty'][:150]}")
|
| 135 |
+
lines.append(f" Decision: REQUEST {d['action']}")
|
| 136 |
+
lines.append(
|
| 137 |
+
f" Outcome: MRR improved from {d['mrr_before']:.2f} to {d['mrr_after']:.2f} "
|
| 138 |
+
f"(reward: {d['reward']:+.3f})"
|
| 139 |
+
)
|
| 140 |
+
lines.append("")
|
| 141 |
+
|
| 142 |
+
lines.append(
|
| 143 |
+
"Learn from these examples. Prioritize channels that resolved similar uncertainties.\n"
|
| 144 |
+
)
|
| 145 |
+
return "\n".join(lines)
|
| 146 |
+
|
| 147 |
+
def _compute_similarity(self, decision: dict, acquired_so_far: list[str]) -> float:
|
| 148 |
+
"""
|
| 149 |
+
Compute similarity between a past decision and current state.
|
| 150 |
+
Based on acquisition stage overlap.
|
| 151 |
+
"""
|
| 152 |
+
past_acquired = set(decision["acquired_before"])
|
| 153 |
+
current_acquired = set(acquired_so_far)
|
| 154 |
+
|
| 155 |
+
# Jaccard similarity of acquisition state
|
| 156 |
+
if not past_acquired and not current_acquired:
|
| 157 |
+
return 1.0 # Both at start
|
| 158 |
+
union = past_acquired | current_acquired
|
| 159 |
+
intersection = past_acquired & current_acquired
|
| 160 |
+
stage_sim = len(intersection) / max(len(union), 1)
|
| 161 |
+
|
| 162 |
+
# Bonus for same acquisition stage (same number of channels acquired)
|
| 163 |
+
stage_match = 1.0 if len(past_acquired) == len(current_acquired) else 0.5
|
| 164 |
+
|
| 165 |
+
return stage_sim * 0.5 + stage_match * 0.5
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ================================================================
|
| 169 |
+
# Approach 2: Lightweight Policy Network (CPU-only)
|
| 170 |
+
# ================================================================
|
| 171 |
+
|
| 172 |
+
class PolicyNetwork:
|
| 173 |
+
"""
|
| 174 |
+
Small MLP that predicts which channel to request.
|
| 175 |
+
|
| 176 |
+
State features (input):
|
| 177 |
+
- One-hot: which channels have been acquired
|
| 178 |
+
- One-hot: which dataset this is
|
| 179 |
+
- Scalar: current top-1 confidence
|
| 180 |
+
- Scalar: confidence gap (top1 - top2)
|
| 181 |
+
- Scalar: acquisition step index (0, 1, 2)
|
| 182 |
+
|
| 183 |
+
Output: probability distribution over requestable channels.
|
| 184 |
+
|
| 185 |
+
Trained with cross-entropy loss weighted by trajectory reward.
|
| 186 |
+
Runs entirely on CPU — no GPU needed. This is a <1000 parameter model.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
all_channels: list[str],
|
| 192 |
+
all_datasets: list[str],
|
| 193 |
+
hidden_dim: int = 32,
|
| 194 |
+
):
|
| 195 |
+
self.all_channels = sorted(all_channels)
|
| 196 |
+
self.all_datasets = sorted(all_datasets)
|
| 197 |
+
self.channel_to_idx = {c: i for i, c in enumerate(self.all_channels)}
|
| 198 |
+
self.dataset_to_idx = {d: i for i, d in enumerate(self.all_datasets)}
|
| 199 |
+
self.n_channels = len(self.all_channels)
|
| 200 |
+
self.n_datasets = len(self.all_datasets)
|
| 201 |
+
|
| 202 |
+
# Feature dimension: acquired_mask + dataset_onehot + confidence + gap + step
|
| 203 |
+
self.input_dim = self.n_channels + self.n_datasets + 3
|
| 204 |
+
self.hidden_dim = hidden_dim
|
| 205 |
+
self.output_dim = self.n_channels
|
| 206 |
+
|
| 207 |
+
# Initialize weights (small random, CPU numpy)
|
| 208 |
+
rng = np.random.RandomState(config.SEED)
|
| 209 |
+
scale1 = np.sqrt(2.0 / self.input_dim)
|
| 210 |
+
scale2 = np.sqrt(2.0 / hidden_dim)
|
| 211 |
+
|
| 212 |
+
self.W1 = rng.randn(self.input_dim, hidden_dim).astype(np.float32) * scale1
|
| 213 |
+
self.b1 = np.zeros(hidden_dim, dtype=np.float32)
|
| 214 |
+
self.W2 = rng.randn(hidden_dim, self.output_dim).astype(np.float32) * scale2
|
| 215 |
+
self.b2 = np.zeros(self.output_dim, dtype=np.float32)
|
| 216 |
+
|
| 217 |
+
self.trained = False
|
| 218 |
+
|
| 219 |
+
def featurize(
|
| 220 |
+
self,
|
| 221 |
+
dataset: str,
|
| 222 |
+
acquired: list[str],
|
| 223 |
+
top1_confidence: float,
|
| 224 |
+
top2_confidence: float,
|
| 225 |
+
step_idx: int,
|
| 226 |
+
) -> np.ndarray:
|
| 227 |
+
"""Convert state to feature vector."""
|
| 228 |
+
features = np.zeros(self.input_dim, dtype=np.float32)
|
| 229 |
+
|
| 230 |
+
# Acquired channels mask
|
| 231 |
+
for ch in acquired:
|
| 232 |
+
if ch in self.channel_to_idx:
|
| 233 |
+
features[self.channel_to_idx[ch]] = 1.0
|
| 234 |
+
|
| 235 |
+
# Dataset one-hot
|
| 236 |
+
offset = self.n_channels
|
| 237 |
+
if dataset in self.dataset_to_idx:
|
| 238 |
+
features[offset + self.dataset_to_idx[dataset]] = 1.0
|
| 239 |
+
|
| 240 |
+
# Scalars
|
| 241 |
+
offset += self.n_datasets
|
| 242 |
+
features[offset] = top1_confidence
|
| 243 |
+
features[offset + 1] = top1_confidence - top2_confidence # Confidence gap
|
| 244 |
+
features[offset + 2] = step_idx / 3.0 # Normalized step
|
| 245 |
+
|
| 246 |
+
return features
|
| 247 |
+
|
| 248 |
+
def predict(
|
| 249 |
+
self,
|
| 250 |
+
features: np.ndarray,
|
| 251 |
+
available_channels: list[str],
|
| 252 |
+
) -> dict[str, float]:
|
| 253 |
+
"""
|
| 254 |
+
Forward pass: predict channel selection probabilities.
|
| 255 |
+
|
| 256 |
+
Returns dict mapping channel_name → probability.
|
| 257 |
+
Only available (not yet acquired) channels get nonzero probability.
|
| 258 |
+
"""
|
| 259 |
+
# Forward pass: input → ReLU → softmax (masked)
|
| 260 |
+
h = np.maximum(0, features @ self.W1 + self.b1) # ReLU
|
| 261 |
+
logits = h @ self.W2 + self.b2
|
| 262 |
+
|
| 263 |
+
# Mask unavailable channels to -inf
|
| 264 |
+
mask = np.full(self.output_dim, -1e9, dtype=np.float32)
|
| 265 |
+
for ch in available_channels:
|
| 266 |
+
if ch in self.channel_to_idx:
|
| 267 |
+
mask[self.channel_to_idx[ch]] = 0.0
|
| 268 |
+
logits = logits + mask
|
| 269 |
+
|
| 270 |
+
# Softmax
|
| 271 |
+
logits = logits - logits.max()
|
| 272 |
+
exp_logits = np.exp(logits)
|
| 273 |
+
probs = exp_logits / (exp_logits.sum() + 1e-8)
|
| 274 |
+
|
| 275 |
+
return {ch: float(probs[self.channel_to_idx[ch]])
|
| 276 |
+
for ch in available_channels if ch in self.channel_to_idx}
|
| 277 |
+
|
| 278 |
+
def train(
|
| 279 |
+
self,
|
| 280 |
+
trajectories: list[Trajectory],
|
| 281 |
+
lr: float = 0.01,
|
| 282 |
+
n_epochs: int = 100,
|
| 283 |
+
reward_temperature: float = 1.0,
|
| 284 |
+
):
|
| 285 |
+
"""
|
| 286 |
+
Train the policy network on collected trajectories.
|
| 287 |
+
|
| 288 |
+
Uses reward-weighted cross-entropy:
|
| 289 |
+
loss = -sum(reward * log(P(action|state)))
|
| 290 |
+
|
| 291 |
+
Positive rewards encourage the action; negative discourage it.
|
| 292 |
+
"""
|
| 293 |
+
# Build training data
|
| 294 |
+
X = []
|
| 295 |
+
actions = []
|
| 296 |
+
rewards = []
|
| 297 |
+
available_masks = []
|
| 298 |
+
|
| 299 |
+
for traj in trajectories:
|
| 300 |
+
for step in traj.steps:
|
| 301 |
+
if step.action == "COMMIT":
|
| 302 |
+
continue
|
| 303 |
+
if step.action not in self.channel_to_idx:
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
# Extract features from the step's state
|
| 307 |
+
top1_conf = step.differential_before[0]["confidence"] if step.differential_before else 0.5
|
| 308 |
+
top2_conf = step.differential_before[1]["confidence"] if len(step.differential_before) > 1 else 0.0
|
| 309 |
+
|
| 310 |
+
feat = self.featurize(
|
| 311 |
+
dataset=traj.dataset,
|
| 312 |
+
acquired=step.acquired_so_far,
|
| 313 |
+
top1_confidence=top1_conf,
|
| 314 |
+
top2_confidence=top2_conf,
|
| 315 |
+
step_idx=step.step_idx,
|
| 316 |
+
)
|
| 317 |
+
X.append(feat)
|
| 318 |
+
actions.append(self.channel_to_idx[step.action])
|
| 319 |
+
|
| 320 |
+
# Reward shaping: normalize across trajectories
|
| 321 |
+
rewards.append(step.utility_reward)
|
| 322 |
+
|
| 323 |
+
# Available channels mask
|
| 324 |
+
mask = np.zeros(self.output_dim, dtype=np.float32)
|
| 325 |
+
for ch in step.available_channels:
|
| 326 |
+
if ch in self.channel_to_idx:
|
| 327 |
+
mask[self.channel_to_idx[ch]] = 1.0
|
| 328 |
+
available_masks.append(mask)
|
| 329 |
+
|
| 330 |
+
if not X:
|
| 331 |
+
logger.warning("No training data available for policy network")
|
| 332 |
+
return
|
| 333 |
+
|
| 334 |
+
X = np.array(X)
|
| 335 |
+
actions = np.array(actions)
|
| 336 |
+
rewards = np.array(rewards)
|
| 337 |
+
available_masks = np.array(available_masks)
|
| 338 |
+
|
| 339 |
+
# Normalize rewards
|
| 340 |
+
if rewards.std() > 0:
|
| 341 |
+
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
| 342 |
+
|
| 343 |
+
# Apply temperature
|
| 344 |
+
weights = np.exp(rewards * reward_temperature)
|
| 345 |
+
weights = weights / weights.sum() * len(weights) # Normalize to mean=1
|
| 346 |
+
|
| 347 |
+
n = len(X)
|
| 348 |
+
logger.info(f"Training policy network on {n} state-action pairs for {n_epochs} epochs")
|
| 349 |
+
|
| 350 |
+
for epoch in range(n_epochs):
|
| 351 |
+
# Forward pass
|
| 352 |
+
h = np.maximum(0, X @ self.W1 + self.b1)
|
| 353 |
+
logits = h @ self.W2 + self.b2
|
| 354 |
+
|
| 355 |
+
# Mask unavailable channels
|
| 356 |
+
logits = logits + (1 - available_masks) * (-1e9)
|
| 357 |
+
|
| 358 |
+
# Softmax
|
| 359 |
+
logits_shifted = logits - logits.max(axis=1, keepdims=True)
|
| 360 |
+
exp_logits = np.exp(logits_shifted)
|
| 361 |
+
probs = exp_logits / (exp_logits.sum(axis=1, keepdims=True) + 1e-8)
|
| 362 |
+
|
| 363 |
+
# Cross-entropy loss (reward-weighted)
|
| 364 |
+
action_probs = probs[np.arange(n), actions]
|
| 365 |
+
loss = -np.mean(weights * np.log(action_probs + 1e-8))
|
| 366 |
+
|
| 367 |
+
# Backward pass (manual gradient)
|
| 368 |
+
# dL/d_logits = probs - one_hot(action), weighted by reward
|
| 369 |
+
grad_logits = probs.copy()
|
| 370 |
+
grad_logits[np.arange(n), actions] -= 1.0
|
| 371 |
+
grad_logits *= weights[:, np.newaxis] / n
|
| 372 |
+
|
| 373 |
+
# Gradient for W2, b2
|
| 374 |
+
grad_W2 = h.T @ grad_logits
|
| 375 |
+
grad_b2 = grad_logits.sum(axis=0)
|
| 376 |
+
|
| 377 |
+
# Gradient for W1, b1 (through ReLU)
|
| 378 |
+
grad_h = grad_logits @ self.W2.T
|
| 379 |
+
grad_h *= (h > 0).astype(np.float32) # ReLU derivative
|
| 380 |
+
grad_W1 = X.T @ grad_h
|
| 381 |
+
grad_b1 = grad_h.sum(axis=0)
|
| 382 |
+
|
| 383 |
+
# Update
|
| 384 |
+
self.W1 -= lr * grad_W1
|
| 385 |
+
self.b1 -= lr * grad_b1
|
| 386 |
+
self.W2 -= lr * grad_W2
|
| 387 |
+
self.b2 -= lr * grad_b2
|
| 388 |
+
|
| 389 |
+
if (epoch + 1) % 20 == 0:
|
| 390 |
+
# Compute accuracy
|
| 391 |
+
predicted = np.argmax(probs, axis=1)
|
| 392 |
+
accuracy = np.mean(predicted == actions)
|
| 393 |
+
logger.info(f" Epoch {epoch + 1}: loss={loss:.4f}, accuracy={accuracy:.3f}")
|
| 394 |
+
|
| 395 |
+
self.trained = True
|
| 396 |
+
logger.info("Policy network training complete")
|
| 397 |
+
|
| 398 |
+
def get_action(
|
| 399 |
+
self,
|
| 400 |
+
case: MedicalCase,
|
| 401 |
+
acquired: list[str],
|
| 402 |
+
differential: list[dict],
|
| 403 |
+
step_idx: int,
|
| 404 |
+
) -> str:
|
| 405 |
+
"""Select the best channel to request using the learned policy."""
|
| 406 |
+
available = [ch for ch in case.requestable_names if ch not in acquired]
|
| 407 |
+
if not available:
|
| 408 |
+
return "COMMIT"
|
| 409 |
+
|
| 410 |
+
top1_conf = differential[0]["confidence"] if differential else 0.5
|
| 411 |
+
top2_conf = differential[1]["confidence"] if len(differential) > 1 else 0.0
|
| 412 |
+
|
| 413 |
+
features = self.featurize(
|
| 414 |
+
dataset=case.dataset,
|
| 415 |
+
acquired=acquired,
|
| 416 |
+
top1_confidence=top1_conf,
|
| 417 |
+
top2_confidence=top2_conf,
|
| 418 |
+
step_idx=step_idx,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
probs = self.predict(features, available)
|
| 422 |
+
|
| 423 |
+
if not probs:
|
| 424 |
+
return random.choice(available)
|
| 425 |
+
|
| 426 |
+
# Select highest probability channel
|
| 427 |
+
best_channel = max(probs, key=probs.get)
|
| 428 |
+
return best_channel
|
| 429 |
+
|
| 430 |
+
def save(self, path: Path):
|
| 431 |
+
"""Save model weights."""
|
| 432 |
+
np.savez(
|
| 433 |
+
path,
|
| 434 |
+
W1=self.W1, b1=self.b1,
|
| 435 |
+
W2=self.W2, b2=self.b2,
|
| 436 |
+
channels=self.all_channels,
|
| 437 |
+
datasets=self.all_datasets,
|
| 438 |
+
)
|
| 439 |
+
logger.info(f"Saved policy network to {path}")
|
| 440 |
+
|
| 441 |
+
def load(self, path: Path):
|
| 442 |
+
"""Load model weights."""
|
| 443 |
+
data = np.load(path, allow_pickle=True)
|
| 444 |
+
self.W1 = data["W1"]
|
| 445 |
+
self.b1 = data["b1"]
|
| 446 |
+
self.W2 = data["W2"]
|
| 447 |
+
self.b2 = data["b2"]
|
| 448 |
+
self.trained = True
|
| 449 |
+
logger.info(f"Loaded policy network from {path}")
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ================================================================
|
| 453 |
+
# Approach 3: Self-Reflective Refinement
|
| 454 |
+
# ================================================================
|
| 455 |
+
|
| 456 |
+
class SelfReflectivePolicy:
|
| 457 |
+
"""
|
| 458 |
+
The VLM critiques its own past failures and generates improved strategies.
|
| 459 |
+
|
| 460 |
+
Pipeline:
|
| 461 |
+
1. Collect cases where zero-shot acquisition was suboptimal
|
| 462 |
+
(the agent requested info that didn't help, or missed info that would have)
|
| 463 |
+
2. Show the VLM its own failure traces and ask it to generate
|
| 464 |
+
"acquisition rules" — structured if-then policies
|
| 465 |
+
3. Inject these self-generated rules into the system prompt
|
| 466 |
+
4. Re-run with the improved prompt
|
| 467 |
+
|
| 468 |
+
This is a form of self-play / self-improvement via reflection.
|
| 469 |
+
"""
|
| 470 |
+
|
| 471 |
+
def __init__(self, client: BaseVLMClient, dataset_name: str):
|
| 472 |
+
self.client = client
|
| 473 |
+
self.dataset_name = dataset_name
|
| 474 |
+
self.rules: list[str] = []
|
| 475 |
+
|
| 476 |
+
def generate_rules_from_failures(
|
| 477 |
+
self,
|
| 478 |
+
trajectories: list[Trajectory],
|
| 479 |
+
n_failure_examples: int = 10,
|
| 480 |
+
) -> list[str]:
|
| 481 |
+
"""
|
| 482 |
+
Analyze failures and generate acquisition rules.
|
| 483 |
+
|
| 484 |
+
A "failure" is a case where:
|
| 485 |
+
- Agent requested a channel with zero or negative utility
|
| 486 |
+
- Agent didn't request a channel that would have helped
|
| 487 |
+
- Agent committed too early (final MRR << oracle MRR)
|
| 488 |
+
"""
|
| 489 |
+
# Collect failure examples
|
| 490 |
+
failures = []
|
| 491 |
+
|
| 492 |
+
for traj in trajectories:
|
| 493 |
+
if traj.dataset != self.dataset_name:
|
| 494 |
+
continue
|
| 495 |
+
|
| 496 |
+
# Type 1: Unhelpful acquisitions
|
| 497 |
+
for step in traj.steps:
|
| 498 |
+
if step.action != "COMMIT" and step.utility_reward <= 0:
|
| 499 |
+
failures.append({
|
| 500 |
+
"type": "unhelpful_acquisition",
|
| 501 |
+
"case_id": traj.case_id,
|
| 502 |
+
"action": step.action,
|
| 503 |
+
"uncertainty": step.uncertainty_text[:200],
|
| 504 |
+
"utility_reward": step.utility_reward,
|
| 505 |
+
"mrr_reward": step.reward,
|
| 506 |
+
"cost": step.acquisition_cost,
|
| 507 |
+
"available": step.available_channels,
|
| 508 |
+
})
|
| 509 |
+
|
| 510 |
+
# Type 2: Premature commitment
|
| 511 |
+
if traj.final_mrr < traj.oracle_mrr - 0.2:
|
| 512 |
+
failures.append({
|
| 513 |
+
"type": "premature_commit",
|
| 514 |
+
"case_id": traj.case_id,
|
| 515 |
+
"acquired": [s.action for s in traj.steps if s.action != "COMMIT"],
|
| 516 |
+
"final_mrr": traj.final_mrr,
|
| 517 |
+
"oracle_mrr": traj.oracle_mrr,
|
| 518 |
+
"gap": traj.oracle_mrr - traj.final_mrr,
|
| 519 |
+
})
|
| 520 |
+
|
| 521 |
+
if not failures:
|
| 522 |
+
logger.info("No failures found — zero-shot policy may already be strong")
|
| 523 |
+
return []
|
| 524 |
+
|
| 525 |
+
# Sample failures
|
| 526 |
+
random.shuffle(failures)
|
| 527 |
+
sampled = failures[:n_failure_examples]
|
| 528 |
+
|
| 529 |
+
# Ask the VLM to analyze and generate rules
|
| 530 |
+
failure_text = json.dumps(sampled, indent=2, default=str)
|
| 531 |
+
|
| 532 |
+
prompt = f"""You are analyzing an AI medical diagnostic agent's acquisition failures on {self.dataset_name} cases.
|
| 533 |
+
The agent must decide what additional information to request (imaging modalities, clinical data, etc.) before making a diagnosis.
|
| 534 |
+
|
| 535 |
+
Here are examples of FAILED acquisition decisions:
|
| 536 |
+
|
| 537 |
+
{failure_text}
|
| 538 |
+
|
| 539 |
+
Based on these failures, generate 5-8 specific, actionable ACQUISITION RULES that would improve future decisions.
|
| 540 |
+
|
| 541 |
+
Format each rule as:
|
| 542 |
+
RULE N: IF [condition about the current state/uncertainty] THEN [specific acquisition action] BECAUSE [reasoning]
|
| 543 |
+
|
| 544 |
+
Rules should be specific to the {self.dataset_name} dataset and its available channels.
|
| 545 |
+
Focus on patterns across failures, not individual cases.
|
| 546 |
+
Be concrete — "request OCT when uncertain about subretinal fluid" is better than "request more information when uncertain."
|
| 547 |
+
|
| 548 |
+
Respond ONLY with the rules, no preamble."""
|
| 549 |
+
|
| 550 |
+
response = self.client.call_with_retry(
|
| 551 |
+
system_prompt="You are an expert in medical diagnostic AI systems.",
|
| 552 |
+
user_text=prompt,
|
| 553 |
+
images=None,
|
| 554 |
+
temperature=0.3,
|
| 555 |
+
max_tokens=2048,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Parse rules
|
| 559 |
+
rules = []
|
| 560 |
+
for line in response.text.split("\n"):
|
| 561 |
+
line = line.strip()
|
| 562 |
+
if line.startswith("RULE") or line.startswith("Rule"):
|
| 563 |
+
rules.append(line)
|
| 564 |
+
elif rules and line and not line.startswith("RULE"):
|
| 565 |
+
# Continuation of previous rule
|
| 566 |
+
rules[-1] += " " + line
|
| 567 |
+
|
| 568 |
+
self.rules = rules
|
| 569 |
+
logger.info(f"Generated {len(rules)} acquisition rules from {len(sampled)} failures")
|
| 570 |
+
for r in rules:
|
| 571 |
+
logger.info(f" {r[:120]}...")
|
| 572 |
+
|
| 573 |
+
return rules
|
| 574 |
+
|
| 575 |
+
def get_enhanced_system_prompt(self, base_prompt: str) -> str:
|
| 576 |
+
"""
|
| 577 |
+
Inject learned rules into the system prompt.
|
| 578 |
+
|
| 579 |
+
This is the key mechanism: the VLM's behavior is modified
|
| 580 |
+
by giving it its own self-generated rules as instructions.
|
| 581 |
+
"""
|
| 582 |
+
if not self.rules:
|
| 583 |
+
return base_prompt
|
| 584 |
+
|
| 585 |
+
rules_text = "\n".join(self.rules)
|
| 586 |
+
injection = f"""
|
| 587 |
+
|
| 588 |
+
LEARNED ACQUISITION STRATEGY (from analyzing past diagnostic cases):
|
| 589 |
+
The following rules have been learned from analyzing cases where acquisition
|
| 590 |
+
decisions were suboptimal. Apply these rules when deciding what information to request:
|
| 591 |
+
|
| 592 |
+
{rules_text}
|
| 593 |
+
|
| 594 |
+
Apply these rules in addition to your general diagnostic reasoning."""
|
| 595 |
+
|
| 596 |
+
return base_prompt + injection
|
| 597 |
+
|
| 598 |
+
def save_rules(self, path: Path):
|
| 599 |
+
"""Save generated rules."""
|
| 600 |
+
with open(path, "w") as f:
|
| 601 |
+
json.dump({"dataset": self.dataset_name, "rules": self.rules}, f, indent=2)
|
| 602 |
+
|
| 603 |
+
def load_rules(self, path: Path):
|
| 604 |
+
"""Load previously generated rules."""
|
| 605 |
+
with open(path) as f:
|
| 606 |
+
data = json.load(f)
|
| 607 |
+
self.rules = data["rules"]
|
| 608 |
+
logger.info(f"Loaded {len(self.rules)} rules for {self.dataset_name}")
|
prompts.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
Three semantically equivalent but lexically different variants (A/B/C)
|
| 5 |
+
for prompt sensitivity analysis.
|
| 6 |
+
|
| 7 |
+
Each prompt has:
|
| 8 |
+
- system_prompt: Sets the agent's role and reasoning format
|
| 9 |
+
- acquisition_prompt: Asks the agent to decide what to request next
|
| 10 |
+
- diagnosis_prompt: Asks the agent to commit to a ranked differential
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
# ============================================================
|
| 14 |
+
# Channel description formatters
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
def format_available_channels(channels: dict, already_acquired: list[str]) -> str:
|
| 18 |
+
"""Format the list of requestable channels for the prompt."""
|
| 19 |
+
lines = []
|
| 20 |
+
sortable = []
|
| 21 |
+
for name, info in channels.items():
|
| 22 |
+
if info.get("always_given"):
|
| 23 |
+
continue
|
| 24 |
+
if name in already_acquired:
|
| 25 |
+
continue
|
| 26 |
+
sortable.append((info.get("order", 999), info.get("cost", 0.0), name, info))
|
| 27 |
+
for _, _, name, info in sorted(sortable):
|
| 28 |
+
cost = float(info.get("cost", 0.0))
|
| 29 |
+
tier = info.get("tier", "unknown")
|
| 30 |
+
lines.append(
|
| 31 |
+
f" - [{name}]: {info['description']} "
|
| 32 |
+
f"(tier: {tier}, cost: ${cost:,.0f})"
|
| 33 |
+
)
|
| 34 |
+
if not lines:
|
| 35 |
+
return " (No additional information available to request.)"
|
| 36 |
+
return "\n".join(lines)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def format_acquired_info(acquired_data: dict) -> str:
|
| 40 |
+
"""Format all previously acquired information for context."""
|
| 41 |
+
if not acquired_data:
|
| 42 |
+
return "(No additional information acquired yet.)"
|
| 43 |
+
parts = []
|
| 44 |
+
for channel_name, content in acquired_data.items():
|
| 45 |
+
if content["type"] == "text":
|
| 46 |
+
parts.append(f"[{channel_name}]: {content['value']}")
|
| 47 |
+
elif content["type"] == "image":
|
| 48 |
+
parts.append(f"[{channel_name}]: (image provided)")
|
| 49 |
+
return "\n".join(parts)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ============================================================
|
| 53 |
+
# Prompt Variant A — Clinical Framing
|
| 54 |
+
# ============================================================
|
| 55 |
+
|
| 56 |
+
VARIANT_A = {
|
| 57 |
+
"name": "clinical",
|
| 58 |
+
|
| 59 |
+
"system_prompt": """You are an experienced physician performing a diagnostic evaluation. \
|
| 60 |
+
You will be shown a medical image and possibly additional clinical information. \
|
| 61 |
+
Your goal is to arrive at the most accurate diagnosis by strategically requesting \
|
| 62 |
+
the most informative additional data.
|
| 63 |
+
|
| 64 |
+
You reason through cases using a structured clinical approach:
|
| 65 |
+
1. OBSERVATION: Describe what you see in the available image(s) and data.
|
| 66 |
+
2. DIFFERENTIAL: List your top 3-5 candidate diagnoses ranked by likelihood, with confidence estimates (0-1).
|
| 67 |
+
3. UNCERTAINTY: Identify specifically what you are uncertain about — which diagnoses cannot be distinguished with current information and WHY.
|
| 68 |
+
4. ACTION: You MUST request one additional piece of information. Choose the one that would best disambiguate your top differential diagnoses.
|
| 69 |
+
|
| 70 |
+
CRITICAL: You must ALWAYS use your remaining budget to request information. \
|
| 71 |
+
Do NOT commit early — additional information almost always improves diagnostic accuracy. \
|
| 72 |
+
Always respond in this exact structured format.""",
|
| 73 |
+
|
| 74 |
+
"acquisition_prompt": """You have {remaining_budget} request(s) remaining. You MUST use them.
|
| 75 |
+
|
| 76 |
+
Available information you can request:
|
| 77 |
+
{available_channels}
|
| 78 |
+
|
| 79 |
+
Previously acquired information:
|
| 80 |
+
{acquired_info}
|
| 81 |
+
|
| 82 |
+
Think carefully: which available channel would MOST help distinguish between your top diagnoses?
|
| 83 |
+
|
| 84 |
+
Respond in EXACTLY this format:
|
| 85 |
+
OBSERVATION: [What you observe from all currently available information]
|
| 86 |
+
DIFFERENTIAL: [Ranked list — format each as "N. DiagnosisName (confidence: X.XX)"]
|
| 87 |
+
UNCERTAINTY: [Which two diagnoses are hardest to tell apart, and what specific information would resolve it]
|
| 88 |
+
ACTION: REQUEST [channel_name]
|
| 89 |
+
|
| 90 |
+
IMPORTANT: Replace [channel_name] with exactly one of the available channel names listed above. \
|
| 91 |
+
You MUST request a channel — do not skip or commit early.""",
|
| 92 |
+
|
| 93 |
+
"diagnosis_prompt": """You strategically gathered the most relevant clinical information. \
|
| 94 |
+
Now provide your final diagnosis. Focus on the evidence you acquired — it was selected \
|
| 95 |
+
specifically to resolve diagnostic uncertainty.
|
| 96 |
+
|
| 97 |
+
Information you gathered:
|
| 98 |
+
{acquired_info}
|
| 99 |
+
|
| 100 |
+
Candidate diagnoses to rank:
|
| 101 |
+
{candidates}
|
| 102 |
+
|
| 103 |
+
Respond in the structured format:
|
| 104 |
+
OBSERVATION: [Synthesis of the key findings from your acquired information]
|
| 105 |
+
DIFFERENTIAL: [Ranked candidates — format each as "N. DiagnosisName (confidence: X.XX)"]
|
| 106 |
+
REASONING: [Key evidence from your acquired data supporting your top diagnosis and ruling out alternatives]""",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ============================================================
|
| 111 |
+
# Prompt Variant B — Information-Theoretic Framing
|
| 112 |
+
# ============================================================
|
| 113 |
+
|
| 114 |
+
VARIANT_B = {
|
| 115 |
+
"name": "information_theoretic",
|
| 116 |
+
|
| 117 |
+
"system_prompt": """You are an AI diagnostic system analyzing medical data under \
|
| 118 |
+
conditions of incomplete information. You process available evidence and estimate which \
|
| 119 |
+
additional data sources would most reduce your diagnostic uncertainty.
|
| 120 |
+
|
| 121 |
+
Your reasoning follows a structured protocol:
|
| 122 |
+
1. EVIDENCE: Catalog the findings from all available inputs.
|
| 123 |
+
2. HYPOTHESES: Rank candidate diagnoses by posterior probability (0-1, must sum to ≤1).
|
| 124 |
+
3. INFORMATION GAP: Identify the highest-uncertainty region of your hypothesis space.
|
| 125 |
+
4. ACQUISITION: Select the data source with highest expected information gain, or finalize.
|
| 126 |
+
|
| 127 |
+
Always respond in this exact structured format. Be precise with probabilities.""",
|
| 128 |
+
|
| 129 |
+
"acquisition_prompt": """Analyze your current diagnostic uncertainty and determine the \
|
| 130 |
+
optimal next data acquisition. You have {remaining_budget} acquisition(s) remaining.
|
| 131 |
+
|
| 132 |
+
Requestable data sources:
|
| 133 |
+
{available_channels}
|
| 134 |
+
|
| 135 |
+
Previously acquired data:
|
| 136 |
+
{acquired_info}
|
| 137 |
+
|
| 138 |
+
Respond in the structured format:
|
| 139 |
+
EVIDENCE: [Findings extracted from all currently available data]
|
| 140 |
+
HYPOTHESES: [Ranked list — format each as "N. DiagnosisName (probability: X.XX)"]
|
| 141 |
+
INFORMATION GAP: [Which distinction between top hypotheses cannot be resolved with current data, and why]
|
| 142 |
+
ACQUISITION: REQUEST [channel_name] — [expected information gain explanation]
|
| 143 |
+
|
| 144 |
+
If your top hypothesis probability exceeds 0.8 and is well-separated from alternatives:
|
| 145 |
+
ACQUISITION: FINALIZE""",
|
| 146 |
+
|
| 147 |
+
"diagnosis_prompt": """All data acquisition is complete. Produce your final ranked \
|
| 148 |
+
hypothesis set.
|
| 149 |
+
|
| 150 |
+
Accumulated data:
|
| 151 |
+
{acquired_info}
|
| 152 |
+
|
| 153 |
+
Candidate diagnoses to rank:
|
| 154 |
+
{candidates}
|
| 155 |
+
|
| 156 |
+
Respond in the structured format:
|
| 157 |
+
EVIDENCE: [Complete synthesis of all acquired data]
|
| 158 |
+
HYPOTHESES: [Final ranked candidates — format each as "N. DiagnosisName (probability: X.XX)"]
|
| 159 |
+
JUSTIFICATION: [Evidence chain supporting top hypothesis; contradicting evidence for alternatives]""",
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ============================================================
|
| 164 |
+
# Prompt Variant C — Neutral/Minimal Framing
|
| 165 |
+
# ============================================================
|
| 166 |
+
|
| 167 |
+
VARIANT_C = {
|
| 168 |
+
"name": "neutral",
|
| 169 |
+
|
| 170 |
+
"system_prompt": """You are assisting with medical image analysis. Given a medical image \
|
| 171 |
+
and possibly additional information, identify the most likely diagnosis from a set of candidates.
|
| 172 |
+
|
| 173 |
+
You may request additional information before making your final decision. Structure your \
|
| 174 |
+
response as follows:
|
| 175 |
+
1. FINDINGS: What you observe.
|
| 176 |
+
2. RANKING: Candidate diagnoses ranked with confidence scores (0-1).
|
| 177 |
+
3. GAPS: What you don't know that would help.
|
| 178 |
+
4. DECISION: Request more info or commit.""",
|
| 179 |
+
|
| 180 |
+
"acquisition_prompt": """You may request one more piece of information. \
|
| 181 |
+
{remaining_budget} request(s) left.
|
| 182 |
+
|
| 183 |
+
Options:
|
| 184 |
+
{available_channels}
|
| 185 |
+
|
| 186 |
+
Information so far:
|
| 187 |
+
{acquired_info}
|
| 188 |
+
|
| 189 |
+
Respond:
|
| 190 |
+
FINDINGS: [Current observations]
|
| 191 |
+
RANKING: [Format: "N. DiagnosisName (confidence: X.XX)"]
|
| 192 |
+
GAPS: [What's missing]
|
| 193 |
+
DECISION: REQUEST [channel_name] — [reason]
|
| 194 |
+
|
| 195 |
+
Or if ready:
|
| 196 |
+
DECISION: COMMIT""",
|
| 197 |
+
|
| 198 |
+
"diagnosis_prompt": """Provide your final diagnosis ranking.
|
| 199 |
+
|
| 200 |
+
All information:
|
| 201 |
+
{acquired_info}
|
| 202 |
+
|
| 203 |
+
Candidates:
|
| 204 |
+
{candidates}
|
| 205 |
+
|
| 206 |
+
Respond:
|
| 207 |
+
FINDINGS: [Summary]
|
| 208 |
+
RANKING: [Format: "N. DiagnosisName (confidence: X.XX)"]
|
| 209 |
+
REASONING: [Brief justification]""",
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ============================================================
|
| 214 |
+
# Variant Registry
|
| 215 |
+
# ============================================================
|
| 216 |
+
|
| 217 |
+
PROMPT_VARIANTS = {
|
| 218 |
+
"A": VARIANT_A,
|
| 219 |
+
"B": VARIANT_B,
|
| 220 |
+
"C": VARIANT_C,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_prompt_variant(variant_id: str) -> dict:
|
| 225 |
+
"""Retrieve a prompt variant by ID."""
|
| 226 |
+
if variant_id not in PROMPT_VARIANTS:
|
| 227 |
+
raise ValueError(f"Unknown prompt variant: {variant_id}. Choose from {list(PROMPT_VARIANTS.keys())}")
|
| 228 |
+
return PROMPT_VARIANTS[variant_id]
|
reasoning_analysis.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reasoning Faithfulness & Acquisition Pattern Analysis.
|
| 3 |
+
|
| 4 |
+
Key analyses for ACL/EMNLP submission:
|
| 5 |
+
|
| 6 |
+
1. Reasoning Faithfulness: Does the agent's stated reasoning match
|
| 7 |
+
actual information gain? When it says "I need X to distinguish
|
| 8 |
+
A from B", does X actually shift probability between A and B?
|
| 9 |
+
|
| 10 |
+
2. Acquisition Order Patterns: What ordering strategies do different
|
| 11 |
+
models learn? Are they consistent? Do they match clinical guidelines?
|
| 12 |
+
|
| 13 |
+
3. Error Analysis: When the agent commits early and is wrong, what
|
| 14 |
+
went wrong in the reasoning chain?
|
| 15 |
+
|
| 16 |
+
4. Stopping Decision Quality: Are the agent's commit decisions well-timed?
|
| 17 |
+
"""
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import re
|
| 21 |
+
from collections import Counter, defaultdict
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
from scipy.stats import spearmanr, kendalltau
|
| 27 |
+
|
| 28 |
+
from agent import AgentResult, AcquisitionStep
|
| 29 |
+
from datasets.base import MedicalCase
|
| 30 |
+
from information_gain import compute_entropy, compute_kl_divergence
|
| 31 |
+
from evaluation import evaluate_single_case, compute_reciprocal_rank
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ================================================================
|
| 37 |
+
# 1. Reasoning Faithfulness
|
| 38 |
+
# ================================================================
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class FaithfulnessMetrics:
|
| 42 |
+
"""Per-step reasoning faithfulness measurement."""
|
| 43 |
+
case_id: str
|
| 44 |
+
step: int
|
| 45 |
+
channel_requested: str
|
| 46 |
+
# What the agent said
|
| 47 |
+
stated_reasoning: str
|
| 48 |
+
stated_if_positive: str
|
| 49 |
+
stated_if_negative: str
|
| 50 |
+
# What actually happened
|
| 51 |
+
target_diagnosis_before: float # Probability of stated target before
|
| 52 |
+
target_diagnosis_after: float # Probability of stated target after
|
| 53 |
+
actual_shift: float # Change in target probability
|
| 54 |
+
shift_direction_correct: bool # Did it shift the way the agent predicted?
|
| 55 |
+
# Information metrics
|
| 56 |
+
entropy_before: float
|
| 57 |
+
entropy_after: float
|
| 58 |
+
actual_ig: float
|
| 59 |
+
predicted_useful: bool # Agent thought this would help
|
| 60 |
+
actually_useful: bool # IG > 0.05 bits
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compute_reasoning_faithfulness(
|
| 64 |
+
results: list[AgentResult],
|
| 65 |
+
cases: list[MedicalCase],
|
| 66 |
+
) -> dict:
|
| 67 |
+
"""
|
| 68 |
+
Measure whether the agent's stated reasoning matches what actually
|
| 69 |
+
happens when information is acquired.
|
| 70 |
+
|
| 71 |
+
For each acquisition step where the agent states expected_impact:
|
| 72 |
+
- Extract the target diagnosis (if_positive/if_negative)
|
| 73 |
+
- Compare probability of that diagnosis before and after acquisition
|
| 74 |
+
- Check if the shift matches the agent's prediction
|
| 75 |
+
|
| 76 |
+
Returns aggregate faithfulness metrics.
|
| 77 |
+
"""
|
| 78 |
+
per_step_metrics = []
|
| 79 |
+
direction_correct_count = 0
|
| 80 |
+
useful_when_predicted = 0
|
| 81 |
+
total_with_impact = 0
|
| 82 |
+
|
| 83 |
+
for result, case in zip(results, cases):
|
| 84 |
+
for i, step in enumerate(result.steps):
|
| 85 |
+
if step.committed or not step.expected_impact:
|
| 86 |
+
continue
|
| 87 |
+
if not step.differential:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
total_with_impact += 1
|
| 91 |
+
|
| 92 |
+
# Current distribution (before receiving the info)
|
| 93 |
+
current_dist = {
|
| 94 |
+
d.get("name", ""): d.get("confidence", 0)
|
| 95 |
+
for d in step.differential
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Find next step's distribution (after receiving the info)
|
| 99 |
+
next_dist = None
|
| 100 |
+
if i + 1 < len(result.steps) and result.steps[i + 1].differential:
|
| 101 |
+
next_dist = {
|
| 102 |
+
d.get("name", ""): d.get("confidence", 0)
|
| 103 |
+
for d in result.steps[i + 1].differential
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if next_dist is None:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# Get the target diagnosis from expected_impact
|
| 110 |
+
pos_target = step.expected_impact.get("if_positive", "")
|
| 111 |
+
neg_target = step.expected_impact.get("if_negative", "")
|
| 112 |
+
|
| 113 |
+
# Find probability of positive target before and after
|
| 114 |
+
pos_before = _fuzzy_lookup(current_dist, pos_target)
|
| 115 |
+
pos_after = _fuzzy_lookup(next_dist, pos_target)
|
| 116 |
+
neg_before = _fuzzy_lookup(current_dist, neg_target)
|
| 117 |
+
neg_after = _fuzzy_lookup(next_dist, neg_target)
|
| 118 |
+
|
| 119 |
+
# The agent predicted that this channel would help distinguish
|
| 120 |
+
# between pos_target and neg_target. Did the gap widen?
|
| 121 |
+
gap_before = abs(pos_before - neg_before)
|
| 122 |
+
gap_after = abs(pos_after - neg_after)
|
| 123 |
+
gap_widened = gap_after > gap_before
|
| 124 |
+
|
| 125 |
+
# Did probability shift in the stated direction?
|
| 126 |
+
actual_shift = pos_after - pos_before
|
| 127 |
+
shift_correct = gap_widened # More discriminating = correct prediction
|
| 128 |
+
|
| 129 |
+
if shift_correct:
|
| 130 |
+
direction_correct_count += 1
|
| 131 |
+
|
| 132 |
+
# Was the channel actually useful?
|
| 133 |
+
entropy_before = compute_entropy(current_dist)
|
| 134 |
+
entropy_after = compute_entropy(next_dist)
|
| 135 |
+
actual_ig = entropy_before - entropy_after
|
| 136 |
+
actually_useful = actual_ig > 0.05
|
| 137 |
+
|
| 138 |
+
if actually_useful:
|
| 139 |
+
useful_when_predicted += 1
|
| 140 |
+
|
| 141 |
+
metrics = FaithfulnessMetrics(
|
| 142 |
+
case_id=result.case_id,
|
| 143 |
+
step=step.step,
|
| 144 |
+
channel_requested=step.requested_channel or "",
|
| 145 |
+
stated_reasoning=step.reasoning[:200],
|
| 146 |
+
stated_if_positive=pos_target,
|
| 147 |
+
stated_if_negative=neg_target,
|
| 148 |
+
target_diagnosis_before=pos_before,
|
| 149 |
+
target_diagnosis_after=pos_after,
|
| 150 |
+
actual_shift=actual_shift,
|
| 151 |
+
shift_direction_correct=shift_correct,
|
| 152 |
+
entropy_before=entropy_before,
|
| 153 |
+
entropy_after=entropy_after,
|
| 154 |
+
actual_ig=actual_ig,
|
| 155 |
+
predicted_useful=True,
|
| 156 |
+
actually_useful=actually_useful,
|
| 157 |
+
)
|
| 158 |
+
per_step_metrics.append(metrics)
|
| 159 |
+
|
| 160 |
+
n = len(per_step_metrics)
|
| 161 |
+
return {
|
| 162 |
+
"n_steps_analyzed": n,
|
| 163 |
+
"n_with_expected_impact": total_with_impact,
|
| 164 |
+
"direction_accuracy": direction_correct_count / n if n > 0 else 0,
|
| 165 |
+
"utility_precision": useful_when_predicted / n if n > 0 else 0,
|
| 166 |
+
"mean_actual_ig": float(np.mean([m.actual_ig for m in per_step_metrics])) if per_step_metrics else 0,
|
| 167 |
+
"mean_absolute_shift": float(np.mean([abs(m.actual_shift) for m in per_step_metrics])) if per_step_metrics else 0,
|
| 168 |
+
"per_step_details": [
|
| 169 |
+
{
|
| 170 |
+
"case_id": m.case_id,
|
| 171 |
+
"step": m.step,
|
| 172 |
+
"channel": m.channel_requested,
|
| 173 |
+
"direction_correct": m.shift_direction_correct,
|
| 174 |
+
"actual_ig": round(m.actual_ig, 4),
|
| 175 |
+
"actually_useful": m.actually_useful,
|
| 176 |
+
"stated_reasoning": m.stated_reasoning[:100],
|
| 177 |
+
}
|
| 178 |
+
for m in per_step_metrics
|
| 179 |
+
],
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ================================================================
|
| 184 |
+
# 2. Acquisition Order Patterns
|
| 185 |
+
# ================================================================
|
| 186 |
+
|
| 187 |
+
def analyze_acquisition_orders(
|
| 188 |
+
results: list[AgentResult],
|
| 189 |
+
cases: list[MedicalCase],
|
| 190 |
+
clinical_order: dict[str, list[str]] = None,
|
| 191 |
+
) -> dict:
|
| 192 |
+
"""
|
| 193 |
+
Analyze what acquisition ordering strategies the agent uses.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
- Most common first/second/third channel requests
|
| 197 |
+
- Order consistency across cases
|
| 198 |
+
- Correlation with clinical guideline order
|
| 199 |
+
- Correlation between acquisition order and case difficulty
|
| 200 |
+
"""
|
| 201 |
+
from baselines import CLINICAL_GUIDELINE_ORDER
|
| 202 |
+
if clinical_order is None:
|
| 203 |
+
clinical_order = CLINICAL_GUIDELINE_ORDER
|
| 204 |
+
|
| 205 |
+
# Collect all acquisition sequences
|
| 206 |
+
sequences = []
|
| 207 |
+
first_requests = Counter()
|
| 208 |
+
second_requests = Counter()
|
| 209 |
+
full_sequences = Counter()
|
| 210 |
+
|
| 211 |
+
for result in results:
|
| 212 |
+
seq = tuple(result.acquired_channels)
|
| 213 |
+
sequences.append(seq)
|
| 214 |
+
full_sequences[seq] += 1
|
| 215 |
+
|
| 216 |
+
if len(seq) >= 1:
|
| 217 |
+
first_requests[seq[0]] += 1
|
| 218 |
+
if len(seq) >= 2:
|
| 219 |
+
second_requests[seq[1]] += 1
|
| 220 |
+
|
| 221 |
+
n = len(sequences)
|
| 222 |
+
|
| 223 |
+
# Consistency: what fraction of cases share the most common first request?
|
| 224 |
+
most_common_first = first_requests.most_common(1)
|
| 225 |
+
first_consistency = most_common_first[0][1] / n if most_common_first and n > 0 else 0
|
| 226 |
+
|
| 227 |
+
# Unique sequences
|
| 228 |
+
n_unique = len(full_sequences)
|
| 229 |
+
|
| 230 |
+
# Correlation with clinical guideline order
|
| 231 |
+
guideline_correlations = []
|
| 232 |
+
for result, case in zip(results, cases):
|
| 233 |
+
ds = case.dataset
|
| 234 |
+
if ds not in clinical_order:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
gl_order = clinical_order[ds]
|
| 238 |
+
agent_order = result.acquired_channels
|
| 239 |
+
|
| 240 |
+
if len(agent_order) < 2:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
# Compute rank correlation
|
| 244 |
+
# Map channels to their guideline rank
|
| 245 |
+
gl_ranks = {ch: i for i, ch in enumerate(gl_order)}
|
| 246 |
+
agent_ranks = {ch: i for i, ch in enumerate(agent_order)}
|
| 247 |
+
|
| 248 |
+
common = set(agent_order) & set(gl_order)
|
| 249 |
+
if len(common) < 2:
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
gl_r = [gl_ranks.get(ch, len(gl_order)) for ch in agent_order if ch in common]
|
| 253 |
+
ag_r = list(range(len(gl_r)))
|
| 254 |
+
|
| 255 |
+
if len(gl_r) >= 2:
|
| 256 |
+
corr, pval = spearmanr(gl_r, ag_r)
|
| 257 |
+
if not np.isnan(corr):
|
| 258 |
+
guideline_correlations.append(corr)
|
| 259 |
+
|
| 260 |
+
# Cost efficiency: does the agent prefer cheaper channels first?
|
| 261 |
+
cost_order_correlations = []
|
| 262 |
+
for result, case in zip(results, cases):
|
| 263 |
+
if len(result.acquired_channels) < 2:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
costs = [case.get_channel_cost(ch) for ch in result.acquired_channels]
|
| 267 |
+
positions = list(range(len(costs)))
|
| 268 |
+
|
| 269 |
+
if len(set(costs)) > 1:
|
| 270 |
+
corr, _ = spearmanr(costs, positions)
|
| 271 |
+
if not np.isnan(corr):
|
| 272 |
+
cost_order_correlations.append(corr)
|
| 273 |
+
|
| 274 |
+
return {
|
| 275 |
+
"n_cases": n,
|
| 276 |
+
"n_unique_sequences": n_unique,
|
| 277 |
+
"sequence_entropy": _sequence_entropy(full_sequences, n),
|
| 278 |
+
"first_request_distribution": dict(first_requests.most_common()),
|
| 279 |
+
"first_request_consistency": first_consistency,
|
| 280 |
+
"second_request_distribution": dict(second_requests.most_common()),
|
| 281 |
+
"most_common_sequences": [
|
| 282 |
+
{"sequence": list(seq), "count": count}
|
| 283 |
+
for seq, count in full_sequences.most_common(5)
|
| 284 |
+
],
|
| 285 |
+
"guideline_correlation": {
|
| 286 |
+
"mean": float(np.mean(guideline_correlations)) if guideline_correlations else None,
|
| 287 |
+
"std": float(np.std(guideline_correlations)) if guideline_correlations else None,
|
| 288 |
+
"n_comparable": len(guideline_correlations),
|
| 289 |
+
},
|
| 290 |
+
"cost_order_correlation": {
|
| 291 |
+
"mean": float(np.mean(cost_order_correlations)) if cost_order_correlations else None,
|
| 292 |
+
"std": float(np.std(cost_order_correlations)) if cost_order_correlations else None,
|
| 293 |
+
"interpretation": (
|
| 294 |
+
"positive = cheaper first, negative = expensive first"
|
| 295 |
+
),
|
| 296 |
+
},
|
| 297 |
+
"mean_channels_acquired": float(np.mean([len(s) for s in sequences])),
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ================================================================
|
| 302 |
+
# 3. Error Analysis
|
| 303 |
+
# ================================================================
|
| 304 |
+
|
| 305 |
+
@dataclass
|
| 306 |
+
class ErrorCase:
|
| 307 |
+
"""Detailed analysis of a single error case."""
|
| 308 |
+
case_id: str
|
| 309 |
+
ground_truth: str
|
| 310 |
+
agent_top1: str
|
| 311 |
+
agent_confidence: float
|
| 312 |
+
n_acquired: int
|
| 313 |
+
acquired_channels: list[str]
|
| 314 |
+
committed_early: bool
|
| 315 |
+
missed_channels: list[str]
|
| 316 |
+
error_type: str # "overconfident_early", "wrong_after_all", "insufficient_info"
|
| 317 |
+
reasoning_chain: list[str]
|
| 318 |
+
entropy_at_commit: float
|
| 319 |
+
final_ig_trend: str # "increasing", "decreasing", "plateau"
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def analyze_errors(
|
| 323 |
+
results: list[AgentResult],
|
| 324 |
+
cases: list[MedicalCase],
|
| 325 |
+
) -> dict:
|
| 326 |
+
"""
|
| 327 |
+
Detailed error analysis: when and why the agent gets cases wrong.
|
| 328 |
+
|
| 329 |
+
Categorizes errors into:
|
| 330 |
+
1. Overconfident early commit — committed before gathering enough info
|
| 331 |
+
2. Wrong after all info — had all info but still wrong (reasoning failure)
|
| 332 |
+
3. Insufficient info — didn't have the right channels (missing key evidence)
|
| 333 |
+
"""
|
| 334 |
+
errors = []
|
| 335 |
+
correct_count = 0
|
| 336 |
+
total = len(results)
|
| 337 |
+
|
| 338 |
+
for result, case in zip(results, cases):
|
| 339 |
+
if not result.final_ranking:
|
| 340 |
+
continue
|
| 341 |
+
|
| 342 |
+
top = result.final_ranking[0]
|
| 343 |
+
top_name = top.get("name", "").strip().lower()
|
| 344 |
+
gt = case.ground_truth.strip().lower()
|
| 345 |
+
correct = top_name == gt or top_name in gt or gt in top_name
|
| 346 |
+
|
| 347 |
+
if correct:
|
| 348 |
+
correct_count += 1
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# Classify error type
|
| 352 |
+
all_requestable = set(case.requestable_channels.keys())
|
| 353 |
+
acquired = set(result.acquired_channels)
|
| 354 |
+
missed = list(all_requestable - acquired)
|
| 355 |
+
|
| 356 |
+
if result.committed_early and missed:
|
| 357 |
+
error_type = "overconfident_early"
|
| 358 |
+
elif not missed:
|
| 359 |
+
error_type = "wrong_after_all"
|
| 360 |
+
else:
|
| 361 |
+
error_type = "insufficient_info"
|
| 362 |
+
|
| 363 |
+
# Extract reasoning chain
|
| 364 |
+
reasoning_chain = []
|
| 365 |
+
for step in result.steps:
|
| 366 |
+
if step.reasoning:
|
| 367 |
+
reasoning_chain.append(
|
| 368 |
+
f"Step {step.step}: {step.reasoning[:150]}"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Entropy trend
|
| 372 |
+
entropies = [s.entropy for s in result.steps if s.entropy > 0]
|
| 373 |
+
if len(entropies) >= 2:
|
| 374 |
+
diffs = [entropies[i+1] - entropies[i] for i in range(len(entropies)-1)]
|
| 375 |
+
if all(d <= 0 for d in diffs):
|
| 376 |
+
trend = "decreasing"
|
| 377 |
+
elif all(d >= 0 for d in diffs):
|
| 378 |
+
trend = "increasing"
|
| 379 |
+
else:
|
| 380 |
+
trend = "non_monotonic"
|
| 381 |
+
else:
|
| 382 |
+
trend = "insufficient_data"
|
| 383 |
+
|
| 384 |
+
entropy_at_commit = entropies[-1] if entropies else 0.0
|
| 385 |
+
|
| 386 |
+
error = ErrorCase(
|
| 387 |
+
case_id=result.case_id,
|
| 388 |
+
ground_truth=case.ground_truth,
|
| 389 |
+
agent_top1=top.get("name", ""),
|
| 390 |
+
agent_confidence=top.get("confidence", 0),
|
| 391 |
+
n_acquired=len(result.acquired_channels),
|
| 392 |
+
acquired_channels=result.acquired_channels,
|
| 393 |
+
committed_early=result.committed_early,
|
| 394 |
+
missed_channels=missed,
|
| 395 |
+
error_type=error_type,
|
| 396 |
+
reasoning_chain=reasoning_chain,
|
| 397 |
+
entropy_at_commit=entropy_at_commit,
|
| 398 |
+
final_ig_trend=trend,
|
| 399 |
+
)
|
| 400 |
+
errors.append(error)
|
| 401 |
+
|
| 402 |
+
# Aggregate by error type
|
| 403 |
+
type_counts = Counter(e.error_type for e in errors)
|
| 404 |
+
n_errors = len(errors)
|
| 405 |
+
|
| 406 |
+
# Confidence distribution for errors vs correct
|
| 407 |
+
error_confidences = [e.agent_confidence for e in errors]
|
| 408 |
+
|
| 409 |
+
return {
|
| 410 |
+
"n_total": total,
|
| 411 |
+
"n_correct": correct_count,
|
| 412 |
+
"n_errors": n_errors,
|
| 413 |
+
"accuracy": correct_count / total if total > 0 else 0,
|
| 414 |
+
"error_type_distribution": {
|
| 415 |
+
"overconfident_early": type_counts.get("overconfident_early", 0),
|
| 416 |
+
"wrong_after_all": type_counts.get("wrong_after_all", 0),
|
| 417 |
+
"insufficient_info": type_counts.get("insufficient_info", 0),
|
| 418 |
+
},
|
| 419 |
+
"error_type_rates": {
|
| 420 |
+
etype: count / n_errors if n_errors > 0 else 0
|
| 421 |
+
for etype, count in type_counts.items()
|
| 422 |
+
},
|
| 423 |
+
"mean_error_confidence": float(np.mean(error_confidences)) if error_confidences else 0,
|
| 424 |
+
"mean_error_channels_acquired": float(np.mean([e.n_acquired for e in errors])) if errors else 0,
|
| 425 |
+
"entropy_at_commit": {
|
| 426 |
+
"mean": float(np.mean([e.entropy_at_commit for e in errors])) if errors else 0,
|
| 427 |
+
"std": float(np.std([e.entropy_at_commit for e in errors])) if errors else 0,
|
| 428 |
+
},
|
| 429 |
+
"ig_trend_distribution": dict(Counter(e.final_ig_trend for e in errors)),
|
| 430 |
+
"per_case_errors": [
|
| 431 |
+
{
|
| 432 |
+
"case_id": e.case_id,
|
| 433 |
+
"ground_truth": e.ground_truth,
|
| 434 |
+
"predicted": e.agent_top1,
|
| 435 |
+
"confidence": e.agent_confidence,
|
| 436 |
+
"error_type": e.error_type,
|
| 437 |
+
"n_acquired": e.n_acquired,
|
| 438 |
+
"missed": e.missed_channels,
|
| 439 |
+
"committed_early": e.committed_early,
|
| 440 |
+
"entropy_at_commit": round(e.entropy_at_commit, 3),
|
| 441 |
+
}
|
| 442 |
+
for e in errors
|
| 443 |
+
],
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
# ================================================================
|
| 448 |
+
# 4. Stopping Decision Quality
|
| 449 |
+
# ================================================================
|
| 450 |
+
|
| 451 |
+
def analyze_stopping_decisions(
|
| 452 |
+
results: list[AgentResult],
|
| 453 |
+
cases: list[MedicalCase],
|
| 454 |
+
) -> dict:
|
| 455 |
+
"""
|
| 456 |
+
Analyze whether the agent's commit decisions are well-timed.
|
| 457 |
+
|
| 458 |
+
Compares:
|
| 459 |
+
- Cases where agent committed early and was correct (good early stop)
|
| 460 |
+
- Cases where agent committed early and was wrong (premature stop)
|
| 461 |
+
- Cases that used all channels (necessary thoroughness vs wasted budget)
|
| 462 |
+
"""
|
| 463 |
+
early_correct = []
|
| 464 |
+
early_wrong = []
|
| 465 |
+
full_correct = []
|
| 466 |
+
full_wrong = []
|
| 467 |
+
|
| 468 |
+
for result, case in zip(results, cases):
|
| 469 |
+
if not result.final_ranking:
|
| 470 |
+
continue
|
| 471 |
+
|
| 472 |
+
top = result.final_ranking[0]
|
| 473 |
+
top_name = top.get("name", "").strip().lower()
|
| 474 |
+
gt = case.ground_truth.strip().lower()
|
| 475 |
+
correct = top_name == gt or top_name in gt or gt in top_name
|
| 476 |
+
n_requestable = len(case.requestable_channels)
|
| 477 |
+
n_acquired = len(result.acquired_channels)
|
| 478 |
+
|
| 479 |
+
entry = {
|
| 480 |
+
"case_id": result.case_id,
|
| 481 |
+
"confidence": top.get("confidence", 0),
|
| 482 |
+
"n_acquired": n_acquired,
|
| 483 |
+
"n_available": n_requestable,
|
| 484 |
+
"fraction_used": n_acquired / n_requestable if n_requestable > 0 else 1,
|
| 485 |
+
"cost": result.acquisition_cost,
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
if result.committed_early:
|
| 489 |
+
if correct:
|
| 490 |
+
early_correct.append(entry)
|
| 491 |
+
else:
|
| 492 |
+
early_wrong.append(entry)
|
| 493 |
+
else:
|
| 494 |
+
if correct:
|
| 495 |
+
full_correct.append(entry)
|
| 496 |
+
else:
|
| 497 |
+
full_wrong.append(entry)
|
| 498 |
+
|
| 499 |
+
def _summarize(entries):
|
| 500 |
+
if not entries:
|
| 501 |
+
return {"count": 0}
|
| 502 |
+
return {
|
| 503 |
+
"count": len(entries),
|
| 504 |
+
"mean_confidence": float(np.mean([e["confidence"] for e in entries])),
|
| 505 |
+
"mean_channels": float(np.mean([e["n_acquired"] for e in entries])),
|
| 506 |
+
"mean_fraction_used": float(np.mean([e["fraction_used"] for e in entries])),
|
| 507 |
+
"mean_cost": float(np.mean([e["cost"] for e in entries])),
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
total = len(results)
|
| 511 |
+
early_rate = (len(early_correct) + len(early_wrong)) / total if total > 0 else 0
|
| 512 |
+
early_precision = (
|
| 513 |
+
len(early_correct) / (len(early_correct) + len(early_wrong))
|
| 514 |
+
if (len(early_correct) + len(early_wrong)) > 0 else 0
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
return {
|
| 518 |
+
"n_total": total,
|
| 519 |
+
"early_commit_rate": early_rate,
|
| 520 |
+
"early_commit_precision": early_precision,
|
| 521 |
+
"early_correct": _summarize(early_correct),
|
| 522 |
+
"early_wrong": _summarize(early_wrong),
|
| 523 |
+
"full_correct": _summarize(full_correct),
|
| 524 |
+
"full_wrong": _summarize(full_wrong),
|
| 525 |
+
"cost_savings_from_early_commit": {
|
| 526 |
+
"mean_cost_early": float(np.mean(
|
| 527 |
+
[e["cost"] for e in early_correct + early_wrong]
|
| 528 |
+
)) if (early_correct or early_wrong) else 0,
|
| 529 |
+
"mean_cost_full": float(np.mean(
|
| 530 |
+
[e["cost"] for e in full_correct + full_wrong]
|
| 531 |
+
)) if (full_correct or full_wrong) else 0,
|
| 532 |
+
},
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# ================================================================
|
| 537 |
+
# Full Analysis Pipeline
|
| 538 |
+
# ================================================================
|
| 539 |
+
|
| 540 |
+
def run_reasoning_analysis(
|
| 541 |
+
results: list[AgentResult],
|
| 542 |
+
cases: list[MedicalCase],
|
| 543 |
+
save_dir: Path = None,
|
| 544 |
+
) -> dict:
|
| 545 |
+
"""Run all reasoning analyses and return combined results."""
|
| 546 |
+
logger.info("Running reasoning analysis...")
|
| 547 |
+
|
| 548 |
+
faithfulness = compute_reasoning_faithfulness(results, cases)
|
| 549 |
+
logger.info(
|
| 550 |
+
f" Faithfulness: direction_accuracy={faithfulness['direction_accuracy']:.3f}, "
|
| 551 |
+
f"utility_precision={faithfulness['utility_precision']:.3f}"
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
orders = analyze_acquisition_orders(results, cases)
|
| 555 |
+
logger.info(
|
| 556 |
+
f" Order patterns: {orders['n_unique_sequences']} unique sequences, "
|
| 557 |
+
f"first_consistency={orders['first_request_consistency']:.3f}"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
errors = analyze_errors(results, cases)
|
| 561 |
+
logger.info(
|
| 562 |
+
f" Errors: {errors['n_errors']}/{errors['n_total']} "
|
| 563 |
+
f"({errors['error_type_distribution']})"
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
stopping = analyze_stopping_decisions(results, cases)
|
| 567 |
+
logger.info(
|
| 568 |
+
f" Stopping: early_rate={stopping['early_commit_rate']:.3f}, "
|
| 569 |
+
f"early_precision={stopping['early_commit_precision']:.3f}"
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
output = {
|
| 573 |
+
"reasoning_faithfulness": faithfulness,
|
| 574 |
+
"acquisition_orders": orders,
|
| 575 |
+
"error_analysis": errors,
|
| 576 |
+
"stopping_decisions": stopping,
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
if save_dir:
|
| 580 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 581 |
+
# Remove non-serializable details for compact save
|
| 582 |
+
compact = json.loads(json.dumps(output, default=str))
|
| 583 |
+
with open(save_dir / "reasoning_analysis.json", "w") as f:
|
| 584 |
+
json.dump(compact, f, indent=2)
|
| 585 |
+
logger.info(f" Saved to {save_dir / 'reasoning_analysis.json'}")
|
| 586 |
+
|
| 587 |
+
return output
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# ================================================================
|
| 591 |
+
# Helpers
|
| 592 |
+
# ================================================================
|
| 593 |
+
|
| 594 |
+
def _fuzzy_lookup(dist: dict, target: str) -> float:
|
| 595 |
+
"""Look up a diagnosis probability with fuzzy name matching."""
|
| 596 |
+
target_lower = target.lower().strip()
|
| 597 |
+
for name, prob in dist.items():
|
| 598 |
+
if target_lower in name.lower() or name.lower() in target_lower:
|
| 599 |
+
return prob
|
| 600 |
+
return 0.0
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def _sequence_entropy(counter: Counter, total: int) -> float:
|
| 604 |
+
"""Shannon entropy of sequence distribution (diversity measure)."""
|
| 605 |
+
if total == 0:
|
| 606 |
+
return 0.0
|
| 607 |
+
entropy = 0.0
|
| 608 |
+
for count in counter.values():
|
| 609 |
+
p = count / total
|
| 610 |
+
if p > 0:
|
| 611 |
+
entropy -= p * np.log2(p)
|
| 612 |
+
return float(entropy)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=6.0.0
|
| 2 |
+
numpy
|
| 3 |
+
Pillow
|
| 4 |
+
scipy
|
| 5 |
+
openai
|
| 6 |
+
anthropic
|
| 7 |
+
together
|
| 8 |
+
python-dotenv
|
tools.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tool definitions for the ActiveMedAgent tool-use architecture.
|
| 3 |
+
|
| 4 |
+
Instead of parsing free-form text with regex, the agent makes structured
|
| 5 |
+
tool calls through the VLM's native function-calling interface. This:
|
| 6 |
+
1. Eliminates brittle parsing heuristics
|
| 7 |
+
2. Makes the agent a genuine tool-using system (not text completion + post-hoc extraction)
|
| 8 |
+
3. Provides formally verifiable action traces
|
| 9 |
+
4. Enables grounded information-theoretic analysis via structured probability reports
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ============================================================
|
| 18 |
+
# Tool Call Data Structures
|
| 19 |
+
# ============================================================
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class ToolCall:
|
| 23 |
+
"""A single tool call extracted from a VLM response."""
|
| 24 |
+
tool_name: str
|
| 25 |
+
arguments: dict[str, Any]
|
| 26 |
+
call_id: str = ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ToolResult:
|
| 31 |
+
"""Result returned to the VLM after executing a tool."""
|
| 32 |
+
call_id: str
|
| 33 |
+
content: str
|
| 34 |
+
images: list[str] | None = None # base64-encoded images
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ============================================================
|
| 38 |
+
# Tool Definitions (canonical format — translated per backend)
|
| 39 |
+
# ============================================================
|
| 40 |
+
|
| 41 |
+
AGENT_TOOLS = [
|
| 42 |
+
{
|
| 43 |
+
"name": "request_information",
|
| 44 |
+
"description": (
|
| 45 |
+
"Request one additional information channel to reduce diagnostic uncertainty "
|
| 46 |
+
"while avoiding unnecessary resource use. Call this when you need more data "
|
| 47 |
+
"to distinguish between competing diagnoses and the expected benefit justifies "
|
| 48 |
+
"the channel's cost. "
|
| 49 |
+
"You must specify which channel to acquire and why it would resolve your "
|
| 50 |
+
"current uncertainty."
|
| 51 |
+
),
|
| 52 |
+
"parameters": {
|
| 53 |
+
"type": "object",
|
| 54 |
+
"properties": {
|
| 55 |
+
"channel_name": {
|
| 56 |
+
"type": "string",
|
| 57 |
+
"description": "Exact name of the channel to request (from the available list)",
|
| 58 |
+
},
|
| 59 |
+
"reasoning": {
|
| 60 |
+
"type": "string",
|
| 61 |
+
"description": "Why this channel best resolves your current diagnostic uncertainty",
|
| 62 |
+
},
|
| 63 |
+
"current_differential": {
|
| 64 |
+
"type": "array",
|
| 65 |
+
"description": "Your current ranked differential diagnosis with calibrated probabilities (must sum to 1.0)",
|
| 66 |
+
"items": {
|
| 67 |
+
"type": "object",
|
| 68 |
+
"properties": {
|
| 69 |
+
"name": {"type": "string", "description": "Diagnosis name"},
|
| 70 |
+
"probability": {
|
| 71 |
+
"type": "number",
|
| 72 |
+
"description": "Posterior probability (0-1), all must sum to 1.0",
|
| 73 |
+
},
|
| 74 |
+
},
|
| 75 |
+
"required": ["name", "probability"],
|
| 76 |
+
},
|
| 77 |
+
},
|
| 78 |
+
"expected_impact": {
|
| 79 |
+
"type": "object",
|
| 80 |
+
"description": "What you expect this information to reveal",
|
| 81 |
+
"properties": {
|
| 82 |
+
"if_positive": {
|
| 83 |
+
"type": "string",
|
| 84 |
+
"description": "Which diagnosis becomes most likely if this channel shows positive/abnormal findings",
|
| 85 |
+
},
|
| 86 |
+
"if_negative": {
|
| 87 |
+
"type": "string",
|
| 88 |
+
"description": "Which diagnosis becomes most likely if this channel shows negative/normal findings",
|
| 89 |
+
},
|
| 90 |
+
},
|
| 91 |
+
"required": ["if_positive", "if_negative"],
|
| 92 |
+
},
|
| 93 |
+
},
|
| 94 |
+
"required": ["channel_name", "reasoning", "current_differential", "expected_impact"],
|
| 95 |
+
},
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"name": "commit_diagnosis",
|
| 99 |
+
"description": (
|
| 100 |
+
"Commit to a final ranked diagnosis. Call this ONLY when you have exhausted "
|
| 101 |
+
"the clinically useful information OR when your top diagnosis has probability "
|
| 102 |
+
">= 0.85 and is well-separated from alternatives. Prefer committing when "
|
| 103 |
+
"remaining channels are unlikely to change management enough to justify cost."
|
| 104 |
+
),
|
| 105 |
+
"parameters": {
|
| 106 |
+
"type": "object",
|
| 107 |
+
"properties": {
|
| 108 |
+
"ranked_diagnoses": {
|
| 109 |
+
"type": "array",
|
| 110 |
+
"description": "Final ranked list of all candidate diagnoses with calibrated probabilities summing to 1.0",
|
| 111 |
+
"items": {
|
| 112 |
+
"type": "object",
|
| 113 |
+
"properties": {
|
| 114 |
+
"name": {"type": "string"},
|
| 115 |
+
"confidence": {
|
| 116 |
+
"type": "number",
|
| 117 |
+
"description": "Posterior probability (0-1)",
|
| 118 |
+
},
|
| 119 |
+
"key_evidence": {
|
| 120 |
+
"type": "string",
|
| 121 |
+
"description": "Most important evidence supporting or refuting this diagnosis",
|
| 122 |
+
},
|
| 123 |
+
},
|
| 124 |
+
"required": ["name", "confidence", "key_evidence"],
|
| 125 |
+
},
|
| 126 |
+
},
|
| 127 |
+
"reasoning": {
|
| 128 |
+
"type": "string",
|
| 129 |
+
"description": "Final diagnostic reasoning chain",
|
| 130 |
+
},
|
| 131 |
+
},
|
| 132 |
+
"required": ["ranked_diagnoses", "reasoning"],
|
| 133 |
+
},
|
| 134 |
+
},
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ============================================================
|
| 139 |
+
# Schema Translation
|
| 140 |
+
# ============================================================
|
| 141 |
+
|
| 142 |
+
def to_openai_tools(tools: list[dict] = None) -> list[dict]:
|
| 143 |
+
"""Convert canonical tool definitions to OpenAI function-calling format."""
|
| 144 |
+
if tools is None:
|
| 145 |
+
tools = AGENT_TOOLS
|
| 146 |
+
return [
|
| 147 |
+
{
|
| 148 |
+
"type": "function",
|
| 149 |
+
"function": {
|
| 150 |
+
"name": t["name"],
|
| 151 |
+
"description": t["description"],
|
| 152 |
+
"parameters": t["parameters"],
|
| 153 |
+
},
|
| 154 |
+
}
|
| 155 |
+
for t in tools
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def to_anthropic_tools(tools: list[dict] = None) -> list[dict]:
|
| 160 |
+
"""Convert canonical tool definitions to Anthropic tool-use format."""
|
| 161 |
+
if tools is None:
|
| 162 |
+
tools = AGENT_TOOLS
|
| 163 |
+
return [
|
| 164 |
+
{
|
| 165 |
+
"name": t["name"],
|
| 166 |
+
"description": t["description"],
|
| 167 |
+
"input_schema": t["parameters"],
|
| 168 |
+
}
|
| 169 |
+
for t in tools
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def constrain_tools_for_step(budget_remaining: int, allow_commit: bool = True) -> list[dict]:
|
| 174 |
+
"""
|
| 175 |
+
Return the appropriate tool subset for the current step.
|
| 176 |
+
|
| 177 |
+
- If budget > 0 and channels available: both request_information and commit_diagnosis
|
| 178 |
+
- If budget == 0 or forced final: only commit_diagnosis
|
| 179 |
+
"""
|
| 180 |
+
if budget_remaining <= 0:
|
| 181 |
+
return [t for t in AGENT_TOOLS if t["name"] == "commit_diagnosis"]
|
| 182 |
+
tools = list(AGENT_TOOLS)
|
| 183 |
+
if not allow_commit:
|
| 184 |
+
tools = [t for t in tools if t["name"] != "commit_diagnosis"]
|
| 185 |
+
return tools
|
trajectory.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trajectory Collection for ActiveMedAgent.
|
| 3 |
+
|
| 4 |
+
Phase 1 of the training pipeline:
|
| 5 |
+
1. Run zero-shot agent on all cases
|
| 6 |
+
2. Record full (state, action, reward) trajectories
|
| 7 |
+
3. Compute per-step rewards: did the acquisition improve the diagnosis?
|
| 8 |
+
4. Save trajectory dataset for Phase 2 policy learning
|
| 9 |
+
|
| 10 |
+
Each trajectory step records:
|
| 11 |
+
- state: current uncertainty, differential, acquired channels so far
|
| 12 |
+
- action: which channel was requested
|
| 13 |
+
- reward: MRR improvement after receiving the requested info
|
| 14 |
+
- outcome: final diagnosis correctness
|
| 15 |
+
"""
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import random
|
| 19 |
+
from dataclasses import dataclass, field, asdict
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
import config
|
| 26 |
+
from api_client import BaseVLMClient, create_client
|
| 27 |
+
from agent import ActiveMedAgent, AgentResult
|
| 28 |
+
from datasets.base import MedicalCase
|
| 29 |
+
from evaluation import compute_reciprocal_rank
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class TrajectoryStep:
|
| 36 |
+
"""A single step in an acquisition trajectory."""
|
| 37 |
+
step_idx: int
|
| 38 |
+
# State representation
|
| 39 |
+
acquired_so_far: list[str]
|
| 40 |
+
available_channels: list[str]
|
| 41 |
+
uncertainty_text: str
|
| 42 |
+
differential_before: list[dict] # Ranking before this acquisition
|
| 43 |
+
mrr_before: float
|
| 44 |
+
|
| 45 |
+
# Action
|
| 46 |
+
action: str # Channel name requested (or "COMMIT")
|
| 47 |
+
|
| 48 |
+
# Outcome (computed after the action)
|
| 49 |
+
differential_after: list[dict] # Ranking after receiving the info
|
| 50 |
+
mrr_after: float
|
| 51 |
+
reward: float # MRR improvement: mrr_after - mrr_before
|
| 52 |
+
acquisition_cost: float = 0.0
|
| 53 |
+
normalized_cost: float = 0.0
|
| 54 |
+
utility_reward: float = 0.0 # Cost-aware reward used for policy learning
|
| 55 |
+
diagnosis_changed: bool = False # Did top-1 change?
|
| 56 |
+
diagnosis_improved: bool = False # Did it change to the correct answer?
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class Trajectory:
|
| 61 |
+
"""Complete trajectory for one case."""
|
| 62 |
+
case_id: str
|
| 63 |
+
dataset: str
|
| 64 |
+
ground_truth: str
|
| 65 |
+
candidates: list[str]
|
| 66 |
+
steps: list[TrajectoryStep] = field(default_factory=list)
|
| 67 |
+
passive_mrr: float = 0.0
|
| 68 |
+
oracle_mrr: float = 0.0
|
| 69 |
+
final_mrr: float = 0.0
|
| 70 |
+
total_reward: float = 0.0
|
| 71 |
+
total_utility_reward: float = 0.0
|
| 72 |
+
success: bool = False # Did the agent get top-1 correct?
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TrajectoryCollector:
|
| 76 |
+
"""
|
| 77 |
+
Collect acquisition trajectories with per-step rewards.
|
| 78 |
+
|
| 79 |
+
Unlike the basic agent.diagnose(), this method runs the agent
|
| 80 |
+
step-by-step, evaluating the diagnosis after EACH acquisition
|
| 81 |
+
to compute fine-grained reward signals.
|
| 82 |
+
|
| 83 |
+
Uses the tool-use agent architecture: runs the full agent for
|
| 84 |
+
acquisition decisions, then evaluates intermediate states via
|
| 85 |
+
the agent's get_diagnosis_at_state() helper.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
client: BaseVLMClient,
|
| 91 |
+
prompt_variant: str = "A",
|
| 92 |
+
budget: int = 3,
|
| 93 |
+
):
|
| 94 |
+
self.client = client
|
| 95 |
+
self.prompt_variant = prompt_variant
|
| 96 |
+
self.budget = budget
|
| 97 |
+
|
| 98 |
+
def collect_trajectory(self, case: MedicalCase) -> Trajectory:
|
| 99 |
+
"""
|
| 100 |
+
Collect a full trajectory with per-step rewards for one case.
|
| 101 |
+
|
| 102 |
+
Strategy:
|
| 103 |
+
1. Get passive baseline (image-only diagnosis)
|
| 104 |
+
2. Get oracle ceiling (all-info diagnosis)
|
| 105 |
+
3. Run the active agent and record its decisions
|
| 106 |
+
4. For each acquisition step, evaluate the intermediate
|
| 107 |
+
diagnosis to compute per-step MRR reward
|
| 108 |
+
"""
|
| 109 |
+
traj = Trajectory(
|
| 110 |
+
case_id=case.case_id,
|
| 111 |
+
dataset=case.dataset,
|
| 112 |
+
ground_truth=case.ground_truth,
|
| 113 |
+
candidates=case.candidates,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# ---- Evaluation agent (budget=0, just for scoring) ----
|
| 117 |
+
eval_agent = ActiveMedAgent(
|
| 118 |
+
self.client, self.prompt_variant, budget=0
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# ---- Get passive baseline (MRR with no acquisition) ----
|
| 122 |
+
passive_result = eval_agent.diagnose_passive(case)
|
| 123 |
+
passive_mrr = compute_reciprocal_rank(
|
| 124 |
+
passive_result.final_ranking, case.ground_truth, case.candidates
|
| 125 |
+
)
|
| 126 |
+
traj.passive_mrr = passive_mrr
|
| 127 |
+
|
| 128 |
+
# ---- Get oracle ceiling (MRR with all info) ----
|
| 129 |
+
oracle_result = eval_agent.diagnose_oracle(case)
|
| 130 |
+
oracle_mrr = compute_reciprocal_rank(
|
| 131 |
+
oracle_result.final_ranking, case.ground_truth, case.candidates
|
| 132 |
+
)
|
| 133 |
+
traj.oracle_mrr = oracle_mrr
|
| 134 |
+
|
| 135 |
+
# ---- Run the active agent to get its acquisition decisions ----
|
| 136 |
+
active_agent = ActiveMedAgent(
|
| 137 |
+
self.client, self.prompt_variant, budget=self.budget
|
| 138 |
+
)
|
| 139 |
+
active_result = active_agent.diagnose(case)
|
| 140 |
+
|
| 141 |
+
# ---- Evaluate each intermediate state ----
|
| 142 |
+
current_mrr = passive_mrr
|
| 143 |
+
current_ranking = passive_result.final_ranking
|
| 144 |
+
acquired_so_far = []
|
| 145 |
+
|
| 146 |
+
for step_idx, step in enumerate(active_result.steps):
|
| 147 |
+
if step.committed:
|
| 148 |
+
# Agent committed early — record and stop
|
| 149 |
+
traj_step = TrajectoryStep(
|
| 150 |
+
step_idx=step_idx,
|
| 151 |
+
acquired_so_far=list(acquired_so_far),
|
| 152 |
+
available_channels=[
|
| 153 |
+
n for n in case.requestable_names
|
| 154 |
+
if n not in acquired_so_far
|
| 155 |
+
],
|
| 156 |
+
uncertainty_text=step.reasoning or "",
|
| 157 |
+
differential_before=current_ranking,
|
| 158 |
+
mrr_before=current_mrr,
|
| 159 |
+
action="COMMIT",
|
| 160 |
+
differential_after=current_ranking,
|
| 161 |
+
mrr_after=current_mrr,
|
| 162 |
+
reward=0.0,
|
| 163 |
+
acquisition_cost=0.0,
|
| 164 |
+
normalized_cost=0.0,
|
| 165 |
+
utility_reward=0.0,
|
| 166 |
+
diagnosis_changed=False,
|
| 167 |
+
diagnosis_improved=False,
|
| 168 |
+
)
|
| 169 |
+
traj.steps.append(traj_step)
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
channel = step.requested_channel
|
| 173 |
+
if not channel:
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
available = [
|
| 177 |
+
n for n in case.requestable_names
|
| 178 |
+
if n not in acquired_so_far
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
# Record the state BEFORE this acquisition
|
| 182 |
+
before_ranking = current_ranking
|
| 183 |
+
before_mrr = current_mrr
|
| 184 |
+
|
| 185 |
+
# Execute the acquisition
|
| 186 |
+
acquired_so_far.append(channel)
|
| 187 |
+
|
| 188 |
+
# Evaluate the diagnosis AFTER this acquisition
|
| 189 |
+
after_ranking, _ = eval_agent.get_diagnosis_at_state(
|
| 190 |
+
case, list(acquired_so_far)
|
| 191 |
+
)
|
| 192 |
+
after_mrr = compute_reciprocal_rank(
|
| 193 |
+
after_ranking, case.ground_truth, case.candidates
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Compute reward
|
| 197 |
+
reward = after_mrr - before_mrr
|
| 198 |
+
channel_cost = case.get_channel_cost(channel)
|
| 199 |
+
max_requestable_cost = max(case.get_max_requestable_cost(), 1.0)
|
| 200 |
+
normalized_cost = channel_cost / max_requestable_cost
|
| 201 |
+
utility_reward = reward - (
|
| 202 |
+
config.COST_PENALTY_LAMBDA * normalized_cost
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Did diagnosis change?
|
| 206 |
+
top1_before = (
|
| 207 |
+
before_ranking[0]["name"] if before_ranking else ""
|
| 208 |
+
)
|
| 209 |
+
top1_after = (
|
| 210 |
+
after_ranking[0]["name"] if after_ranking else ""
|
| 211 |
+
)
|
| 212 |
+
diagnosis_changed = (
|
| 213 |
+
top1_before.lower() != top1_after.lower()
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
gt_lower = case.ground_truth.lower()
|
| 217 |
+
diagnosis_improved = (
|
| 218 |
+
diagnosis_changed
|
| 219 |
+
and (
|
| 220 |
+
gt_lower in top1_after.lower()
|
| 221 |
+
or top1_after.lower() in gt_lower
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
traj_step = TrajectoryStep(
|
| 226 |
+
step_idx=step_idx,
|
| 227 |
+
acquired_so_far=list(acquired_so_far[:-1]),
|
| 228 |
+
available_channels=available,
|
| 229 |
+
uncertainty_text=step.reasoning or "",
|
| 230 |
+
differential_before=before_ranking,
|
| 231 |
+
mrr_before=before_mrr,
|
| 232 |
+
action=channel,
|
| 233 |
+
differential_after=after_ranking,
|
| 234 |
+
mrr_after=after_mrr,
|
| 235 |
+
reward=reward,
|
| 236 |
+
acquisition_cost=channel_cost,
|
| 237 |
+
normalized_cost=normalized_cost,
|
| 238 |
+
utility_reward=utility_reward,
|
| 239 |
+
diagnosis_changed=diagnosis_changed,
|
| 240 |
+
diagnosis_improved=diagnosis_improved,
|
| 241 |
+
)
|
| 242 |
+
traj.steps.append(traj_step)
|
| 243 |
+
|
| 244 |
+
# Update state for next step
|
| 245 |
+
current_mrr = after_mrr
|
| 246 |
+
current_ranking = after_ranking
|
| 247 |
+
|
| 248 |
+
# ---- Finalize trajectory ----
|
| 249 |
+
traj.final_mrr = current_mrr
|
| 250 |
+
traj.total_reward = sum(s.reward for s in traj.steps)
|
| 251 |
+
traj.total_utility_reward = sum(s.utility_reward for s in traj.steps)
|
| 252 |
+
traj.success = (current_mrr == 1.0)
|
| 253 |
+
|
| 254 |
+
return traj
|
| 255 |
+
|
| 256 |
+
def collect_dataset(
|
| 257 |
+
self,
|
| 258 |
+
cases: list[MedicalCase],
|
| 259 |
+
max_cases: int = None,
|
| 260 |
+
save_path: Path = None,
|
| 261 |
+
) -> list[Trajectory]:
|
| 262 |
+
"""Collect trajectories for all cases."""
|
| 263 |
+
if max_cases:
|
| 264 |
+
cases = cases[:max_cases]
|
| 265 |
+
|
| 266 |
+
trajectories = []
|
| 267 |
+
for case in tqdm(cases, desc="Collecting trajectories", ncols=80):
|
| 268 |
+
try:
|
| 269 |
+
traj = self.collect_trajectory(case)
|
| 270 |
+
trajectories.append(traj)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.error(f"Failed on {case.case_id}: {e}")
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
# Save
|
| 276 |
+
if save_path:
|
| 277 |
+
save_path = Path(save_path)
|
| 278 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 279 |
+
with open(save_path, "w") as f:
|
| 280 |
+
json.dump(
|
| 281 |
+
[asdict(t) for t in trajectories],
|
| 282 |
+
f, indent=2, default=str,
|
| 283 |
+
)
|
| 284 |
+
logger.info(f"Saved {len(trajectories)} trajectories to {save_path}")
|
| 285 |
+
|
| 286 |
+
# Report statistics
|
| 287 |
+
self._report_stats(trajectories)
|
| 288 |
+
|
| 289 |
+
return trajectories
|
| 290 |
+
|
| 291 |
+
def _report_stats(self, trajectories: list[Trajectory]):
|
| 292 |
+
"""Log summary statistics of collected trajectories."""
|
| 293 |
+
n = len(trajectories)
|
| 294 |
+
if n == 0:
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
logger.info(f"\n{'='*50}")
|
| 298 |
+
logger.info(f"Trajectory Collection Summary (n={n})")
|
| 299 |
+
logger.info(f"{'='*50}")
|
| 300 |
+
|
| 301 |
+
success_rate = np.mean([t.success for t in trajectories])
|
| 302 |
+
avg_steps = np.mean([len(t.steps) for t in trajectories])
|
| 303 |
+
avg_reward = np.mean([t.total_reward for t in trajectories])
|
| 304 |
+
avg_utility = np.mean([t.total_utility_reward for t in trajectories])
|
| 305 |
+
avg_passive_mrr = np.mean([t.passive_mrr for t in trajectories])
|
| 306 |
+
avg_final_mrr = np.mean([t.final_mrr for t in trajectories])
|
| 307 |
+
avg_oracle_mrr = np.mean([t.oracle_mrr for t in trajectories])
|
| 308 |
+
|
| 309 |
+
logger.info(f" Success rate: {success_rate:.3f}")
|
| 310 |
+
logger.info(f" Avg steps taken: {avg_steps:.1f}")
|
| 311 |
+
logger.info(f" Avg total reward: {avg_reward:.3f}")
|
| 312 |
+
logger.info(f" Avg utility reward: {avg_utility:.3f}")
|
| 313 |
+
logger.info(
|
| 314 |
+
f" MRR: passive={avg_passive_mrr:.3f} -> "
|
| 315 |
+
f"active={avg_final_mrr:.3f} -> oracle={avg_oracle_mrr:.3f}"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Per-action reward statistics
|
| 319 |
+
all_steps = [
|
| 320 |
+
s for t in trajectories for s in t.steps
|
| 321 |
+
if s.action != "COMMIT"
|
| 322 |
+
]
|
| 323 |
+
if all_steps:
|
| 324 |
+
action_rewards = {}
|
| 325 |
+
for s in all_steps:
|
| 326 |
+
if s.action not in action_rewards:
|
| 327 |
+
action_rewards[s.action] = []
|
| 328 |
+
action_rewards[s.action].append(s.utility_reward)
|
| 329 |
+
|
| 330 |
+
logger.info(f"\n Per-channel utility statistics:")
|
| 331 |
+
for action, rewards in sorted(
|
| 332 |
+
action_rewards.items(), key=lambda x: -np.mean(x[1])
|
| 333 |
+
):
|
| 334 |
+
logger.info(
|
| 335 |
+
f" {action:<25} mean_utility={np.mean(rewards):+.3f} "
|
| 336 |
+
f"n={len(rewards)} "
|
| 337 |
+
f"positive_rate={np.mean([r > 0 for r in rewards]):.2f}"
|
| 338 |
+
)
|