yuxbox commited on
Commit
a1aaf30
·
verified ·
1 Parent(s): dbdae02

Upload folder using huggingface_hub

Browse files
.env.template ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Set your API keys as HF Space secrets
2
+ # OPENAI_API_KEY=sk-...
3
+ # ANTHROPIC_API_KEY=sk-ant-...
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_cases/chest_xray_ipf.png filter=lfs diff=lfs merge=lfs -text
37
+ demo_cases/ct_pulmonary_pe.png filter=lfs diff=lfs merge=lfs -text
38
+ demo_cases/fundus_dme.png filter=lfs diff=lfs merge=lfs -text
39
+ demo_cases/oct_bscan_dme.png filter=lfs diff=lfs merge=lfs -text
40
+ demo_cases/skin_lesion_dermoscopy.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,25 @@
1
  ---
2
- title: Activemedagent Demo
3
- emoji: 🌖
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.11.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ActiveMedAgent Demo
3
+ emoji: 🏥
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 6.10.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # ActiveMedAgent: Learned Information Acquisition for Medical Diagnosis
14
+
15
+ Interactive demo for the ActiveMedAgent framework. Watch the agent reason step-by-step,
16
+ acquire information channels strategically, and track entropy reduction in real time.
17
+
18
+ **No budget constraint** — the agent decides when to stop based on information-theoretic criteria.
19
+
20
+ ## Features
21
+ - Pre-built demo cases (NEJM, MIDAS, OLIVES)
22
+ - Custom case builder with image upload
23
+ - Step-by-step reasoning trace with probability bars
24
+ - Entropy trajectory and information gain tracking
25
+ - Simulated mode (no API key needed) + real VLM backends
agent.py ADDED
@@ -0,0 +1,1170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core ActiveMedAgent — Tool-Use Architecture with Adaptive Context Management.
3
+
4
+ Two modes of operation:
5
+
6
+ 1. FULL MODE (for capable models: GPT-4o, Claude, Qwen-72B):
7
+ Multi-turn conversation with full history. The VLM sees its own prior
8
+ reasoning and can build on it.
9
+
10
+ 2. CONDENSED MODE (for weaker models: GPT-4o-mini, etc.):
11
+ Each acquisition step is a fresh single-turn call. The VLM receives:
12
+ - Initial image(s)
13
+ - A compact structured acquisition log of all prior steps
14
+ - The latest channel data
15
+ - Available channels + tools
16
+ This keeps context size O(1) per step instead of O(n), preventing
17
+ weaker models from losing track of their own reasoning.
18
+
19
+ In both modes:
20
+ - There is NO fixed budget. The agent acquires as many channels as it
21
+ needs (0 to all available). If a case needs all 5 NEJM vignettes, it
22
+ gets all 5. If the image alone is sufficient, it commits immediately.
23
+ - The agent decides when to stop via the commit_diagnosis tool.
24
+ - Probability distributions from tool calls are tracked for information-
25
+ theoretic analysis (entropy, IG, KL divergence).
26
+ """
27
+ import json
28
+ import logging
29
+ from dataclasses import dataclass, field
30
+
31
+ from api_client import BaseVLMClient, VLMResponse
32
+ from datasets.base import MedicalCase, ChannelData
33
+ from tools import (
34
+ ToolCall, ToolResult, AGENT_TOOLS,
35
+ to_openai_tools, to_anthropic_tools,
36
+ constrain_tools_for_step,
37
+ )
38
+ from information_gain import (
39
+ BeliefState, BeliefTrajectory,
40
+ compute_entropy, compute_kl_divergence,
41
+ estimate_expected_information_gain,
42
+ should_commit, compute_value_of_information,
43
+ )
44
+ from prompts import format_available_channels, format_acquired_info
45
+ import config
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ # ============================================================
51
+ # Data Structures
52
+ # ============================================================
53
+
54
+ @dataclass
55
+ class AcquisitionStep:
56
+ """Record of a single acquisition decision made via tool call."""
57
+ step: int
58
+ tool_call: ToolCall | None
59
+ requested_channel: str | None # None if agent committed
60
+ reasoning: str
61
+ differential: list[dict] # [{name, confidence}, ...]
62
+ committed: bool
63
+ raw_response: str
64
+ latency_ms: float
65
+ entropy: float = 0.0
66
+ information_gain: float = 0.0
67
+ kl_divergence: float = 0.0
68
+ expected_impact: dict = field(default_factory=dict)
69
+
70
+
71
+ @dataclass
72
+ class AgentResult:
73
+ """Complete result of an agent's diagnostic process on one case."""
74
+ case_id: str
75
+ dataset: str
76
+ prompt_variant: str
77
+ backend: str
78
+ budget: int # max channels available (not a hard limit)
79
+ steps: list[AcquisitionStep] = field(default_factory=list)
80
+ final_ranking: list[dict] = field(default_factory=list)
81
+ acquired_channels: list[str] = field(default_factory=list)
82
+ total_latency_ms: float = 0.0
83
+ total_input_tokens: int = 0
84
+ total_output_tokens: int = 0
85
+ committed_early: bool = False
86
+ final_raw_response: str = ""
87
+ belief_trajectory: BeliefTrajectory | None = None
88
+ total_case_cost: float = 0.0
89
+ acquisition_cost: float = 0.0
90
+
91
+
92
+ # ============================================================
93
+ # System Prompts
94
+ # ============================================================
95
+
96
+ SYSTEM_PROMPT_FULL = """\
97
+ You are a medical diagnostic agent. You may receive free clinical context, exam data, \
98
+ labs, and images in a tiered pathway. Your job is to determine the correct diagnosis \
99
+ from a set of candidates while avoiding unnecessary resource use.
100
+
101
+ You have two tools:
102
+
103
+ 1. request_information — Request one additional data channel when its expected \
104
+ clinical value justifies its cost. You MUST provide:
105
+ - channel_name: exactly one of the available channel names
106
+ - reasoning: why this channel best resolves your current uncertainty
107
+ - current_differential: your FULL ranked differential with calibrated \
108
+ probabilities that sum to 1.0 across ALL candidates
109
+ - expected_impact: what you expect (if_positive and if_negative)
110
+
111
+ 2. commit_diagnosis — Submit your final ranked diagnosis. Provide:
112
+ - ranked_diagnoses: ALL candidates with calibrated probabilities summing to 1.0, \
113
+ each with key_evidence
114
+ - reasoning: your complete diagnostic reasoning chain
115
+
116
+ Strategy:
117
+ - Start with whatever information is already available for free at presentation.
118
+ - Escalate only when the currently available information cannot safely distinguish the top diagnoses.
119
+ - If demographics, chief complaint, history, exam, or existing evidence are sufficient, commit without requesting imaging.
120
+ - Commit when your top diagnosis has high probability and is well-separated from \
121
+ alternatives, OR when no remaining channel would meaningfully change your differential.
122
+ - Your probability estimates MUST sum to 1.0 and reflect genuine calibrated uncertainty."""
123
+
124
+ SYSTEM_PROMPT_CONDENSED = """\
125
+ You are a medical diagnostic agent. Examine all currently available clinical \
126
+ information below, then decide your next action.
127
+
128
+ You have two tools:
129
+
130
+ 1. request_information — Request one more data channel only if its expected value \
131
+ justifies its cost. Provide:
132
+ - channel_name: one of the available channels listed below
133
+ - reasoning: why this channel would help
134
+ - current_differential: ranked diagnoses with probabilities summing to 1.0
135
+ - expected_impact: if_positive and if_negative predictions
136
+
137
+ 2. commit_diagnosis — Submit final diagnosis. Provide:
138
+ - ranked_diagnoses: ALL candidates with probabilities summing to 1.0 and key_evidence
139
+ - reasoning: complete reasoning chain
140
+
141
+ Decide: if remaining channels would meaningfully change your differential enough to \
142
+ justify their cost, request the best one. Otherwise, commit your diagnosis."""
143
+
144
+ SYSTEM_PROMPT_FINAL = """\
145
+ You are a medical diagnostic agent. You have gathered information. Now provide \
146
+ your final ranked diagnosis using the commit_diagnosis tool.
147
+
148
+ You MUST:
149
+ - Include ALL candidate diagnoses in ranked_diagnoses
150
+ - Assign calibrated probabilities summing to 1.0
151
+ - Provide specific key_evidence for EACH diagnosis
152
+ - Write a thorough reasoning chain synthesizing all acquired evidence
153
+ - Favor the least resource-intensive pathway that still supports the diagnosis"""
154
+
155
+
156
+ # ============================================================
157
+ # Context Mode Detection
158
+ # ============================================================
159
+
160
+ def _should_use_condensed(model_name: str) -> bool:
161
+ """Determine if a model should use condensed context mode."""
162
+ if config.CONTEXT_MODE == "full":
163
+ return False
164
+ if config.CONTEXT_MODE == "condensed":
165
+ return True
166
+ # adaptive — check model name
167
+ for pattern in config.CONDENSED_MODELS:
168
+ if pattern in model_name:
169
+ return True
170
+ return False
171
+
172
+
173
+ # ============================================================
174
+ # Condensed Acquisition Log Builder
175
+ # ============================================================
176
+
177
+ def _build_acquisition_log(
178
+ steps: list[AcquisitionStep],
179
+ acquired_data: dict[str, str],
180
+ ) -> str:
181
+ """
182
+ Build a compact structured summary of all prior acquisition steps.
183
+
184
+ This replaces the full multi-turn conversation for condensed mode.
185
+ Each step is ~50-80 tokens instead of 300-500 for the full tool call
186
+ response, keeping context manageable for weaker models.
187
+
188
+ Example output:
189
+ === ACQUISITION LOG (2 channels acquired) ===
190
+ Step 1: Acquired [dermoscopy]
191
+ Reasoning: Need subsurface structures to distinguish melanoma from BCC
192
+ Data received: [dermoscopy]: (image — see above)
193
+ Updated differential: Melanoma(0.55), BCC(0.30), SCC(0.15)
194
+ Entropy: 1.37 bits | Information gain: 0.19 bits
195
+
196
+ Step 2: Acquired [patient_demographics]
197
+ Reasoning: Age and skin type are critical for melanoma risk
198
+ Data received: [demographics]: 34M, Fitzpatrick II
199
+ Updated differential: Melanoma(0.75), BCC(0.15), SCC(0.10)
200
+ Entropy: 1.06 bits | Information gain: 0.31 bits
201
+ === END LOG ===
202
+ """
203
+ if not steps:
204
+ return "(No information acquired yet.)"
205
+
206
+ lines = [f"=== ACQUISITION LOG ({len(steps)} channel(s) acquired) ==="]
207
+ for step in steps:
208
+ if step.committed:
209
+ continue
210
+ ch = step.requested_channel or "unknown"
211
+ lines.append(f"Step {step.step + 1}: Acquired [{ch}]")
212
+ if step.reasoning:
213
+ # Truncate reasoning to key point
214
+ reasoning = step.reasoning
215
+ if len(reasoning) > 200:
216
+ reasoning = reasoning[:197] + "..."
217
+ lines.append(f" Reasoning: {reasoning}")
218
+
219
+ # Include the actual data received
220
+ data = acquired_data.get(ch, "")
221
+ if data:
222
+ if len(data) > 300:
223
+ data = data[:297] + "..."
224
+ lines.append(f" Data received: {data}")
225
+
226
+ # Compact differential
227
+ if step.differential:
228
+ diff_str = ", ".join(
229
+ f"{d['name']}({d['confidence']:.2f})"
230
+ for d in step.differential[:5]
231
+ )
232
+ lines.append(f" Updated differential: {diff_str}")
233
+
234
+ lines.append(
235
+ f" Entropy: {step.entropy:.2f} bits"
236
+ + (f" | IG: {step.information_gain:.2f} bits" if step.information_gain else "")
237
+ )
238
+ lines.append("")
239
+
240
+ lines.append("=== END LOG ===")
241
+ return "\n".join(lines)
242
+
243
+
244
+ # ============================================================
245
+ # Main Agent Class
246
+ # ============================================================
247
+
248
+ class ActiveMedAgent:
249
+ """
250
+ Tool-use active acquisition agent with adaptive context management.
251
+
252
+ No fixed budget — the agent acquires as many or as few channels as it
253
+ needs. Supports two context modes:
254
+ - full: multi-turn conversation (capable models)
255
+ - condensed: single-turn with compressed state (weaker models)
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ client: BaseVLMClient,
261
+ prompt_variant: str = "A",
262
+ budget: int = None,
263
+ context_mode: str = None,
264
+ ):
265
+ """
266
+ Args:
267
+ client: VLM API client.
268
+ prompt_variant: Prompt variant ID (for tracking, not used in tool mode).
269
+ budget: Max acquisitions. None = unlimited (use all channels if needed).
270
+ context_mode: "full", "condensed", or None (auto-detect from model).
271
+ """
272
+ self.client = client
273
+ self.prompt_variant = prompt_variant
274
+ self.budget = budget # None means unlimited
275
+ self._commit_hint = ""
276
+ if context_mode is not None:
277
+ self.condensed = (context_mode == "condensed")
278
+ else:
279
+ self.condensed = _should_use_condensed(client.model)
280
+
281
+ if self.condensed:
282
+ logger.info(
283
+ f"Using CONDENSED context mode for {client.model} "
284
+ f"(single-turn with compressed state)"
285
+ )
286
+
287
+ # ============================================================
288
+ # Main Acquisition Loop
289
+ # ============================================================
290
+
291
+ def diagnose(self, case: MedicalCase) -> AgentResult:
292
+ """
293
+ Run the full tool-use acquisition loop.
294
+
295
+ The agent has NO fixed budget. It keeps requesting channels until:
296
+ 1. It calls commit_diagnosis (confident enough), or
297
+ 2. All available channels are exhausted, or
298
+ 3. The safety limit is hit (max_steps = number of requestable channels)
299
+
300
+ Context mode determines how conversation history is managed:
301
+ - full: growing multi-turn conversation
302
+ - condensed: fresh single-turn call each step with compressed state
303
+ """
304
+ max_steps = len(case.requestable_names)
305
+ if self.budget is not None:
306
+ max_steps = min(max_steps, self.budget)
307
+
308
+ result = AgentResult(
309
+ case_id=case.case_id,
310
+ dataset=case.dataset,
311
+ prompt_variant=self.prompt_variant,
312
+ backend=self.client.model,
313
+ budget=max_steps,
314
+ )
315
+
316
+ acquired = []
317
+ acquired_data = {} # channel_name -> data string (for condensed log)
318
+ dataset_channel_config = config.CHANNEL_CONFIGS.get(case.dataset, {})
319
+ channel_config = {
320
+ name: info for name, info in dataset_channel_config.items()
321
+ if name in case.initial_channels or name in case.requestable_channels
322
+ }
323
+ conversation = [] # only used in full mode
324
+ trajectory = BeliefTrajectory(case_id=case.case_id)
325
+ prev_distribution = None
326
+ initial_images = case.get_initial_images()
327
+ initial_context_str = format_acquired_info(case.get_text_context([]))
328
+
329
+ # ---- Build initial message (shared by both modes) ----
330
+ available_str = format_available_channels(channel_config, acquired)
331
+ candidates_str = "\n".join(
332
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
333
+ )
334
+
335
+ if not self.condensed:
336
+ # FULL MODE: build initial user message for multi-turn
337
+ initial_content = self._build_image_content(initial_images)
338
+ initial_text = (
339
+ f"Review the currently available clinical information and determine the diagnosis.\n\n"
340
+ f"Information already available at presentation:\n{initial_context_str}\n\n"
341
+ f"Candidate diagnoses (you must rank ALL of these):\n"
342
+ f"{candidates_str}\n\n"
343
+ f"Prefer the least costly pathway that still supports a safe diagnosis.\n"
344
+ f"If the current information is sufficient, commit immediately without requesting more.\n"
345
+ f"You can request as many additional channels as you need "
346
+ f"(or none if already confident).\n"
347
+ f"Available information channels:\n{available_str}\n\n"
348
+ f"Use request_information to acquire the most informative "
349
+ f"channel for the cost, or commit_diagnosis if already confident."
350
+ )
351
+ initial_content.append({"type": "text", "text": initial_text})
352
+ conversation.append({"role": "user", "content": initial_content})
353
+
354
+ # ---- Acquisition Loop ----
355
+ for step_idx in range(max_steps):
356
+ available = [
357
+ n for n in case.requestable_names if n not in acquired
358
+ ]
359
+ if not available:
360
+ logger.debug(
361
+ f"[{case.case_id}] All channels exhausted at step {step_idx}"
362
+ )
363
+ break
364
+
365
+ # Force acquisition on first step; allow commit after that
366
+ step_tools = constrain_tools_for_step(
367
+ budget_remaining=max_steps - step_idx,
368
+ allow_commit=(step_idx > 0),
369
+ )
370
+
371
+ if self.condensed:
372
+ # CONDENSED MODE: build a fresh single-turn call each step
373
+ response = self._call_condensed(
374
+ case=case,
375
+ initial_images=initial_images,
376
+ acquired=acquired,
377
+ acquired_data=acquired_data,
378
+ steps=result.steps,
379
+ available=available,
380
+ candidates_str=candidates_str,
381
+ channel_config=channel_config,
382
+ step_tools=step_tools,
383
+ )
384
+ else:
385
+ # FULL MODE: multi-turn call with complete history
386
+ response = self.client.call_with_retry(
387
+ system_prompt=SYSTEM_PROMPT_FULL,
388
+ messages=conversation,
389
+ temperature=config.TEMPERATURE,
390
+ max_tokens=config.MAX_TOKENS,
391
+ tools=step_tools,
392
+ )
393
+
394
+ result.total_latency_ms += response.latency_ms
395
+ result.total_input_tokens += response.input_tokens
396
+ result.total_output_tokens += response.output_tokens
397
+
398
+ # ---- Process tool call ----
399
+ tool_call = response.tool_call
400
+
401
+ if tool_call is None:
402
+ # No tool call — fallback
403
+ logger.warning(
404
+ f"[{case.case_id}] Step {step_idx}: no tool call returned"
405
+ )
406
+ if not available:
407
+ break
408
+ fallback = available[0]
409
+ acquired.append(fallback)
410
+ result.acquired_channels.append(fallback)
411
+
412
+ ch = case.get_channel(fallback)
413
+ if ch and ch.channel_type == "text":
414
+ acquired_data[fallback] = f"[{fallback}]: {ch.value}"
415
+ else:
416
+ acquired_data[fallback] = f"[{fallback}]: (image)"
417
+
418
+ step = AcquisitionStep(
419
+ step=step_idx, tool_call=None,
420
+ requested_channel=fallback,
421
+ reasoning="(fallback — no tool call produced)",
422
+ differential=[], committed=False,
423
+ raw_response=response.text,
424
+ latency_ms=response.latency_ms,
425
+ )
426
+ result.steps.append(step)
427
+
428
+ if not self.condensed:
429
+ self._deliver_channel_data_as_user_message(
430
+ case, fallback, conversation, available, acquired,
431
+ channel_config,
432
+ )
433
+ continue
434
+
435
+ # Add assistant message to conversation (full mode only)
436
+ if not self.condensed:
437
+ conversation.append({
438
+ "role": "assistant",
439
+ "content": response.text,
440
+ "tool_calls": [tool_call],
441
+ })
442
+
443
+ # ---- Handle commit_diagnosis ----
444
+ if tool_call.tool_name == "commit_diagnosis":
445
+ args = tool_call.arguments
446
+ ranking = self._extract_ranking_from_commit(args)
447
+ distribution = {d["name"]: d["confidence"] for d in ranking}
448
+
449
+ belief = BeliefState(
450
+ step=step_idx,
451
+ distribution=distribution,
452
+ channel_acquired=None,
453
+ )
454
+ trajectory.states.append(belief)
455
+
456
+ ig = 0.0
457
+ kl = 0.0
458
+ if prev_distribution is not None:
459
+ ig = compute_entropy(prev_distribution) - compute_entropy(distribution)
460
+ kl = compute_kl_divergence(distribution, prev_distribution)
461
+
462
+ step = AcquisitionStep(
463
+ step=step_idx, tool_call=tool_call,
464
+ requested_channel=None,
465
+ reasoning=args.get("reasoning", ""),
466
+ differential=ranking, committed=True,
467
+ raw_response=response.text,
468
+ latency_ms=response.latency_ms,
469
+ entropy=belief.entropy,
470
+ information_gain=ig, kl_divergence=kl,
471
+ )
472
+ result.steps.append(step)
473
+ result.committed_early = True
474
+ result.final_ranking = ranking
475
+ logger.debug(
476
+ f"[{case.case_id}] Committed at step {step_idx} "
477
+ f"after acquiring {len(acquired)} channels "
478
+ f"(entropy={belief.entropy:.3f} bits)"
479
+ )
480
+ break
481
+
482
+ # ---- Handle request_information ----
483
+ elif tool_call.tool_name == "request_information":
484
+ args = tool_call.arguments
485
+ requested = args.get("channel_name", "")
486
+ differential = args.get("current_differential", [])
487
+ expected_impact = args.get("expected_impact", {})
488
+ reasoning = args.get("reasoning", "")
489
+
490
+ matched = self._match_channel(requested, available)
491
+ if matched is None:
492
+ matched = available[0]
493
+ logger.warning(
494
+ f"[{case.case_id}] Step {step_idx}: '{requested}' "
495
+ f"not in {available}, falling back to '{matched}'"
496
+ )
497
+
498
+ # Build distribution from tool call
499
+ distribution = {}
500
+ for d in differential:
501
+ distribution[d.get("name", "")] = d.get("probability", 0.0)
502
+
503
+ # Information-theoretic metrics
504
+ ig = 0.0
505
+ kl = 0.0
506
+ if prev_distribution is not None:
507
+ ig = compute_entropy(prev_distribution) - compute_entropy(distribution)
508
+ kl = compute_kl_divergence(distribution, prev_distribution)
509
+
510
+ belief = BeliefState(
511
+ step=step_idx,
512
+ distribution=distribution,
513
+ channel_acquired=matched,
514
+ )
515
+ trajectory.states.append(belief)
516
+
517
+ eig = estimate_expected_information_gain(
518
+ distribution, matched, expected_impact, case.candidates,
519
+ )
520
+
521
+ logger.debug(
522
+ f"[{case.case_id}] Step {step_idx}: requesting '{matched}' "
523
+ f"(H={belief.entropy:.3f}, IG={ig:.3f}, EIG={eig:.3f})"
524
+ )
525
+
526
+ step = AcquisitionStep(
527
+ step=step_idx, tool_call=tool_call,
528
+ requested_channel=matched, reasoning=reasoning,
529
+ differential=[
530
+ {"name": d.get("name", ""),
531
+ "confidence": d.get("probability", 0.0),
532
+ "rank": i + 1}
533
+ for i, d in enumerate(differential)
534
+ ],
535
+ committed=False,
536
+ raw_response=response.text,
537
+ latency_ms=response.latency_ms,
538
+ entropy=belief.entropy,
539
+ information_gain=ig, kl_divergence=kl,
540
+ expected_impact=expected_impact,
541
+ )
542
+ result.steps.append(step)
543
+ prev_distribution = distribution
544
+
545
+ acquired.append(matched)
546
+ result.acquired_channels.append(matched)
547
+
548
+ # Store acquired data for condensed log
549
+ ch = case.get_channel(matched)
550
+ if ch and ch.channel_type == "text":
551
+ acquired_data[matched] = f"[{matched}]: {ch.value}"
552
+ elif ch and ch.channel_type == "image":
553
+ acquired_data[matched] = f"[{matched}]: (image provided)"
554
+ else:
555
+ acquired_data[matched] = f"[{matched}]: No data available."
556
+
557
+ # ---- Check stopping criterion ----
558
+ # After recording the new belief state, evaluate whether
559
+ # the agent should stop acquiring. This is a principled
560
+ # information-theoretic check, not just a prompt heuristic.
561
+ remaining_channels = [
562
+ n for n in case.requestable_names if n not in acquired
563
+ ]
564
+ commit_recommended, commit_reason = should_commit(
565
+ trajectory=trajectory,
566
+ available_channels=remaining_channels,
567
+ min_steps=0, # agent decides — no forced minimum
568
+ )
569
+ voi = compute_value_of_information(
570
+ trajectory, len(remaining_channels),
571
+ )
572
+
573
+ if commit_recommended and remaining_channels:
574
+ logger.info(
575
+ f"[{case.case_id}] Stopping criterion triggered at "
576
+ f"step {step_idx}: {commit_reason} (VoI={voi:.3f})"
577
+ )
578
+ # Don't break yet — let the VLM make the decision on
579
+ # the next iteration. But inject a hint into the
580
+ # follow-up context.
581
+ self._commit_hint = (
582
+ f"\n\nNote: Based on your belief trajectory, additional "
583
+ f"acquisition has low expected value (VoI={voi:.2f}). "
584
+ f"The last channel provided only {ig:.3f} bits of "
585
+ f"information gain. Consider committing your diagnosis."
586
+ )
587
+ else:
588
+ self._commit_hint = ""
589
+
590
+ # Deliver tool result (full mode only — condensed rebuilds
591
+ # the full state each call)
592
+ if not self.condensed:
593
+ self._deliver_tool_result(
594
+ case=case, channel_name=matched,
595
+ tool_call=tool_call,
596
+ conversation=conversation,
597
+ acquired=acquired,
598
+ channel_config=channel_config,
599
+ )
600
+
601
+ # ---- Final Diagnosis ----
602
+ if not result.committed_early or not result.final_ranking:
603
+ if self.condensed:
604
+ final_ranking, final_response, final_belief = (
605
+ self._get_final_diagnosis_condensed(
606
+ case, acquired, acquired_data, result.steps,
607
+ )
608
+ )
609
+ else:
610
+ final_ranking, final_response, final_belief = (
611
+ self._get_final_diagnosis_tooluse(
612
+ case, acquired, conversation,
613
+ )
614
+ )
615
+ result.final_ranking = final_ranking
616
+ result.final_raw_response = final_response.text
617
+ result.total_latency_ms += final_response.latency_ms
618
+ result.total_input_tokens += final_response.input_tokens
619
+ result.total_output_tokens += final_response.output_tokens
620
+ if final_belief:
621
+ trajectory.states.append(final_belief)
622
+
623
+ result.acquired_channels = acquired
624
+ result.belief_trajectory = trajectory
625
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
626
+ result.total_case_cost = case.get_total_cost(acquired)
627
+ return result
628
+
629
+ # ============================================================
630
+ # Condensed Mode: Single-Turn Call Builder
631
+ # ============================================================
632
+
633
+ def _call_condensed(
634
+ self,
635
+ case: MedicalCase,
636
+ initial_images: list[str],
637
+ acquired: list[str],
638
+ acquired_data: dict[str, str],
639
+ steps: list[AcquisitionStep],
640
+ available: list[str],
641
+ candidates_str: str,
642
+ channel_config: dict,
643
+ step_tools: list[dict],
644
+ ) -> VLMResponse:
645
+ """
646
+ Build and execute a single-turn call for condensed mode.
647
+
648
+ Each call gets a complete, self-contained context:
649
+ 1. Initial image(s)
650
+ 2. Any acquired images
651
+ 3. Structured acquisition log (compact summary of all prior steps)
652
+ 4. All acquired text data
653
+ 5. Available channels
654
+ 6. Tools
655
+
656
+ This keeps context size predictable and prevents weaker models
657
+ from losing track of their reasoning in long multi-turn histories.
658
+ """
659
+ content = []
660
+
661
+ # 1. Initial image(s) — always included
662
+ content.extend(self._build_image_content(initial_images))
663
+
664
+ # 2. Acquired images — include all visual channels
665
+ for ch_name in acquired:
666
+ ch = case.get_channel(ch_name)
667
+ if ch and ch.channel_type == "image" and ch.value:
668
+ if isinstance(ch.value, list):
669
+ for img_b64 in ch.value:
670
+ content.append({
671
+ "type": "image_url",
672
+ "image_url": {
673
+ "url": f"data:image/jpeg;base64,{img_b64}",
674
+ },
675
+ })
676
+ else:
677
+ content.append({
678
+ "type": "image_url",
679
+ "image_url": {
680
+ "url": f"data:image/jpeg;base64,{ch.value}",
681
+ },
682
+ })
683
+
684
+ # 3. Build the text prompt
685
+ available_str = format_available_channels(channel_config, acquired)
686
+ log_str = _build_acquisition_log(steps, acquired_data)
687
+
688
+ # 4. Collect all currently available context (initial + acquired)
689
+ current_context = format_acquired_info(case.get_text_context(acquired))
690
+
691
+ prompt = (
692
+ f"Review all currently available clinical information below.\n\n"
693
+ f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
694
+ f"Current available evidence:\n{current_context}\n\n"
695
+ )
696
+
697
+ if steps:
698
+ prompt += (
699
+ f"Your prior acquisition history:\n{log_str}\n\n"
700
+ )
701
+
702
+ commit_hint = getattr(self, '_commit_hint', '')
703
+
704
+ if available:
705
+ prompt += (
706
+ f"Remaining channels you can request:\n{available_str}\n\n"
707
+ f"Decide: Would any remaining channel meaningfully change your "
708
+ f"differential enough to justify its cost? If yes, use "
709
+ f"request_information. If no, use commit_diagnosis with your final ranking."
710
+ f"{commit_hint}"
711
+ )
712
+ else:
713
+ prompt += (
714
+ f"All channels have been acquired. Use commit_diagnosis to "
715
+ f"submit your final ranked diagnosis."
716
+ )
717
+
718
+ content.append({"type": "text", "text": prompt})
719
+
720
+ return self.client.call_with_retry(
721
+ system_prompt=SYSTEM_PROMPT_CONDENSED,
722
+ user_text=None,
723
+ images=None,
724
+ temperature=config.TEMPERATURE,
725
+ max_tokens=config.MAX_TOKENS,
726
+ tools=step_tools,
727
+ messages=[{"role": "user", "content": content}],
728
+ )
729
+
730
+ # ============================================================
731
+ # Full Mode: Tool Result Delivery
732
+ # ============================================================
733
+
734
+ def _deliver_tool_result(
735
+ self,
736
+ case: MedicalCase,
737
+ channel_name: str,
738
+ tool_call: ToolCall,
739
+ conversation: list[dict],
740
+ acquired: list[str],
741
+ channel_config: dict,
742
+ ):
743
+ """Deliver requested channel data as a tool_result message (full mode)."""
744
+ ch = case.get_channel(channel_name)
745
+
746
+ result_images = []
747
+ if ch and ch.channel_type == "image" and ch.value:
748
+ if isinstance(ch.value, list):
749
+ result_images.extend(ch.value)
750
+ else:
751
+ result_images.append(ch.value)
752
+
753
+ if ch and ch.channel_type == "text":
754
+ data_str = f"[{channel_name}]: {ch.value}"
755
+ elif ch and ch.channel_type == "image":
756
+ data_str = f"[{channel_name}]: (image provided — see attached)"
757
+ else:
758
+ data_str = f"[{channel_name}]: No data available for this channel."
759
+
760
+ available_after = [
761
+ n for n in case.requestable_names if n not in acquired
762
+ ]
763
+ available_after_str = format_available_channels(channel_config, acquired)
764
+
765
+ # Include commit hint from stopping criterion (if triggered)
766
+ commit_hint = getattr(self, '_commit_hint', '')
767
+
768
+ if available_after:
769
+ follow_up = (
770
+ f"Here is the information you requested:\n{data_str}\n\n"
771
+ f"Integrate this evidence with your prior observations.\n\n"
772
+ f"Remaining channels you can request:\n{available_after_str}\n\n"
773
+ f"Use request_information if another channel would meaningfully "
774
+ f"change your differential enough to justify its cost, or "
775
+ f"commit_diagnosis if confident."
776
+ f"{commit_hint}"
777
+ )
778
+ else:
779
+ follow_up = (
780
+ f"Here is the information you requested:\n{data_str}\n\n"
781
+ f"All channels have been acquired. Use commit_diagnosis to "
782
+ f"submit your final ranked diagnosis."
783
+ )
784
+
785
+ conversation.append({
786
+ "role": "tool_result",
787
+ "tool_call_id": tool_call.call_id,
788
+ "content": data_str,
789
+ "images": result_images,
790
+ "follow_up": follow_up,
791
+ })
792
+
793
+ def _deliver_channel_data_as_user_message(
794
+ self,
795
+ case: MedicalCase,
796
+ channel_name: str,
797
+ conversation: list[dict],
798
+ available_before: list[str],
799
+ acquired: list[str],
800
+ channel_config: dict,
801
+ ):
802
+ """Deliver channel data as a plain user message (fallback, full mode)."""
803
+ ch = case.get_channel(channel_name)
804
+ content = []
805
+
806
+ if ch and ch.channel_type == "image" and ch.value:
807
+ if isinstance(ch.value, list):
808
+ for img_b64 in ch.value:
809
+ content.append({
810
+ "type": "image_url",
811
+ "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
812
+ })
813
+ else:
814
+ content.append({
815
+ "type": "image_url",
816
+ "image_url": {"url": f"data:image/jpeg;base64,{ch.value}"},
817
+ })
818
+
819
+ if ch and ch.channel_type == "text":
820
+ data_str = f"[{channel_name}]: {ch.value}"
821
+ elif ch and ch.channel_type == "image":
822
+ data_str = f"[{channel_name}]: (image provided above)"
823
+ else:
824
+ data_str = f"[{channel_name}]: No data available."
825
+
826
+ available_after = [n for n in case.requestable_names if n not in acquired]
827
+ available_after_str = format_available_channels(channel_config, acquired)
828
+
829
+ if available_after:
830
+ text = (
831
+ f"Data received:\n{data_str}\n\n"
832
+ f"Remaining channels:\n{available_after_str}\n\n"
833
+ f"Use request_information only if another channel is worth its cost, or commit_diagnosis."
834
+ )
835
+ else:
836
+ text = (
837
+ f"Data received:\n{data_str}\n\n"
838
+ f"All channels acquired. Use commit_diagnosis."
839
+ )
840
+
841
+ content.append({"type": "text", "text": text})
842
+ conversation.append({"role": "user", "content": content})
843
+
844
+ # ============================================================
845
+ # Final Diagnosis
846
+ # ============================================================
847
+
848
+ def _get_final_diagnosis_tooluse(
849
+ self,
850
+ case: MedicalCase,
851
+ acquired: list[str],
852
+ conversation: list[dict],
853
+ ) -> tuple[list[dict], VLMResponse, BeliefState | None]:
854
+ """Get final diagnosis via tool call (full mode)."""
855
+ text_context = case.get_text_context(acquired)
856
+ acquired_str = format_acquired_info(text_context)
857
+ candidates_str = "\n".join(
858
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
859
+ )
860
+
861
+ final_prompt = (
862
+ f"All information has been gathered. Submit your final diagnosis.\n\n"
863
+ f"Information acquired:\n{acquired_str}\n\n"
864
+ f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
865
+ f"Use commit_diagnosis with calibrated probabilities summing to 1.0 "
866
+ f"and key_evidence for each diagnosis. Favor the least resource-intensive "
867
+ f"pathway supported by the evidence."
868
+ )
869
+ conversation.append({"role": "user", "content": final_prompt})
870
+
871
+ commit_tools = constrain_tools_for_step(budget_remaining=0)
872
+
873
+ response = self.client.call_with_retry(
874
+ system_prompt=SYSTEM_PROMPT_FINAL,
875
+ messages=conversation,
876
+ temperature=config.TEMPERATURE,
877
+ max_tokens=config.MAX_TOKENS,
878
+ tools=commit_tools,
879
+ )
880
+
881
+ return self._parse_final_response(response, case, acquired)
882
+
883
+ def _get_final_diagnosis_condensed(
884
+ self,
885
+ case: MedicalCase,
886
+ acquired: list[str],
887
+ acquired_data: dict[str, str],
888
+ steps: list[AcquisitionStep],
889
+ ) -> tuple[list[dict], VLMResponse, BeliefState | None]:
890
+ """Get final diagnosis via single-turn call (condensed mode)."""
891
+ content = []
892
+
893
+ # Include all images
894
+ content.extend(self._build_image_content(case.get_initial_images()))
895
+ for ch_name in acquired:
896
+ ch = case.get_channel(ch_name)
897
+ if ch and ch.channel_type == "image" and ch.value:
898
+ if isinstance(ch.value, list):
899
+ for img_b64 in ch.value:
900
+ content.append({
901
+ "type": "image_url",
902
+ "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
903
+ })
904
+ else:
905
+ content.append({
906
+ "type": "image_url",
907
+ "image_url": {"url": f"data:image/jpeg;base64,{ch.value}"},
908
+ })
909
+
910
+ # Build text
911
+ candidates_str = "\n".join(
912
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
913
+ )
914
+ log_str = _build_acquisition_log(steps, acquired_data)
915
+ current_context = format_acquired_info(case.get_text_context(acquired))
916
+
917
+ prompt = (
918
+ f"Submit your final diagnosis based on all gathered information.\n\n"
919
+ f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
920
+ f"Acquisition history:\n{log_str}\n\n"
921
+ f"All currently available evidence:\n{current_context}\n\n"
922
+ f"Use commit_diagnosis with calibrated probabilities summing to 1.0 "
923
+ f"and key_evidence for each diagnosis. Favor the least resource-intensive "
924
+ f"pathway supported by the evidence."
925
+ )
926
+ content.append({"type": "text", "text": prompt})
927
+
928
+ commit_tools = constrain_tools_for_step(budget_remaining=0)
929
+
930
+ response = self.client.call_with_retry(
931
+ system_prompt=SYSTEM_PROMPT_FINAL,
932
+ messages=[{"role": "user", "content": content}],
933
+ temperature=config.TEMPERATURE,
934
+ max_tokens=config.MAX_TOKENS,
935
+ tools=commit_tools,
936
+ )
937
+
938
+ return self._parse_final_response(response, case, acquired)
939
+
940
+ def _parse_final_response(
941
+ self,
942
+ response: VLMResponse,
943
+ case: MedicalCase,
944
+ acquired: list[str],
945
+ ) -> tuple[list[dict], VLMResponse, BeliefState | None]:
946
+ """Parse the final diagnosis response (shared by both modes)."""
947
+ tool_call = response.tool_call
948
+ if tool_call and tool_call.tool_name == "commit_diagnosis":
949
+ ranking = self._extract_ranking_from_commit(tool_call.arguments)
950
+ distribution = {d["name"]: d["confidence"] for d in ranking}
951
+ belief = BeliefState(
952
+ step=len(acquired),
953
+ distribution=distribution,
954
+ channel_acquired=None,
955
+ )
956
+ return ranking, response, belief
957
+
958
+ logger.warning(
959
+ f"[{case.case_id}] Final diagnosis: no tool call, "
960
+ f"falling back to text extraction"
961
+ )
962
+ ranking = self._extract_ranking_from_text(response.text, case.candidates)
963
+ return ranking, response, None
964
+
965
+ # ============================================================
966
+ # Baseline Conditions
967
+ # ============================================================
968
+
969
+ def get_diagnosis_at_state(
970
+ self, case: MedicalCase, acquired: list[str]
971
+ ) -> tuple[list[dict], VLMResponse]:
972
+ """
973
+ Public helper: get a diagnosis given a set of acquired channels.
974
+
975
+ Used by TrajectoryCollector to evaluate intermediate states.
976
+ Returns (ranking, response).
977
+ """
978
+ return self._get_final_diagnosis_single(case, acquired)
979
+
980
+ def diagnose_passive(self, case: MedicalCase) -> AgentResult:
981
+ """Passive baseline: initial available context only, no acquisition."""
982
+ result = AgentResult(
983
+ case_id=case.case_id, dataset=case.dataset,
984
+ prompt_variant=self.prompt_variant,
985
+ backend=self.client.model, budget=0,
986
+ )
987
+ final_ranking, final_response = self._get_final_diagnosis_single(
988
+ case, acquired=[],
989
+ )
990
+ result.final_ranking = final_ranking
991
+ result.final_raw_response = final_response.text
992
+ result.total_latency_ms = final_response.latency_ms
993
+ result.total_input_tokens = final_response.input_tokens
994
+ result.total_output_tokens = final_response.output_tokens
995
+ result.total_case_cost = case.get_total_cost([])
996
+ return result
997
+
998
+ def diagnose_oracle(self, case: MedicalCase) -> AgentResult:
999
+ """Oracle baseline: ALL information given upfront."""
1000
+ all_channels = list(case.requestable_channels.keys())
1001
+ result = AgentResult(
1002
+ case_id=case.case_id, dataset=case.dataset,
1003
+ prompt_variant=self.prompt_variant,
1004
+ backend=self.client.model,
1005
+ budget=len(all_channels),
1006
+ acquired_channels=all_channels,
1007
+ )
1008
+ final_ranking, final_response = self._get_final_diagnosis_single(
1009
+ case, acquired=all_channels,
1010
+ )
1011
+ result.final_ranking = final_ranking
1012
+ result.final_raw_response = final_response.text
1013
+ result.total_latency_ms = final_response.latency_ms
1014
+ result.total_input_tokens = final_response.input_tokens
1015
+ result.total_output_tokens = final_response.output_tokens
1016
+ result.acquisition_cost = case.get_acquisition_cost(all_channels)
1017
+ result.total_case_cost = case.get_total_cost(all_channels)
1018
+ return result
1019
+
1020
+ def diagnose_fixed_order(
1021
+ self, case: MedicalCase, order: list[str] = None
1022
+ ) -> AgentResult:
1023
+ """Fixed-order baseline: acquire channels in predetermined order."""
1024
+ if order is None:
1025
+ order = list(case.requestable_channels.keys())
1026
+ max_acq = self.budget if self.budget is not None else len(order)
1027
+ acquired = order[:max_acq]
1028
+ result = AgentResult(
1029
+ case_id=case.case_id, dataset=case.dataset,
1030
+ prompt_variant=self.prompt_variant,
1031
+ backend=self.client.model,
1032
+ budget=max_acq,
1033
+ acquired_channels=acquired,
1034
+ )
1035
+ final_ranking, final_response = self._get_final_diagnosis_single(
1036
+ case, acquired=acquired,
1037
+ )
1038
+ result.final_ranking = final_ranking
1039
+ result.final_raw_response = final_response.text
1040
+ result.total_latency_ms = final_response.latency_ms
1041
+ result.total_input_tokens = final_response.input_tokens
1042
+ result.total_output_tokens = final_response.output_tokens
1043
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
1044
+ result.total_case_cost = case.get_total_cost(acquired)
1045
+ return result
1046
+
1047
+ def _get_final_diagnosis_single(
1048
+ self, case: MedicalCase, acquired: list[str]
1049
+ ) -> tuple[list[dict], VLMResponse]:
1050
+ """Single-turn final diagnosis (for baselines)."""
1051
+ images = case.get_all_images_up_to(acquired)
1052
+ text_context = case.get_text_context(acquired)
1053
+ acquired_str = format_acquired_info(text_context)
1054
+ candidates_str = "\n".join(
1055
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
1056
+ )
1057
+
1058
+ user_text = (
1059
+ f"Provide your diagnosis using the currently available clinical information.\n\n"
1060
+ f"Available information:\n{acquired_str}\n\n"
1061
+ f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
1062
+ f"Use commit_diagnosis with calibrated probabilities summing "
1063
+ f"to 1.0 and key_evidence for each diagnosis. Prefer the least costly "
1064
+ f"explanation supported by the evidence."
1065
+ )
1066
+
1067
+ commit_tools = constrain_tools_for_step(budget_remaining=0)
1068
+
1069
+ response = self.client.call_with_retry(
1070
+ system_prompt=SYSTEM_PROMPT_FINAL,
1071
+ user_text=user_text,
1072
+ images=images,
1073
+ temperature=config.TEMPERATURE,
1074
+ max_tokens=config.MAX_TOKENS,
1075
+ tools=commit_tools,
1076
+ )
1077
+
1078
+ tool_call = response.tool_call
1079
+ if tool_call and tool_call.tool_name == "commit_diagnosis":
1080
+ ranking = self._extract_ranking_from_commit(tool_call.arguments)
1081
+ return ranking, response
1082
+
1083
+ ranking = self._extract_ranking_from_text(response.text, case.candidates)
1084
+ return ranking, response
1085
+
1086
+ # ============================================================
1087
+ # Helpers
1088
+ # ============================================================
1089
+
1090
+ def _build_image_content(self, images: list[str]) -> list[dict]:
1091
+ """Build image content blocks for API messages."""
1092
+ content = []
1093
+ for img_b64 in images:
1094
+ content.append({
1095
+ "type": "image_url",
1096
+ "image_url": {
1097
+ "url": f"data:image/jpeg;base64,{img_b64}",
1098
+ "detail": "high",
1099
+ },
1100
+ })
1101
+ return content
1102
+
1103
+ def _extract_ranking_from_commit(self, args: dict) -> list[dict]:
1104
+ """Extract ranking from commit_diagnosis tool call (structured JSON)."""
1105
+ ranked = args.get("ranked_diagnoses", [])
1106
+ ranking = []
1107
+ for i, entry in enumerate(ranked):
1108
+ ranking.append({
1109
+ "name": entry.get("name", ""),
1110
+ "confidence": entry.get("confidence", 0.0),
1111
+ "rank": i + 1,
1112
+ "key_evidence": entry.get("key_evidence", ""),
1113
+ })
1114
+ ranking.sort(key=lambda x: x["confidence"], reverse=True)
1115
+ for i, entry in enumerate(ranking):
1116
+ entry["rank"] = i + 1
1117
+ return ranking
1118
+
1119
+ def _extract_ranking_from_text(
1120
+ self, text: str, candidates: list[str]
1121
+ ) -> list[dict]:
1122
+ """Last-resort fallback: extract ranking from free text."""
1123
+ import re
1124
+ ranking = []
1125
+ pattern = (
1126
+ r"(\d+)\.\s*(.+?)\s*"
1127
+ r"\((?:confidence|probability|prob|conf):\s*([\d.]+)\)"
1128
+ )
1129
+ matches = re.findall(pattern, text, re.IGNORECASE)
1130
+ if matches:
1131
+ for rank_str, name, conf_str in matches:
1132
+ try:
1133
+ ranking.append({
1134
+ "name": name.strip(),
1135
+ "confidence": float(conf_str),
1136
+ "rank": int(rank_str),
1137
+ })
1138
+ except ValueError:
1139
+ continue
1140
+ if not ranking and candidates:
1141
+ for i, candidate in enumerate(candidates):
1142
+ if candidate.lower() in text.lower():
1143
+ ranking.append({
1144
+ "name": candidate,
1145
+ "confidence": max(0.1, 1.0 - i * 0.2),
1146
+ "rank": len(ranking) + 1,
1147
+ })
1148
+ ranking.sort(key=lambda x: x.get("confidence", 0), reverse=True)
1149
+ for i, entry in enumerate(ranking):
1150
+ entry["rank"] = i + 1
1151
+ return ranking
1152
+
1153
+ def _match_channel(
1154
+ self, requested: str, available: list[str]
1155
+ ) -> str | None:
1156
+ """Match requested channel name to available channels."""
1157
+ requested = requested.lower().strip().replace(" ", "_")
1158
+ if requested in available:
1159
+ return requested
1160
+ for ch in available:
1161
+ if requested in ch or ch in requested:
1162
+ return ch
1163
+ req_words = set(requested.split("_"))
1164
+ best_match, best_overlap = None, 0
1165
+ for ch in available:
1166
+ overlap = len(req_words & set(ch.split("_")))
1167
+ if overlap > best_overlap:
1168
+ best_overlap = overlap
1169
+ best_match = ch
1170
+ return best_match if best_overlap > 0 else None
api_client.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified multi-backend VLM API client with tool-use support.
3
+
4
+ Supports OpenAI (GPT-4o), Anthropic (Claude), and Together (Qwen2.5-VL).
5
+ Handles image encoding, rate limiting, retries, response normalization,
6
+ and native function/tool calling across all backends.
7
+ """
8
+ import base64
9
+ import io
10
+ import json
11
+ import time
12
+ import logging
13
+ from collections import deque
14
+ from pathlib import Path
15
+ from abc import ABC, abstractmethod
16
+ from dataclasses import dataclass, field
17
+
18
+ from PIL import Image
19
+
20
+ import config
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class VLMResponse:
27
+ """Normalized response from any VLM backend, including tool calls."""
28
+ text: str
29
+ model: str
30
+ backend: str
31
+ input_tokens: int
32
+ output_tokens: int
33
+ latency_ms: float
34
+ tool_call: object | None = None # tools.ToolCall if a tool was called
35
+
36
+
37
+ def _normalize_image_mode(img: Image.Image) -> Image.Image:
38
+ """Normalize medical image modes to RGB-compatible formats for JPEG encoding."""
39
+ if img.mode in ("RGB",):
40
+ return img
41
+ if img.mode == "RGBA":
42
+ background = Image.new("RGB", img.size, (255, 255, 255))
43
+ background.paste(img, mask=img.split()[3])
44
+ return background
45
+ if img.mode == "L":
46
+ return img.convert("RGB")
47
+ if img.mode in ("I", "I;16", "I;16B", "I;16L"):
48
+ import numpy as np
49
+ arr = np.array(img, dtype=np.float64)
50
+ if arr.max() > arr.min():
51
+ arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255.0
52
+ else:
53
+ arr = np.zeros_like(arr)
54
+ return Image.fromarray(arr.astype(np.uint8)).convert("RGB")
55
+ if img.mode == "F":
56
+ import numpy as np
57
+ arr = np.array(img, dtype=np.float64)
58
+ if arr.max() > arr.min():
59
+ arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255.0
60
+ else:
61
+ arr = np.zeros_like(arr)
62
+ return Image.fromarray(arr.astype(np.uint8)).convert("RGB")
63
+ return img.convert("RGB")
64
+
65
+
66
+ def encode_image_to_base64(image_path: str | Path, max_size: int = 1024) -> str:
67
+ """Load and encode an image to base64, resizing if needed."""
68
+ img = Image.open(image_path)
69
+ if max(img.size) > max_size:
70
+ ratio = max_size / max(img.size)
71
+ new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
72
+ img = img.resize(new_size, Image.LANCZOS)
73
+ img = _normalize_image_mode(img)
74
+ buf = io.BytesIO()
75
+ img.save(buf, format="JPEG", quality=90)
76
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
77
+
78
+
79
+ def encode_pil_image_to_base64(img: Image.Image, max_size: int = 1024) -> str:
80
+ """Encode a PIL Image object to base64."""
81
+ if max(img.size) > max_size:
82
+ ratio = max_size / max(img.size)
83
+ new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
84
+ img = img.resize(new_size, Image.LANCZOS)
85
+ img = _normalize_image_mode(img)
86
+ buf = io.BytesIO()
87
+ img.save(buf, format="JPEG", quality=90)
88
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
89
+
90
+
91
+ class BaseVLMClient(ABC):
92
+ """Abstract base class for VLM API clients with tool-use support."""
93
+
94
+ def __init__(self, model: str, api_key: str, rate_limit: int = 30):
95
+ self.model = model
96
+ self.api_key = api_key
97
+ self.rate_limit = rate_limit
98
+ self._call_timestamps: deque[float] = deque()
99
+
100
+ def _rate_limit_wait(self):
101
+ """Enforce rate limiting using a sliding window over the last 60 seconds."""
102
+ now = time.time()
103
+ while self._call_timestamps and now - self._call_timestamps[0] >= 60.0:
104
+ self._call_timestamps.popleft()
105
+ if len(self._call_timestamps) >= self.rate_limit:
106
+ sleep_time = 60.0 - (now - self._call_timestamps[0])
107
+ if sleep_time > 0:
108
+ time.sleep(sleep_time)
109
+ self._call_timestamps.popleft()
110
+ self._call_timestamps.append(time.time())
111
+
112
+ @abstractmethod
113
+ def call(
114
+ self,
115
+ system_prompt: str,
116
+ user_text: str,
117
+ images: list[str] | None = None,
118
+ temperature: float = 0.1,
119
+ max_tokens: int = 2048,
120
+ tools: list[dict] | None = None,
121
+ ) -> VLMResponse:
122
+ """Make a VLM API call, optionally with tools."""
123
+ pass
124
+
125
+ def call_multiturn(
126
+ self,
127
+ system_prompt: str,
128
+ messages: list[dict],
129
+ temperature: float = 0.1,
130
+ max_tokens: int = 2048,
131
+ tools: list[dict] | None = None,
132
+ ) -> VLMResponse:
133
+ """Multi-turn conversation call with tool support. Override in subclasses."""
134
+ last_user = ""
135
+ last_images = []
136
+ for msg in reversed(messages):
137
+ if msg["role"] == "user":
138
+ if isinstance(msg["content"], str):
139
+ last_user = msg["content"]
140
+ elif isinstance(msg["content"], list):
141
+ for block in msg["content"]:
142
+ if block.get("type") == "text":
143
+ last_user = block["text"]
144
+ elif block.get("type") == "image_url":
145
+ last_images.append(block["image_url"]["url"].split(",", 1)[-1])
146
+ break
147
+ return self.call(system_prompt, last_user, last_images or None, temperature, max_tokens, tools)
148
+
149
+ def call_with_retry(
150
+ self,
151
+ system_prompt: str,
152
+ user_text: str = None,
153
+ images: list[str] | None = None,
154
+ temperature: float = 0.1,
155
+ max_tokens: int = 2048,
156
+ max_retries: int = 3,
157
+ messages: list[dict] | None = None,
158
+ tools: list[dict] | None = None,
159
+ ) -> VLMResponse:
160
+ """Call with exponential backoff retry. Supports single-turn, multi-turn, and tools."""
161
+ for attempt in range(max_retries):
162
+ try:
163
+ self._rate_limit_wait()
164
+ if messages is not None:
165
+ return self.call_multiturn(system_prompt, messages, temperature, max_tokens, tools)
166
+ return self.call(system_prompt, user_text, images, temperature, max_tokens, tools)
167
+ except Exception as e:
168
+ wait_time = 2 ** attempt * 5
169
+ logger.warning(
170
+ f"API call failed (attempt {attempt + 1}/{max_retries}): {e}. "
171
+ f"Retrying in {wait_time}s..."
172
+ )
173
+ if attempt == max_retries - 1:
174
+ raise
175
+ time.sleep(wait_time)
176
+
177
+
178
+ def _parse_tool_call_openai(response_message) -> object | None:
179
+ """Extract a ToolCall from an OpenAI response message."""
180
+ from tools import ToolCall
181
+
182
+ tool_calls = getattr(response_message, "tool_calls", None)
183
+ if not tool_calls:
184
+ return None
185
+
186
+ tc = tool_calls[0] # Take the first tool call
187
+ try:
188
+ arguments = json.loads(tc.function.arguments)
189
+ except (json.JSONDecodeError, AttributeError):
190
+ arguments = {}
191
+
192
+ return ToolCall(
193
+ tool_name=tc.function.name,
194
+ arguments=arguments,
195
+ call_id=tc.id,
196
+ )
197
+
198
+
199
+ def _parse_tool_call_anthropic(response) -> object | None:
200
+ """Extract a ToolCall from an Anthropic response."""
201
+ from tools import ToolCall
202
+
203
+ for block in response.content:
204
+ if block.type == "tool_use":
205
+ return ToolCall(
206
+ tool_name=block.name,
207
+ arguments=block.input,
208
+ call_id=block.id,
209
+ )
210
+ return None
211
+
212
+
213
+ # ============================================================
214
+ # OpenAI Backend (GPT-4o) — with tool calling
215
+ # ============================================================
216
+
217
+ class OpenAIClient(BaseVLMClient):
218
+ """OpenAI GPT-4o API client with native function calling."""
219
+
220
+ def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
221
+ super().__init__(
222
+ model=model or config.MODELS["openai"],
223
+ api_key=api_key or config.OPENAI_API_KEY,
224
+ rate_limit=rate_limit or config.RATE_LIMITS["openai"],
225
+ )
226
+ from openai import OpenAI
227
+ self.client = OpenAI(api_key=self.api_key)
228
+
229
+ def call(
230
+ self,
231
+ system_prompt: str,
232
+ user_text: str,
233
+ images: list[str] | None = None,
234
+ temperature: float = 0.1,
235
+ max_tokens: int = 2048,
236
+ tools: list[dict] | None = None,
237
+ ) -> VLMResponse:
238
+ content = []
239
+ if images:
240
+ for img_b64 in images:
241
+ content.append({
242
+ "type": "image_url",
243
+ "image_url": {"url": f"data:image/jpeg;base64,{img_b64}", "detail": "high"},
244
+ })
245
+ content.append({"type": "text", "text": user_text})
246
+
247
+ messages = [
248
+ {"role": "system", "content": system_prompt},
249
+ {"role": "user", "content": content},
250
+ ]
251
+
252
+ kwargs = {
253
+ "model": self.model,
254
+ "messages": messages,
255
+ "temperature": temperature,
256
+ "max_tokens": max_tokens,
257
+ }
258
+ if tools:
259
+ from tools import to_openai_tools
260
+ kwargs["tools"] = to_openai_tools(tools)
261
+ kwargs["tool_choice"] = "required"
262
+
263
+ t0 = time.time()
264
+ response = self.client.chat.completions.create(**kwargs)
265
+ latency = (time.time() - t0) * 1000
266
+
267
+ msg = response.choices[0].message
268
+ tool_call = _parse_tool_call_openai(msg) if tools else None
269
+
270
+ return VLMResponse(
271
+ text=msg.content or "",
272
+ model=self.model,
273
+ backend="openai",
274
+ input_tokens=response.usage.prompt_tokens,
275
+ output_tokens=response.usage.completion_tokens,
276
+ latency_ms=latency,
277
+ tool_call=tool_call,
278
+ )
279
+
280
+ def call_multiturn(
281
+ self,
282
+ system_prompt: str,
283
+ messages: list[dict],
284
+ temperature: float = 0.1,
285
+ max_tokens: int = 2048,
286
+ tools: list[dict] | None = None,
287
+ ) -> VLMResponse:
288
+ """
289
+ Multi-turn OpenAI call with full tool-calling protocol.
290
+
291
+ Translates our internal message format to OpenAI's API format:
292
+ - "user" → role:"user" (passed through)
293
+ - "assistant" → role:"assistant" with tool_calls array
294
+ - "tool_result" → role:"tool" (text result) + role:"user" (images + follow-up)
295
+
296
+ OpenAI requires: after an assistant message with tool_calls, the next
297
+ message MUST be role:"tool" with the matching tool_call_id. Images
298
+ cannot go in tool messages, so we send them in a separate user message.
299
+ """
300
+ api_messages = [{"role": "system", "content": system_prompt}]
301
+
302
+ for msg in messages:
303
+ role = msg["role"]
304
+
305
+ if role == "user":
306
+ api_messages.append({
307
+ "role": "user",
308
+ "content": msg["content"],
309
+ })
310
+
311
+ elif role == "assistant":
312
+ api_msg = {"role": "assistant"}
313
+ if msg.get("tool_calls"):
314
+ tc = msg["tool_calls"][0]
315
+ api_msg["tool_calls"] = [{
316
+ "id": tc.call_id,
317
+ "type": "function",
318
+ "function": {
319
+ "name": tc.tool_name,
320
+ "arguments": json.dumps(tc.arguments),
321
+ },
322
+ }]
323
+ # OpenAI requires content to be null when tool_calls present
324
+ api_msg["content"] = msg.get("content") or None
325
+ else:
326
+ api_msg["content"] = msg.get("content", "")
327
+ api_messages.append(api_msg)
328
+
329
+ elif role == "tool_result":
330
+ # Step 1: Send the tool result as role:"tool"
331
+ api_messages.append({
332
+ "role": "tool",
333
+ "tool_call_id": msg["tool_call_id"],
334
+ "content": msg.get("content", ""),
335
+ })
336
+
337
+ # Step 2: Send images + follow-up as a user message
338
+ # (OpenAI tool messages don't support image content blocks)
339
+ follow_up_content = []
340
+ for img_b64 in msg.get("images", []):
341
+ follow_up_content.append({
342
+ "type": "image_url",
343
+ "image_url": {
344
+ "url": f"data:image/jpeg;base64,{img_b64}",
345
+ },
346
+ })
347
+ follow_up = msg.get("follow_up", "")
348
+ if follow_up:
349
+ follow_up_content.append({
350
+ "type": "text",
351
+ "text": follow_up,
352
+ })
353
+ if follow_up_content:
354
+ api_messages.append({
355
+ "role": "user",
356
+ "content": follow_up_content,
357
+ })
358
+
359
+ kwargs = {
360
+ "model": self.model,
361
+ "messages": api_messages,
362
+ "temperature": temperature,
363
+ "max_tokens": max_tokens,
364
+ }
365
+ if tools:
366
+ from tools import to_openai_tools
367
+ kwargs["tools"] = to_openai_tools(tools)
368
+ kwargs["tool_choice"] = "required"
369
+
370
+ t0 = time.time()
371
+ response = self.client.chat.completions.create(**kwargs)
372
+ latency = (time.time() - t0) * 1000
373
+
374
+ msg = response.choices[0].message
375
+ tool_call = _parse_tool_call_openai(msg) if tools else None
376
+
377
+ return VLMResponse(
378
+ text=msg.content or "",
379
+ model=self.model,
380
+ backend="openai",
381
+ input_tokens=response.usage.prompt_tokens,
382
+ output_tokens=response.usage.completion_tokens,
383
+ latency_ms=latency,
384
+ tool_call=tool_call,
385
+ )
386
+
387
+
388
+ # ============================================================
389
+ # Anthropic Backend (Claude) — with tool use
390
+ # ============================================================
391
+
392
+ class AnthropicClient(BaseVLMClient):
393
+ """Anthropic Claude API client with native tool use."""
394
+
395
+ def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
396
+ super().__init__(
397
+ model=model or config.MODELS["anthropic"],
398
+ api_key=api_key or config.ANTHROPIC_API_KEY,
399
+ rate_limit=rate_limit or config.RATE_LIMITS["anthropic"],
400
+ )
401
+ from anthropic import Anthropic
402
+ self.client = Anthropic(api_key=self.api_key)
403
+
404
+ def call(
405
+ self,
406
+ system_prompt: str,
407
+ user_text: str,
408
+ images: list[str] | None = None,
409
+ temperature: float = 0.1,
410
+ max_tokens: int = 2048,
411
+ tools: list[dict] | None = None,
412
+ ) -> VLMResponse:
413
+ content = []
414
+ if images:
415
+ for img_b64 in images:
416
+ content.append({
417
+ "type": "image",
418
+ "source": {
419
+ "type": "base64",
420
+ "media_type": "image/jpeg",
421
+ "data": img_b64,
422
+ },
423
+ })
424
+ content.append({"type": "text", "text": user_text})
425
+
426
+ kwargs = {
427
+ "model": self.model,
428
+ "system": system_prompt,
429
+ "messages": [{"role": "user", "content": content}],
430
+ "temperature": temperature,
431
+ "max_tokens": max_tokens,
432
+ }
433
+ if tools:
434
+ from tools import to_anthropic_tools
435
+ kwargs["tools"] = to_anthropic_tools(tools)
436
+ kwargs["tool_choice"] = {"type": "any"}
437
+
438
+ t0 = time.time()
439
+ response = self.client.messages.create(**kwargs)
440
+ latency = (time.time() - t0) * 1000
441
+
442
+ # Extract text from response (may have both text and tool_use blocks)
443
+ text_parts = []
444
+ for block in response.content:
445
+ if hasattr(block, "text"):
446
+ text_parts.append(block.text)
447
+
448
+ tool_call = _parse_tool_call_anthropic(response) if tools else None
449
+
450
+ return VLMResponse(
451
+ text="\n".join(text_parts),
452
+ model=self.model,
453
+ backend="anthropic",
454
+ input_tokens=response.usage.input_tokens,
455
+ output_tokens=response.usage.output_tokens,
456
+ latency_ms=latency,
457
+ tool_call=tool_call,
458
+ )
459
+
460
+ def call_multiturn(
461
+ self,
462
+ system_prompt: str,
463
+ messages: list[dict],
464
+ temperature: float = 0.1,
465
+ max_tokens: int = 2048,
466
+ tools: list[dict] | None = None,
467
+ ) -> VLMResponse:
468
+ """
469
+ Multi-turn Anthropic call with full tool-use protocol.
470
+
471
+ Translates our internal message format to Anthropic's API format:
472
+ - "user" → role:"user" (passed through)
473
+ - "assistant" → role:"assistant" with tool_use content blocks
474
+ - "tool_result" → role:"user" with tool_result block + image blocks
475
+
476
+ Anthropic's protocol: after an assistant message with a tool_use block,
477
+ the next message MUST be role:"user" containing a tool_result block
478
+ with the matching tool_use_id. Images and follow-up text can be
479
+ included in the same user message as additional content blocks.
480
+ """
481
+ api_messages = []
482
+
483
+ for msg in messages:
484
+ role = msg["role"]
485
+
486
+ if role == "user":
487
+ content = msg["content"]
488
+ # Convert image_url format to Anthropic's image format
489
+ if isinstance(content, list):
490
+ anthropic_content = []
491
+ for block in content:
492
+ if block.get("type") == "image_url":
493
+ url = block["image_url"]["url"]
494
+ # Extract base64 data from data URL
495
+ if url.startswith("data:"):
496
+ b64_data = url.split(",", 1)[-1]
497
+ else:
498
+ b64_data = url
499
+ anthropic_content.append({
500
+ "type": "image",
501
+ "source": {
502
+ "type": "base64",
503
+ "media_type": "image/jpeg",
504
+ "data": b64_data,
505
+ },
506
+ })
507
+ elif block.get("type") == "text":
508
+ anthropic_content.append(block)
509
+ else:
510
+ anthropic_content.append(block)
511
+ api_messages.append({
512
+ "role": "user",
513
+ "content": anthropic_content,
514
+ })
515
+ else:
516
+ api_messages.append({
517
+ "role": "user",
518
+ "content": content,
519
+ })
520
+
521
+ elif role == "assistant":
522
+ content_blocks = []
523
+ if msg.get("content"):
524
+ content_blocks.append({
525
+ "type": "text",
526
+ "text": msg["content"],
527
+ })
528
+ if msg.get("tool_calls"):
529
+ tc = msg["tool_calls"][0]
530
+ content_blocks.append({
531
+ "type": "tool_use",
532
+ "id": tc.call_id,
533
+ "name": tc.tool_name,
534
+ "input": tc.arguments,
535
+ })
536
+ api_messages.append({
537
+ "role": "assistant",
538
+ "content": content_blocks,
539
+ })
540
+
541
+ elif role == "tool_result":
542
+ # Anthropic: tool_result goes in a user message alongside
543
+ # any images and follow-up text
544
+ user_content = []
545
+
546
+ # The tool_result block
547
+ user_content.append({
548
+ "type": "tool_result",
549
+ "tool_use_id": msg["tool_call_id"],
550
+ "content": msg.get("content", ""),
551
+ })
552
+
553
+ # Images from the channel data
554
+ for img_b64 in msg.get("images", []):
555
+ user_content.append({
556
+ "type": "image",
557
+ "source": {
558
+ "type": "base64",
559
+ "media_type": "image/jpeg",
560
+ "data": img_b64,
561
+ },
562
+ })
563
+
564
+ # Follow-up text (next-step instructions)
565
+ follow_up = msg.get("follow_up", "")
566
+ if follow_up:
567
+ user_content.append({
568
+ "type": "text",
569
+ "text": follow_up,
570
+ })
571
+
572
+ api_messages.append({
573
+ "role": "user",
574
+ "content": user_content,
575
+ })
576
+
577
+ kwargs = {
578
+ "model": self.model,
579
+ "system": system_prompt,
580
+ "messages": api_messages,
581
+ "temperature": temperature,
582
+ "max_tokens": max_tokens,
583
+ }
584
+ if tools:
585
+ from tools import to_anthropic_tools
586
+ kwargs["tools"] = to_anthropic_tools(tools)
587
+ kwargs["tool_choice"] = {"type": "any"}
588
+
589
+ t0 = time.time()
590
+ response = self.client.messages.create(**kwargs)
591
+ latency = (time.time() - t0) * 1000
592
+
593
+ text_parts = []
594
+ for block in response.content:
595
+ if hasattr(block, "text"):
596
+ text_parts.append(block.text)
597
+
598
+ tool_call = _parse_tool_call_anthropic(response) if tools else None
599
+
600
+ return VLMResponse(
601
+ text="\n".join(text_parts),
602
+ model=self.model,
603
+ backend="anthropic",
604
+ input_tokens=response.usage.input_tokens,
605
+ output_tokens=response.usage.output_tokens,
606
+ latency_ms=latency,
607
+ tool_call=tool_call,
608
+ )
609
+
610
+
611
+ # ============================================================
612
+ # Together Backend (Qwen2.5-VL) — with tool calling
613
+ # ============================================================
614
+
615
+ class TogetherClient(BaseVLMClient):
616
+ """Together AI client with function calling support."""
617
+
618
+ def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
619
+ super().__init__(
620
+ model=model or config.MODELS["together"],
621
+ api_key=api_key or config.TOGETHER_API_KEY,
622
+ rate_limit=rate_limit or config.RATE_LIMITS["together"],
623
+ )
624
+ from together import Together
625
+ self.client = Together(api_key=self.api_key)
626
+
627
+ def call(
628
+ self,
629
+ system_prompt: str,
630
+ user_text: str,
631
+ images: list[str] | None = None,
632
+ temperature: float = 0.1,
633
+ max_tokens: int = 2048,
634
+ tools: list[dict] | None = None,
635
+ ) -> VLMResponse:
636
+ content = []
637
+ if images:
638
+ for img_b64 in images:
639
+ content.append({
640
+ "type": "image_url",
641
+ "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
642
+ })
643
+ content.append({"type": "text", "text": user_text})
644
+
645
+ messages = [
646
+ {"role": "system", "content": system_prompt},
647
+ {"role": "user", "content": content},
648
+ ]
649
+
650
+ kwargs = {
651
+ "model": self.model,
652
+ "messages": messages,
653
+ "temperature": temperature,
654
+ "max_tokens": max_tokens,
655
+ }
656
+ if tools:
657
+ from tools import to_openai_tools
658
+ kwargs["tools"] = to_openai_tools(tools)
659
+
660
+ t0 = time.time()
661
+ response = self.client.chat.completions.create(**kwargs)
662
+ latency = (time.time() - t0) * 1000
663
+
664
+ msg = response.choices[0].message
665
+ usage = response.usage
666
+ tool_call = _parse_tool_call_openai(msg) if tools else None
667
+
668
+ return VLMResponse(
669
+ text=msg.content or "",
670
+ model=self.model,
671
+ backend="together",
672
+ input_tokens=getattr(usage, "prompt_tokens", 0),
673
+ output_tokens=getattr(usage, "completion_tokens", 0),
674
+ latency_ms=latency,
675
+ tool_call=tool_call,
676
+ )
677
+
678
+
679
+ # ============================================================
680
+ # Client Factory
681
+ # ============================================================
682
+
683
+ class OpenAIMiniClient(OpenAIClient):
684
+ """OpenAI GPT-4o-mini client."""
685
+
686
+ def __init__(self, model: str = None, api_key: str = None, rate_limit: int = None):
687
+ BaseVLMClient.__init__(
688
+ self,
689
+ model=model or config.MODELS["openai_mini"],
690
+ api_key=api_key or config.OPENAI_API_KEY,
691
+ rate_limit=rate_limit or config.RATE_LIMITS["openai_mini"],
692
+ )
693
+ from openai import OpenAI
694
+ self.client = OpenAI(api_key=self.api_key)
695
+
696
+
697
+ def create_client(backend: str, **kwargs) -> BaseVLMClient:
698
+ """Factory function to create a VLM client by backend name."""
699
+ clients = {
700
+ "openai": OpenAIClient,
701
+ "openai_mini": OpenAIMiniClient,
702
+ "anthropic": AnthropicClient,
703
+ "together": TogetherClient,
704
+ }
705
+ if backend not in clients:
706
+ raise ValueError(f"Unknown backend: {backend}. Choose from {list(clients.keys())}")
707
+ return clients[backend](**kwargs)
app.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive Demo for ActiveMedAgent.
3
+
4
+ A Gradio-based UI that lets users:
5
+ - Select from pre-built demo cases OR enter a custom clinical scenario
6
+ - Upload medical images (optional)
7
+ - Watch the agent's step-by-step reasoning, information acquisition, and
8
+ entropy reduction in real time
9
+ - No budget constraint — the agent acquires as many channels as it needs
10
+
11
+ Usage:
12
+ python app.py
13
+ python app.py --backend openai
14
+ python app.py --backend anthropic --port 7861
15
+ """
16
+ import argparse
17
+ import json
18
+ import logging
19
+ import sys
20
+ import time
21
+ import math
22
+ from pathlib import Path
23
+ from dataclasses import dataclass, field
24
+
25
+ import numpy as np
26
+ import gradio as gr
27
+ from PIL import Image
28
+
29
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
30
+
31
+ import config
32
+ from api_client import create_client, encode_image_to_base64, encode_pil_image_to_base64
33
+ from agent import ActiveMedAgent, AgentResult, AcquisitionStep, SYSTEM_PROMPT_FULL, SYSTEM_PROMPT_CONDENSED, SYSTEM_PROMPT_FINAL
34
+ from datasets.base import MedicalCase, ChannelData
35
+ from tools import AGENT_TOOLS, constrain_tools_for_step, ToolCall
36
+ from information_gain import (
37
+ BeliefState, BeliefTrajectory,
38
+ compute_entropy, compute_kl_divergence,
39
+ estimate_expected_information_gain,
40
+ should_commit, compute_value_of_information,
41
+ )
42
+ from prompts import format_available_channels, format_acquired_info
43
+
44
+ logging.basicConfig(
45
+ level=logging.INFO,
46
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
47
+ )
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ # ============================================================
52
+ # Backend Availability Detection
53
+ # ============================================================
54
+
55
+ def _detect_available_backends() -> list[str]:
56
+ """Detect which backends have API keys configured."""
57
+ available = []
58
+ if config.OPENAI_API_KEY and config.OPENAI_API_KEY != "sk-...":
59
+ available.append("openai")
60
+ if config.ANTHROPIC_API_KEY and config.ANTHROPIC_API_KEY != "sk-ant-...":
61
+ available.append("anthropic")
62
+ if config.TOGETHER_API_KEY:
63
+ available.append("together")
64
+ return available
65
+
66
+
67
+ AVAILABLE_BACKENDS = _detect_available_backends()
68
+
69
+
70
+ # ============================================================
71
+ # Simulation Mode — works without API keys
72
+ # ============================================================
73
+
74
+ def _simulate_agent_on_case(case: MedicalCase) -> AgentResult:
75
+ """
76
+ Run a simulated agent that demonstrates the full pipeline
77
+ with realistic-looking reasoning traces. No API keys needed.
78
+ """
79
+ import random
80
+ random.seed(42)
81
+
82
+ result = AgentResult(
83
+ case_id=case.case_id,
84
+ dataset=case.dataset,
85
+ prompt_variant="A",
86
+ backend="simulated (no API key)",
87
+ budget=len(case.requestable_channels),
88
+ )
89
+ trajectory = BeliefTrajectory(case_id=case.case_id)
90
+ acquired = []
91
+ n_candidates = len(case.candidates)
92
+
93
+ # Generate initial uniform-ish distribution
94
+ probs = np.random.dirichlet(np.ones(n_candidates) * 2.0).tolist()
95
+ probs.sort(reverse=True)
96
+ # Make ground truth likely to end up on top by the end
97
+ gt_idx = case.ground_truth_rank
98
+
99
+ requestable_names = list(case.requestable_channels.keys())
100
+ cumulative_cost = case.get_initial_cost()
101
+
102
+ for step_idx, ch_name in enumerate(requestable_names):
103
+ ch = case.requestable_channels[ch_name]
104
+
105
+ # Evolve the distribution — gradually concentrate on correct answer
106
+ progress = (step_idx + 1) / len(requestable_names)
107
+ new_probs = []
108
+ for i in range(n_candidates):
109
+ if i == gt_idx:
110
+ new_probs.append(probs[i] + 0.15 * progress + random.uniform(0, 0.05))
111
+ else:
112
+ new_probs.append(max(0.01, probs[i] - 0.04 * progress + random.uniform(-0.02, 0.02)))
113
+ total = sum(new_probs)
114
+ probs = [p / total for p in new_probs]
115
+
116
+ distribution = {case.candidates[i]: probs[i] for i in range(n_candidates)}
117
+ sorted_dist = sorted(distribution.items(), key=lambda x: -x[1])
118
+
119
+ prev_entropy = trajectory.states[-1].entropy if trajectory.states else compute_entropy(distribution) + 0.3
120
+ belief = BeliefState(
121
+ step=step_idx,
122
+ distribution=distribution,
123
+ channel_acquired=ch_name,
124
+ )
125
+ trajectory.states.append(belief)
126
+
127
+ ig = prev_entropy - belief.entropy
128
+ kl = abs(ig) * 1.2 + random.uniform(0, 0.1)
129
+
130
+ top_two = sorted_dist[:2]
131
+ reasoning_templates = [
132
+ f"Need to distinguish between {top_two[0][0]} ({top_two[0][1]:.0%}) and {top_two[1][0]} ({top_two[1][1]:.0%}). "
133
+ f"Requesting {ch_name} to resolve this uncertainty.",
134
+ f"Current top diagnosis is {top_two[0][0]} at {top_two[0][1]:.0%} but {top_two[1][0]} cannot be ruled out. "
135
+ f"The {ch_name} channel should provide discriminating evidence.",
136
+ f"Diagnostic uncertainty remains high (H={belief.entropy:.2f} bits). "
137
+ f"The {ch_name} data is expected to significantly narrow the differential.",
138
+ ]
139
+
140
+ step = AcquisitionStep(
141
+ step=step_idx,
142
+ tool_call=ToolCall(tool_name="request_information", arguments={
143
+ "channel_name": ch_name,
144
+ "reasoning": reasoning_templates[step_idx % len(reasoning_templates)],
145
+ }),
146
+ requested_channel=ch_name,
147
+ reasoning=reasoning_templates[step_idx % len(reasoning_templates)],
148
+ differential=[
149
+ {"name": name, "confidence": prob, "rank": i + 1}
150
+ for i, (name, prob) in enumerate(sorted_dist)
151
+ ],
152
+ committed=False,
153
+ raw_response="(simulated)",
154
+ latency_ms=random.uniform(800, 3000),
155
+ entropy=belief.entropy,
156
+ information_gain=ig,
157
+ kl_divergence=kl,
158
+ expected_impact={
159
+ "if_positive": sorted_dist[0][0],
160
+ "if_negative": sorted_dist[1][0],
161
+ },
162
+ )
163
+ result.steps.append(step)
164
+ acquired.append(ch_name)
165
+
166
+ # Final commit step
167
+ final_probs = []
168
+ for i in range(n_candidates):
169
+ if i == gt_idx:
170
+ final_probs.append(0.65 + random.uniform(0, 0.15))
171
+ else:
172
+ final_probs.append(random.uniform(0.02, 0.12))
173
+ total = sum(final_probs)
174
+ final_probs = [p / total for p in final_probs]
175
+ final_dist = {case.candidates[i]: final_probs[i] for i in range(n_candidates)}
176
+ sorted_final = sorted(final_dist.items(), key=lambda x: -x[1])
177
+
178
+ final_belief = BeliefState(
179
+ step=len(requestable_names),
180
+ distribution=final_dist,
181
+ channel_acquired=None,
182
+ )
183
+ trajectory.states.append(final_belief)
184
+
185
+ final_ranking = [
186
+ {
187
+ "name": name,
188
+ "confidence": prob,
189
+ "rank": i + 1,
190
+ "key_evidence": f"Supported by evidence from acquired channels" if i == 0 else "Less consistent with findings",
191
+ }
192
+ for i, (name, prob) in enumerate(sorted_final)
193
+ ]
194
+
195
+ commit_step = AcquisitionStep(
196
+ step=len(requestable_names),
197
+ tool_call=ToolCall(tool_name="commit_diagnosis", arguments={}),
198
+ requested_channel=None,
199
+ reasoning=f"After acquiring all available channels, the evidence strongly supports {sorted_final[0][0]}. "
200
+ f"Entropy reduced to {final_belief.entropy:.2f} bits. Committing diagnosis.",
201
+ differential=final_ranking,
202
+ committed=True,
203
+ raw_response="(simulated)",
204
+ latency_ms=random.uniform(500, 2000),
205
+ entropy=final_belief.entropy,
206
+ information_gain=trajectory.states[-2].entropy - final_belief.entropy if len(trajectory.states) >= 2 else 0,
207
+ kl_divergence=0.0,
208
+ )
209
+ result.steps.append(commit_step)
210
+ result.committed_early = False
211
+ result.final_ranking = final_ranking
212
+ result.acquired_channels = acquired
213
+ result.belief_trajectory = trajectory
214
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
215
+ result.total_case_cost = case.get_total_cost(acquired)
216
+ result.total_latency_ms = sum(s.latency_ms for s in result.steps)
217
+ result.total_input_tokens = 0
218
+ result.total_output_tokens = 0
219
+
220
+ return result
221
+
222
+
223
+ # ============================================================
224
+ # Synthetic Demo Cases
225
+ # ============================================================
226
+
227
+ def _make_dummy_image(width=224, height=224, color=(180, 60, 60)) -> str:
228
+ img = Image.new("RGB", (width, height), color)
229
+ arr = np.array(img)
230
+ noise = np.random.randint(-20, 20, arr.shape, dtype=np.int16)
231
+ arr = np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8)
232
+ img = Image.fromarray(arr)
233
+ return encode_pil_image_to_base64(img)
234
+
235
+
236
+ DEMO_CASES = {
237
+ "NEJM: Pulmonary Fibrosis": {
238
+ "description": (
239
+ "A 58-year-old man with progressive dyspnea and dry cough over 3 months. "
240
+ "30-pack-year smoking history, takes lisinopril for hypertension."
241
+ ),
242
+ "case": lambda: MedicalCase(
243
+ case_id="demo_nejm_ipf",
244
+ dataset="nejm",
245
+ initial_channels={
246
+ "demographics": ChannelData(
247
+ name="demographics", channel_type="text",
248
+ description="Patient age, sex, and ethnicity",
249
+ value="A 58-year-old man", always_given=True, cost=0.0, tier="free",
250
+ ),
251
+ "chief_complaint": ChannelData(
252
+ name="chief_complaint", channel_type="text",
253
+ description="Presenting symptoms and duration",
254
+ value="Progressive dyspnea and dry cough over the past 3 months.",
255
+ always_given=True, cost=0.0, tier="free",
256
+ ),
257
+ "medical_history": ChannelData(
258
+ name="medical_history", channel_type="text",
259
+ description="Past medical conditions, medications, family and social history",
260
+ value="30-pack-year smoking history. No prior lung disease. Takes lisinopril for hypertension.",
261
+ always_given=True, cost=0.0, tier="free",
262
+ ),
263
+ },
264
+ requestable_channels={
265
+ "exam_findings": ChannelData(
266
+ name="exam_findings", channel_type="text",
267
+ description="Physical examination results and observations",
268
+ value="Bibasilar crackles on auscultation. No clubbing. Oxygen saturation 92% on room air.",
269
+ cost=75.0, tier="cheap",
270
+ ),
271
+ "investigations": ChannelData(
272
+ name="investigations", channel_type="text",
273
+ description="Laboratory values, prior imaging results, and test outcomes",
274
+ value="PFTs show restrictive pattern with reduced DLCO. CT chest shows bilateral ground-glass opacities with honeycombing in the lower lobes.",
275
+ cost=250.0, tier="moderate",
276
+ ),
277
+ "image": ChannelData(
278
+ name="image", channel_type="image",
279
+ description="The primary diagnostic image (chest CT)",
280
+ value=_make_dummy_image(300, 300, (200, 200, 210)),
281
+ cost=800.0, tier="expensive",
282
+ ),
283
+ },
284
+ candidates=[
285
+ "A. Idiopathic pulmonary fibrosis",
286
+ "B. Hypersensitivity pneumonitis",
287
+ "C. Sarcoidosis",
288
+ "D. Lung adenocarcinoma",
289
+ "E. ACE-inhibitor induced cough with incidental CT findings",
290
+ ],
291
+ ground_truth="A. Idiopathic pulmonary fibrosis",
292
+ ground_truth_rank=0,
293
+ ),
294
+ },
295
+ "Dermatology: Pigmented Lesion": {
296
+ "description": (
297
+ "A 62-year-old woman presents with a pigmented lesion on her left forearm. "
298
+ "The lesion is 8mm x 6mm. Clinical photograph provided."
299
+ ),
300
+ "case": lambda: MedicalCase(
301
+ case_id="demo_midas_001",
302
+ dataset="midas",
303
+ initial_channels={
304
+ "clinical_30cm": ChannelData(
305
+ name="clinical_30cm", channel_type="image",
306
+ description="Clinical photograph at 30cm distance",
307
+ value=_make_dummy_image(224, 224, (180, 120, 100)),
308
+ always_given=True, cost=0.0, tier="free",
309
+ ),
310
+ },
311
+ requestable_channels={
312
+ "patient_demographics": ChannelData(
313
+ name="patient_demographics", channel_type="text",
314
+ description="Patient age, sex, and Fitzpatrick skin type",
315
+ value="Age: 62; Sex: Female; Fitzpatrick skin type: III",
316
+ cost=0.0, tier="free",
317
+ ),
318
+ "lesion_metadata": ChannelData(
319
+ name="lesion_metadata", channel_type="text",
320
+ description="Anatomic location, lesion length and width",
321
+ value="Anatomic location: Left forearm; Lesion length: 8mm; Lesion width: 6mm",
322
+ cost=25.0, tier="cheap",
323
+ ),
324
+ "clinical_15cm": ChannelData(
325
+ name="clinical_15cm", channel_type="image",
326
+ description="Clinical photograph at 15cm distance (closer view)",
327
+ value=_make_dummy_image(224, 224, (170, 110, 90)),
328
+ cost=50.0, tier="moderate",
329
+ ),
330
+ "dermoscopy": ChannelData(
331
+ name="dermoscopy", channel_type="image",
332
+ description="Dermoscopic image showing subsurface skin structures",
333
+ value=_make_dummy_image(224, 224, (100, 80, 60)),
334
+ cost=250.0, tier="expensive",
335
+ ),
336
+ },
337
+ candidates=[
338
+ "Melanoma in situ",
339
+ "Dysplastic nevus",
340
+ "Basal cell carcinoma",
341
+ "Seborrheic keratosis",
342
+ "Solar lentigo",
343
+ ],
344
+ ground_truth="Dysplastic nevus",
345
+ ground_truth_rank=1,
346
+ ),
347
+ },
348
+ "Ophthalmology: Retinal Biomarkers (OLIVES)": {
349
+ "description": (
350
+ "A patient with diabetic macular edema (DME), 4 prior anti-VEGF injections, "
351
+ "32 weeks in treatment. Fundus photograph provided."
352
+ ),
353
+ "case": lambda: MedicalCase(
354
+ case_id="demo_olives_P01",
355
+ dataset="olives",
356
+ initial_channels={
357
+ "disease_context": ChannelData(
358
+ name="disease_context", channel_type="text",
359
+ description="Disease type and treatment context",
360
+ value="Disease: Diabetic Macular Edema (DME). Prior anti-VEGF injections: 4. Weeks in treatment: 32.",
361
+ always_given=True, cost=0.0, tier="free",
362
+ ),
363
+ },
364
+ requestable_channels={
365
+ "clinical_measurements": ChannelData(
366
+ name="clinical_measurements", channel_type="text",
367
+ description="Best Corrected Visual Acuity (BCVA) and Central Subfield Thickness (CST)",
368
+ value="BCVA: 20/60 (logMAR 0.48); CST: 385 um",
369
+ cost=20.0, tier="cheap",
370
+ ),
371
+ "biomarker_hints": ChannelData(
372
+ name="biomarker_hints", channel_type="text",
373
+ description="Expert-graded presence of fundus-visible retinal biomarkers",
374
+ value="Hard Exudates: Present; Hemorrhage: Present; Microaneurysms: Present; Cotton Wool Spots: Not detected",
375
+ cost=100.0, tier="moderate",
376
+ ),
377
+ "oct_scan": ChannelData(
378
+ name="oct_scan", channel_type="image",
379
+ description="OCT B-scan showing retinal cross-section",
380
+ value=_make_dummy_image(512, 128, (60, 60, 60)),
381
+ cost=300.0, tier="expensive",
382
+ ),
383
+ "additional_oct": ChannelData(
384
+ name="additional_oct", channel_type="image",
385
+ description="Additional OCT B-scans from different retinal locations",
386
+ value=_make_dummy_image(512, 128, (50, 50, 55)),
387
+ cost=150.0, tier="very_expensive",
388
+ ),
389
+ },
390
+ candidates=[
391
+ "Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms",
392
+ "Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Fluid Srf, Hard Exudates, Hemorrhage, Microaneurysms",
393
+ "Present biomarkers: Hard Exudates, Hemorrhage, Microaneurysms",
394
+ "Present biomarkers: Dril, Ez Disruption, Fluid Irf, Shrm",
395
+ "No biomarkers detected",
396
+ ],
397
+ ground_truth="Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms",
398
+ ground_truth_rank=0,
399
+ ),
400
+ },
401
+ "NEJM: Cardiac Case": {
402
+ "description": (
403
+ "A 45-year-old woman presents with sudden onset chest pain and shortness "
404
+ "of breath. She recently completed a long international flight."
405
+ ),
406
+ "case": lambda: MedicalCase(
407
+ case_id="demo_nejm_pe",
408
+ dataset="nejm",
409
+ initial_channels={
410
+ "demographics": ChannelData(
411
+ name="demographics", channel_type="text",
412
+ description="Patient age, sex, and ethnicity",
413
+ value="A 45-year-old woman", always_given=True, cost=0.0, tier="free",
414
+ ),
415
+ "chief_complaint": ChannelData(
416
+ name="chief_complaint", channel_type="text",
417
+ description="Presenting symptoms and duration",
418
+ value="Sudden onset chest pain and shortness of breath, started 2 hours ago after returning from a 14-hour international flight.",
419
+ always_given=True, cost=0.0, tier="free",
420
+ ),
421
+ "medical_history": ChannelData(
422
+ name="medical_history", channel_type="text",
423
+ description="Past medical conditions, medications, family and social history",
424
+ value="On oral contraceptives for 5 years. BMI 32. No prior VTE. Mother had DVT at age 50.",
425
+ always_given=True, cost=0.0, tier="free",
426
+ ),
427
+ },
428
+ requestable_channels={
429
+ "exam_findings": ChannelData(
430
+ name="exam_findings", channel_type="text",
431
+ description="Physical examination results and observations",
432
+ value="Tachycardic (HR 110), tachypneic (RR 24), SpO2 89% on room air. Right calf swollen and tender. JVP elevated. Loud P2 on cardiac auscultation.",
433
+ cost=75.0, tier="cheap",
434
+ ),
435
+ "investigations": ChannelData(
436
+ name="investigations", channel_type="text",
437
+ description="Laboratory values, imaging results, and test outcomes",
438
+ value="D-dimer: 4200 ng/mL (markedly elevated). Troponin I: 0.15 ng/mL (mildly elevated). ABG: pH 7.48, PaO2 62 mmHg, PaCO2 28 mmHg. ECG: S1Q3T3 pattern, right axis deviation. CT pulmonary angiography: bilateral pulmonary emboli with right heart strain.",
439
+ cost=250.0, tier="moderate",
440
+ ),
441
+ "image": ChannelData(
442
+ name="image", channel_type="image",
443
+ description="CT Pulmonary Angiography image",
444
+ value=_make_dummy_image(300, 300, (100, 100, 120)),
445
+ cost=800.0, tier="expensive",
446
+ ),
447
+ },
448
+ candidates=[
449
+ "A. Pulmonary embolism",
450
+ "B. Acute myocardial infarction",
451
+ "C. Tension pneumothorax",
452
+ "D. Aortic dissection",
453
+ "E. Acute pericarditis",
454
+ ],
455
+ ground_truth="A. Pulmonary embolism",
456
+ ground_truth_rank=0,
457
+ ),
458
+ },
459
+ }
460
+
461
+
462
+ # ============================================================
463
+ # Custom Case Builder
464
+ # ============================================================
465
+
466
+ def build_custom_case(
467
+ scenario_text: str,
468
+ candidates_text: str,
469
+ channel_1_name: str, channel_1_type: str, channel_1_value: str,
470
+ channel_2_name: str, channel_2_type: str, channel_2_value: str,
471
+ channel_3_name: str, channel_3_type: str, channel_3_value: str,
472
+ uploaded_image=None,
473
+ ) -> MedicalCase:
474
+ """Build a MedicalCase from user-provided custom inputs."""
475
+ candidates = [c.strip() for c in candidates_text.strip().split("\n") if c.strip()]
476
+ if not candidates:
477
+ candidates = ["Diagnosis A", "Diagnosis B", "Diagnosis C"]
478
+
479
+ initial_channels = {
480
+ "clinical_scenario": ChannelData(
481
+ name="clinical_scenario", channel_type="text",
482
+ description="The presenting clinical scenario",
483
+ value=scenario_text,
484
+ always_given=True, cost=0.0, tier="free",
485
+ ),
486
+ }
487
+
488
+ if uploaded_image is not None:
489
+ img_b64 = encode_pil_image_to_base64(Image.fromarray(uploaded_image))
490
+ initial_channels["uploaded_image"] = ChannelData(
491
+ name="uploaded_image", channel_type="image",
492
+ description="Uploaded medical image",
493
+ value=img_b64, always_given=True, cost=0.0, tier="free",
494
+ )
495
+
496
+ requestable = {}
497
+ for name, ctype, value in [
498
+ (channel_1_name, channel_1_type, channel_1_value),
499
+ (channel_2_name, channel_2_type, channel_2_value),
500
+ (channel_3_name, channel_3_type, channel_3_value),
501
+ ]:
502
+ name = name.strip()
503
+ value = value.strip()
504
+ if name and value:
505
+ key = name.lower().replace(" ", "_")
506
+ requestable[key] = ChannelData(
507
+ name=key, channel_type=ctype.lower(),
508
+ description=name,
509
+ value=value,
510
+ cost=100.0, tier="moderate",
511
+ )
512
+
513
+ # Register channel config so the agent can look it up
514
+ custom_config = {}
515
+ for name, ch in initial_channels.items():
516
+ custom_config[name] = {
517
+ "description": ch.description,
518
+ "type": ch.channel_type,
519
+ "always_given": True,
520
+ "tier": ch.tier,
521
+ "cost": ch.cost,
522
+ "order": 0,
523
+ }
524
+ for i, (name, ch) in enumerate(requestable.items()):
525
+ custom_config[name] = {
526
+ "description": ch.description,
527
+ "type": ch.channel_type,
528
+ "always_given": False,
529
+ "tier": ch.tier,
530
+ "cost": ch.cost,
531
+ "order": i + 1,
532
+ }
533
+ config.CHANNEL_CONFIGS["custom"] = custom_config
534
+
535
+ return MedicalCase(
536
+ case_id="custom_case",
537
+ dataset="custom",
538
+ initial_channels=initial_channels,
539
+ requestable_channels=requestable,
540
+ candidates=candidates,
541
+ ground_truth=candidates[0] if candidates else "",
542
+ ground_truth_rank=0,
543
+ )
544
+
545
+
546
+ # ============================================================
547
+ # Formatting Helpers
548
+ # ============================================================
549
+
550
+ def format_step_markdown(step_idx: int, step: AcquisitionStep, cumulative_cost: float) -> str:
551
+ """Format a single acquisition step as rich markdown."""
552
+ lines = []
553
+
554
+ if step.committed:
555
+ lines.append(f"### Step {step_idx + 1}: COMMITTED TO DIAGNOSIS")
556
+ lines.append("")
557
+ lines.append(f"**Reasoning:** {step.reasoning}")
558
+ lines.append("")
559
+ if step.differential:
560
+ lines.append("**Final Ranking:**")
561
+ for d in step.differential:
562
+ conf = d.get("confidence", 0)
563
+ bar = render_bar(conf)
564
+ evidence = d.get("key_evidence", "")
565
+ lines.append(f"- **{d['name']}** — {conf:.1%} {bar}")
566
+ if evidence:
567
+ lines.append(f" - *Evidence:* {evidence}")
568
+ else:
569
+ lines.append(f"### Step {step_idx + 1}: Requested `{step.requested_channel}`")
570
+ lines.append("")
571
+ lines.append(f"**Reasoning:** {step.reasoning}")
572
+ lines.append("")
573
+
574
+ if step.differential:
575
+ lines.append("**Current Differential:**")
576
+ for d in step.differential:
577
+ conf = d.get("confidence", 0)
578
+ bar = render_bar(conf)
579
+ lines.append(f"- {d['name']} — {conf:.1%} {bar}")
580
+
581
+ if step.expected_impact:
582
+ lines.append("")
583
+ lines.append("**Expected Impact:**")
584
+ pos = step.expected_impact.get("if_positive", "N/A")
585
+ neg = step.expected_impact.get("if_negative", "N/A")
586
+ lines.append(f"- If positive/abnormal: *{pos}*")
587
+ lines.append(f"- If negative/normal: *{neg}*")
588
+
589
+ lines.append("")
590
+ lines.append("**Information Metrics:**")
591
+ lines.append(f"- Entropy: **{step.entropy:.3f}** bits")
592
+ if step.information_gain:
593
+ lines.append(f"- Information Gain: **{step.information_gain:.3f}** bits")
594
+ if step.kl_divergence:
595
+ lines.append(f"- KL Divergence: **{step.kl_divergence:.3f}** bits")
596
+ lines.append(f"- Latency: {step.latency_ms:.0f}ms")
597
+ lines.append(f"- Cumulative Cost: ${cumulative_cost:,.0f}")
598
+ lines.append("")
599
+ lines.append("---")
600
+
601
+ return "\n".join(lines)
602
+
603
+
604
+ def render_bar(value: float, width: int = 20) -> str:
605
+ """Render a text-based progress bar."""
606
+ filled = int(value * width)
607
+ return "`" + "\u2588" * filled + "\u2591" * (width - filled) + "`"
608
+
609
+
610
+ def format_entropy_table(trajectory: BeliefTrajectory) -> str:
611
+ """Format entropy trajectory as a markdown table."""
612
+ if not trajectory or not trajectory.states:
613
+ return "*No belief trajectory recorded.*"
614
+
615
+ lines = ["| Step | Channel | Entropy (bits) | Info Gain | Cumulative IG |"]
616
+ lines.append("|------|---------|---------------|-----------|---------------|")
617
+
618
+ cumulative_ig = 0.0
619
+ for i, state in enumerate(trajectory.states):
620
+ ch = state.channel_acquired or "initial/commit"
621
+ ig = 0.0
622
+ if i > 0:
623
+ ig = trajectory.states[i - 1].entropy - state.entropy
624
+ cumulative_ig += ig
625
+ lines.append(
626
+ f"| {i} | {ch} | {state.entropy:.3f} | "
627
+ f"{ig:+.3f} | {cumulative_ig:.3f} |"
628
+ )
629
+
630
+ lines.append("")
631
+ lines.append(f"**Information Efficiency:** {trajectory.information_efficiency:.1%}")
632
+ lines.append(f"**Total Information Gain:** {trajectory.total_information_gain:.3f} bits")
633
+
634
+ return "\n".join(lines)
635
+
636
+
637
+ def format_summary(result: AgentResult, case: MedicalCase) -> str:
638
+ """Format the overall result summary."""
639
+ lines = []
640
+ lines.append("## Summary")
641
+ lines.append("")
642
+
643
+ if result.final_ranking:
644
+ top = result.final_ranking[0]
645
+ top_name = top["name"].strip().lower()
646
+ gt_name = case.ground_truth.strip().lower()
647
+ # Fuzzy match: handle "Pulmonary embolism" vs "A. Pulmonary embolism"
648
+ correct = top_name == gt_name or top_name in gt_name or gt_name in top_name
649
+ icon = "correct" if correct else "incorrect"
650
+ lines.append(f"**Top Diagnosis:** {top['name']} ({top['confidence']:.1%})")
651
+ lines.append(f"**Ground Truth:** {case.ground_truth}")
652
+ lines.append(f"**Result:** {icon}")
653
+ else:
654
+ lines.append("*No diagnosis produced.*")
655
+
656
+ lines.append("")
657
+ lines.append(f"**Channels Acquired:** {len(result.acquired_channels)} / {len(case.requestable_channels)}")
658
+ if result.acquired_channels:
659
+ lines.append(f"**Acquisition Order:** {' -> '.join(result.acquired_channels)}")
660
+ lines.append(f"**Committed Early:** {'Yes' if result.committed_early else 'No'}")
661
+ lines.append(f"**Total Acquisition Cost:** ${result.acquisition_cost:,.0f}")
662
+ lines.append(f"**Total Case Cost:** ${result.total_case_cost:,.0f}")
663
+ lines.append(f"**Total Latency:** {result.total_latency_ms:,.0f}ms")
664
+ lines.append(f"**Tokens Used:** {result.total_input_tokens:,} in / {result.total_output_tokens:,} out")
665
+
666
+ return "\n".join(lines)
667
+
668
+
669
+ # ============================================================
670
+ # Main Agent Runner (for Gradio)
671
+ # ============================================================
672
+
673
+ def run_agent_on_case(
674
+ case: MedicalCase,
675
+ backend: str,
676
+ context_mode: str,
677
+ ) -> tuple[str, str, str]:
678
+ """
679
+ Run the agent on a case and return formatted markdown outputs.
680
+
681
+ Returns: (steps_markdown, entropy_table, summary_markdown)
682
+ """
683
+ if backend == "simulated (no API key)":
684
+ result = _simulate_agent_on_case(case)
685
+ model_name = "simulated"
686
+ else:
687
+ try:
688
+ client = create_client(backend)
689
+ except Exception as e:
690
+ return (
691
+ f"**Error creating {backend} client:** {e}\n\n"
692
+ "Make sure your API key is set in `.env` or environment variables. "
693
+ "Or select **simulated (no API key)** to see a demo trace.",
694
+ "", "",
695
+ )
696
+ agent = ActiveMedAgent(
697
+ client,
698
+ prompt_variant="A",
699
+ budget=None, # NO BUDGET CONSTRAINT
700
+ context_mode=context_mode if context_mode != "adaptive" else None,
701
+ )
702
+ try:
703
+ result = agent.diagnose(case)
704
+ except Exception as e:
705
+ return f"**Error running agent:** {e}", "", ""
706
+ model_name = client.model
707
+
708
+ # Format step-by-step reasoning
709
+ steps_parts = []
710
+ steps_parts.append("# Agent Reasoning Trace\n")
711
+ steps_parts.append(f"**Case:** {case.case_id} | **Dataset:** {case.dataset} | **Backend:** {model_name}\n")
712
+ steps_parts.append(f"**Candidates:** {', '.join(case.candidates)}\n")
713
+
714
+ initial_info = format_acquired_info(case.get_text_context([]))
715
+ steps_parts.append(f"**Initial Information:**\n{initial_info}\n")
716
+ steps_parts.append("---\n")
717
+
718
+ cumulative_cost = case.get_initial_cost()
719
+ for i, step in enumerate(result.steps):
720
+ if step.requested_channel:
721
+ cumulative_cost += case.get_channel_cost(step.requested_channel)
722
+ steps_parts.append(format_step_markdown(i, step, cumulative_cost))
723
+
724
+ steps_md = "\n".join(steps_parts)
725
+
726
+ # Format entropy trajectory
727
+ entropy_md = ""
728
+ if result.belief_trajectory:
729
+ entropy_md = format_entropy_table(result.belief_trajectory)
730
+
731
+ # Format summary
732
+ summary_md = format_summary(result, case)
733
+
734
+ return steps_md, entropy_md, summary_md
735
+
736
+
737
+ # ============================================================
738
+ # Gradio Event Handlers
739
+ # ============================================================
740
+
741
+ def on_demo_case_selected(case_name: str) -> tuple[str, str]:
742
+ """When a demo case is selected, show its description and candidates."""
743
+ if case_name in DEMO_CASES:
744
+ info = DEMO_CASES[case_name]
745
+ case = info["case"]()
746
+ desc = info["description"]
747
+ cands = "\n".join(case.candidates)
748
+ channels = []
749
+ for name, ch in case.requestable_channels.items():
750
+ channels.append(f"- **{name}** ({ch.tier}, ${ch.cost:,.0f}): {ch.description}")
751
+ ch_str = "\n".join(channels)
752
+ return (
753
+ f"{desc}\n\n**Available channels to acquire:**\n{ch_str}",
754
+ cands,
755
+ )
756
+ return "", ""
757
+
758
+
759
+ def run_demo_case(case_name: str, backend: str, context_mode: str):
760
+ """Run agent on a selected demo case."""
761
+ if case_name not in DEMO_CASES:
762
+ return "Please select a demo case.", "", ""
763
+
764
+ case = DEMO_CASES[case_name]["case"]()
765
+ return run_agent_on_case(case, backend, context_mode)
766
+
767
+
768
+ def run_custom_case(
769
+ scenario: str, candidates: str,
770
+ ch1_name: str, ch1_type: str, ch1_value: str,
771
+ ch2_name: str, ch2_type: str, ch2_value: str,
772
+ ch3_name: str, ch3_type: str, ch3_value: str,
773
+ uploaded_image,
774
+ backend: str, context_mode: str,
775
+ ):
776
+ """Run agent on a custom user-defined case."""
777
+ if not scenario.strip():
778
+ return "Please enter a clinical scenario.", "", ""
779
+
780
+ case = build_custom_case(
781
+ scenario, candidates,
782
+ ch1_name, ch1_type, ch1_value,
783
+ ch2_name, ch2_type, ch2_value,
784
+ ch3_name, ch3_type, ch3_value,
785
+ uploaded_image,
786
+ )
787
+ return run_agent_on_case(case, backend, context_mode)
788
+
789
+
790
+ # ============================================================
791
+ # Gradio UI
792
+ # ============================================================
793
+
794
+ def create_app():
795
+ with gr.Blocks(
796
+ title="ActiveMedAgent Interactive Demo",
797
+ ) as app:
798
+ gr.Markdown(
799
+ """
800
+ # ActiveMedAgent: Learned Information Acquisition for Medical Diagnosis
801
+ **Interactive Demo** — Watch the agent reason step-by-step, acquire information channels,
802
+ and track entropy reduction. **No budget constraint** — the agent decides when to stop.
803
+ """,
804
+ elem_classes="header-text",
805
+ )
806
+
807
+ # Build backend choices: simulation always available, real backends if keys exist
808
+ backend_choices = ["simulated (no API key)"] + AVAILABLE_BACKENDS
809
+ default_backend = AVAILABLE_BACKENDS[0] if AVAILABLE_BACKENDS else "simulated (no API key)"
810
+
811
+ with gr.Row():
812
+ backend = gr.Dropdown(
813
+ choices=backend_choices,
814
+ value=default_backend,
815
+ label="VLM Backend",
816
+ info="Select 'simulated' to see the demo without API keys",
817
+ scale=1,
818
+ )
819
+ context_mode = gr.Dropdown(
820
+ choices=["adaptive", "full", "condensed"],
821
+ value="adaptive",
822
+ label="Context Mode",
823
+ info="How the agent manages conversation history",
824
+ scale=1,
825
+ )
826
+
827
+ with gr.Tabs():
828
+ # ---- Tab 1: Demo Cases ----
829
+ with gr.TabItem("Demo Cases"):
830
+ gr.Markdown("Select a pre-built clinical scenario and run the agent.")
831
+ with gr.Row():
832
+ case_selector = gr.Dropdown(
833
+ choices=list(DEMO_CASES.keys()),
834
+ label="Select Case",
835
+ scale=2,
836
+ )
837
+ run_demo_btn = gr.Button("Run Agent", variant="primary", scale=1)
838
+
839
+ case_description = gr.Markdown(label="Case Description")
840
+ case_candidates = gr.Textbox(label="Candidate Diagnoses", lines=3, interactive=False)
841
+
842
+ case_selector.change(
843
+ fn=on_demo_case_selected,
844
+ inputs=[case_selector],
845
+ outputs=[case_description, case_candidates],
846
+ )
847
+
848
+ with gr.Row():
849
+ with gr.Column(scale=2):
850
+ demo_steps = gr.Markdown(
851
+ label="Reasoning Steps",
852
+ elem_classes="reasoning-box",
853
+ )
854
+ with gr.Column(scale=1):
855
+ demo_summary = gr.Markdown(label="Summary")
856
+ demo_entropy = gr.Markdown(label="Entropy Trajectory")
857
+
858
+ run_demo_btn.click(
859
+ fn=run_demo_case,
860
+ inputs=[case_selector, backend, context_mode],
861
+ outputs=[demo_steps, demo_entropy, demo_summary],
862
+ )
863
+
864
+ # ---- Tab 2: Custom Case ----
865
+ with gr.TabItem("Custom Case"):
866
+ gr.Markdown(
867
+ "Define your own clinical scenario, candidate diagnoses, "
868
+ "and information channels the agent can request."
869
+ )
870
+
871
+ with gr.Row():
872
+ with gr.Column():
873
+ custom_scenario = gr.Textbox(
874
+ label="Clinical Scenario",
875
+ placeholder="A 35-year-old woman presents with...",
876
+ lines=4,
877
+ )
878
+ custom_candidates = gr.Textbox(
879
+ label="Candidate Diagnoses (one per line)",
880
+ placeholder="A. Diagnosis one\nB. Diagnosis two\nC. Diagnosis three",
881
+ lines=5,
882
+ )
883
+ custom_image = gr.Image(
884
+ label="Upload Medical Image (optional)",
885
+ type="numpy",
886
+ )
887
+
888
+ with gr.Column():
889
+ gr.Markdown("### Requestable Information Channels")
890
+ gr.Markdown("Define up to 3 channels the agent can request.")
891
+
892
+ with gr.Group():
893
+ gr.Markdown("**Channel 1:**")
894
+ ch1_name = gr.Textbox(label="Name", value="Exam Findings", scale=1)
895
+ ch1_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
896
+ ch1_value = gr.Textbox(label="Content (what the agent receives)", lines=2,
897
+ placeholder="Physical exam: temperature 38.5C, ...")
898
+
899
+ with gr.Group():
900
+ gr.Markdown("**Channel 2:**")
901
+ ch2_name = gr.Textbox(label="Name", value="Lab Results", scale=1)
902
+ ch2_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
903
+ ch2_value = gr.Textbox(label="Content", lines=2,
904
+ placeholder="WBC 12,000, CRP elevated, ...")
905
+
906
+ with gr.Group():
907
+ gr.Markdown("**Channel 3:**")
908
+ ch3_name = gr.Textbox(label="Name", value="Imaging", scale=1)
909
+ ch3_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
910
+ ch3_value = gr.Textbox(label="Content", lines=2,
911
+ placeholder="CT scan shows...")
912
+
913
+ run_custom_btn = gr.Button("Run Agent on Custom Case", variant="primary")
914
+
915
+ with gr.Row():
916
+ with gr.Column(scale=2):
917
+ custom_steps = gr.Markdown(
918
+ label="Reasoning Steps",
919
+ elem_classes="reasoning-box",
920
+ )
921
+ with gr.Column(scale=1):
922
+ custom_summary = gr.Markdown(label="Summary")
923
+ custom_entropy = gr.Markdown(label="Entropy Trajectory")
924
+
925
+ run_custom_btn.click(
926
+ fn=run_custom_case,
927
+ inputs=[
928
+ custom_scenario, custom_candidates,
929
+ ch1_name, ch1_type, ch1_value,
930
+ ch2_name, ch2_type, ch2_value,
931
+ ch3_name, ch3_type, ch3_value,
932
+ custom_image,
933
+ backend, context_mode,
934
+ ],
935
+ outputs=[custom_steps, custom_entropy, custom_summary],
936
+ )
937
+
938
+ # ---- Tab 3: How It Works ----
939
+ with gr.TabItem("How It Works"):
940
+ gr.Markdown("""
941
+ ## ActiveMedAgent Architecture
942
+
943
+ ### Tool-Use Acquisition Loop
944
+ The agent uses native VLM function calling (not regex parsing) with two tools:
945
+ 1. **`request_information`** — Request one data channel, providing reasoning, current differential with calibrated probabilities, and expected impact
946
+ 2. **`commit_diagnosis`** — Submit final ranked diagnosis when confident
947
+
948
+ ### No Budget Constraint
949
+ The agent acquires as many channels as it needs (0 to all). It stops when:
950
+ - It calls `commit_diagnosis` (self-determined confidence)
951
+ - Information-theoretic stopping criteria trigger (convergence, confirmed dominance, or diminishing returns)
952
+ - All channels are exhausted
953
+
954
+ ### Information-Theoretic Metrics
955
+ At each step, the system tracks:
956
+ - **Shannon Entropy** H(p) — diagnostic uncertainty in bits
957
+ - **Information Gain** — entropy reduction from each acquisition
958
+ - **KL Divergence** — how much the belief distribution shifted
959
+ - **Expected Information Gain (EIG)** — predicted value of the next channel
960
+ - **Value of Information (VoI)** — whether continuing to acquire is worthwhile
961
+
962
+ ### Context Management
963
+ - **Full Mode**: Multi-turn conversation with complete history (for capable models)
964
+ - **Condensed Mode**: Fresh single-turn call each step with compressed state log (for weaker models)
965
+ - **Adaptive**: Auto-selects based on model capability
966
+
967
+ ### Stopping Criteria
968
+ 1. **Convergence**: Last acquisition < 0.05 bits of IG
969
+ 2. **Confirmed Dominance**: Top diagnosis > 90% probability with > 40% gap (after 2+ acquisitions)
970
+ 3. **Diminishing Returns**: Last 2 acquisitions both < 0.1 bits IG
971
+ """)
972
+
973
+ return app
974
+
975
+
976
+ # ============================================================
977
+ # Entry Point
978
+ # ============================================================
979
+
980
+ def main():
981
+ parser = argparse.ArgumentParser(description="ActiveMedAgent Interactive Demo")
982
+ parser.add_argument("--port", type=int, default=7860, help="Port to serve on")
983
+ parser.add_argument("--backend", default="openai", choices=["openai", "anthropic", "together"])
984
+ parser.add_argument("--share", action="store_true", help="Create a public Gradio link")
985
+ args = parser.parse_args()
986
+
987
+ app = create_app()
988
+ app.launch(
989
+ server_port=args.port,
990
+ share=args.share,
991
+ theme=gr.themes.Soft(),
992
+ css="""
993
+ .reasoning-box { font-size: 14px; }
994
+ .header-text { text-align: center; margin-bottom: 10px; }
995
+ """,
996
+ )
997
+
998
+
999
+ if __name__ == "__main__":
1000
+ main()
baselines.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Additional Baselines for ACL/EMNLP Submission.
3
+
4
+ Five baselines that answer: "Does active sequential acquisition actually
5
+ help over simpler strategies?"
6
+
7
+ 1. AllAtOnce: Give the VLM all text channels upfront (no sequential reasoning)
8
+ 2. RandomOrder: Acquire channels in random order (same budget as active)
9
+ 3. ClinicalGuidelineOrder: Follow standard clinical workflow ordering
10
+ 4. ReactBaseline: Free-form ReAct-style reasoning (no structured tool calls)
11
+ 5. CoTSinglePass: Chain-of-thought with all info in one shot
12
+
13
+ All baselines use the same VLM and produce AgentResult objects for
14
+ direct comparison with the active agent.
15
+ """
16
+ import json
17
+ import logging
18
+ import random
19
+ import re
20
+ import time
21
+ from dataclasses import field
22
+
23
+ import numpy as np
24
+
25
+ import config
26
+ from api_client import BaseVLMClient, VLMResponse
27
+ from agent import (
28
+ ActiveMedAgent, AgentResult, AcquisitionStep,
29
+ SYSTEM_PROMPT_FULL, SYSTEM_PROMPT_FINAL,
30
+ )
31
+ from datasets.base import MedicalCase, ChannelData
32
+ from tools import ToolCall, constrain_tools_for_step
33
+ from information_gain import BeliefState, BeliefTrajectory, compute_entropy
34
+ from prompts import format_acquired_info
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # ================================================================
40
+ # Clinical Guideline Orderings
41
+ # ================================================================
42
+
43
+ CLINICAL_GUIDELINE_ORDER = {
44
+ "nejm": [
45
+ "demographics",
46
+ "chief_complaint",
47
+ "medical_history",
48
+ "exam_findings",
49
+ "investigations",
50
+ "image",
51
+ ],
52
+ "midas": [
53
+ "patient_demographics",
54
+ "lesion_metadata",
55
+ "clinical_15cm",
56
+ "dermoscopy",
57
+ ],
58
+ "olives": [
59
+ "clinical_measurements",
60
+ "biomarker_hints",
61
+ "oct_scan",
62
+ "additional_oct",
63
+ ],
64
+ }
65
+
66
+
67
+ # ================================================================
68
+ # Baseline 1: All-At-Once
69
+ # ================================================================
70
+
71
+ class AllAtOnceBaseline:
72
+ """
73
+ Give the VLM all available text/image channels at once.
74
+
75
+ Tests whether sequential reasoning matters or if the VLM can
76
+ handle everything in a single pass with all evidence.
77
+ Different from Oracle: Oracle uses the experiment evaluation
78
+ framework; this uses the same prompt structure as the active agent.
79
+ """
80
+
81
+ def __init__(self, client: BaseVLMClient, prompt_variant: str = "A"):
82
+ self.client = client
83
+ self.prompt_variant = prompt_variant
84
+
85
+ def diagnose(self, case: MedicalCase) -> AgentResult:
86
+ all_channels = list(case.requestable_channels.keys())
87
+
88
+ result = AgentResult(
89
+ case_id=case.case_id,
90
+ dataset=case.dataset,
91
+ prompt_variant=self.prompt_variant,
92
+ backend=self.client.model,
93
+ budget=len(all_channels),
94
+ acquired_channels=all_channels,
95
+ )
96
+
97
+ images = case.get_all_images_up_to(all_channels)
98
+ text_context = case.get_text_context(all_channels)
99
+ acquired_str = format_acquired_info(text_context)
100
+ candidates_str = "\n".join(
101
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
102
+ )
103
+
104
+ system_prompt = (
105
+ "You are a medical diagnostic agent. You are given ALL available "
106
+ "clinical information at once. Analyze everything and provide your "
107
+ "final ranked diagnosis.\n\n"
108
+ "You MUST use the commit_diagnosis tool to submit your answer.\n"
109
+ "Include ALL candidate diagnoses with calibrated probabilities "
110
+ "summing to 1.0 and key_evidence for each."
111
+ )
112
+
113
+ user_text = (
114
+ f"All available clinical information:\n{acquired_str}\n\n"
115
+ f"Candidate diagnoses (rank ALL):\n{candidates_str}\n\n"
116
+ f"Analyze all information and submit your final diagnosis "
117
+ f"using commit_diagnosis."
118
+ )
119
+
120
+ commit_tools = constrain_tools_for_step(budget_remaining=0)
121
+
122
+ t0 = time.time()
123
+ response = self.client.call_with_retry(
124
+ system_prompt=system_prompt,
125
+ user_text=user_text,
126
+ images=images,
127
+ temperature=config.TEMPERATURE,
128
+ max_tokens=config.MAX_TOKENS,
129
+ tools=commit_tools,
130
+ )
131
+
132
+ result.total_latency_ms = response.latency_ms
133
+ result.total_input_tokens = response.input_tokens
134
+ result.total_output_tokens = response.output_tokens
135
+
136
+ if response.tool_call and response.tool_call.tool_name == "commit_diagnosis":
137
+ args = response.tool_call.arguments
138
+ ranked = args.get("ranked_diagnoses", [])
139
+ ranking = []
140
+ for i, entry in enumerate(ranked):
141
+ ranking.append({
142
+ "name": entry.get("name", ""),
143
+ "confidence": entry.get("confidence", 0.0),
144
+ "rank": i + 1,
145
+ "key_evidence": entry.get("key_evidence", ""),
146
+ })
147
+ ranking.sort(key=lambda x: x["confidence"], reverse=True)
148
+ for i, entry in enumerate(ranking):
149
+ entry["rank"] = i + 1
150
+ result.final_ranking = ranking
151
+ else:
152
+ result.final_ranking = _extract_ranking_from_text(
153
+ response.text, case.candidates
154
+ )
155
+
156
+ result.final_raw_response = response.text
157
+ result.acquisition_cost = case.get_acquisition_cost(all_channels)
158
+ result.total_case_cost = case.get_total_cost(all_channels)
159
+ return result
160
+
161
+
162
+ # ================================================================
163
+ # Baseline 2: Random Order Acquisition
164
+ # ================================================================
165
+
166
+ class RandomOrderBaseline:
167
+ """
168
+ Acquire channels in random order, then diagnose.
169
+
170
+ Uses the same active agent architecture but overrides channel
171
+ selection with random choice. This isolates the value of
172
+ strategic ordering from the value of having more information.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ client: BaseVLMClient,
178
+ prompt_variant: str = "A",
179
+ budget: int = None,
180
+ n_trials: int = 3,
181
+ seed: int = 42,
182
+ ):
183
+ self.client = client
184
+ self.prompt_variant = prompt_variant
185
+ self.budget = budget
186
+ self.n_trials = n_trials
187
+ self.seed = seed
188
+
189
+ def diagnose(self, case: MedicalCase) -> AgentResult:
190
+ """Run with random order. If n_trials > 1, returns best trial."""
191
+ rng = random.Random(self.seed + hash(case.case_id))
192
+ requestable = list(case.requestable_channels.keys())
193
+ max_acq = self.budget if self.budget is not None else len(requestable)
194
+
195
+ best_result = None
196
+ best_mrr = -1
197
+
198
+ for trial in range(self.n_trials):
199
+ order = list(requestable)
200
+ rng.shuffle(order)
201
+ acquired = order[:max_acq]
202
+
203
+ agent = ActiveMedAgent(
204
+ self.client, self.prompt_variant, budget=0,
205
+ )
206
+ result = AgentResult(
207
+ case_id=case.case_id,
208
+ dataset=case.dataset,
209
+ prompt_variant=self.prompt_variant,
210
+ backend=self.client.model,
211
+ budget=max_acq,
212
+ acquired_channels=acquired,
213
+ )
214
+
215
+ final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
216
+ result.final_ranking = final_ranking
217
+ result.final_raw_response = resp.text
218
+ result.total_latency_ms = resp.latency_ms
219
+ result.total_input_tokens = resp.input_tokens
220
+ result.total_output_tokens = resp.output_tokens
221
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
222
+ result.total_case_cost = case.get_total_cost(acquired)
223
+
224
+ if self.n_trials == 1:
225
+ return result
226
+
227
+ # Pick the trial with highest top-1 confidence (proxy for quality)
228
+ top_conf = final_ranking[0]["confidence"] if final_ranking else 0
229
+ if top_conf > best_mrr:
230
+ best_mrr = top_conf
231
+ best_result = result
232
+
233
+ return best_result
234
+
235
+ def diagnose_single_random(
236
+ self, case: MedicalCase, seed: int = None
237
+ ) -> AgentResult:
238
+ """Single random trial (for aggregate statistics)."""
239
+ rng = random.Random(seed or self.seed)
240
+ requestable = list(case.requestable_channels.keys())
241
+ max_acq = self.budget if self.budget is not None else len(requestable)
242
+ order = list(requestable)
243
+ rng.shuffle(order)
244
+ acquired = order[:max_acq]
245
+
246
+ agent = ActiveMedAgent(
247
+ self.client, self.prompt_variant, budget=0,
248
+ )
249
+ result = AgentResult(
250
+ case_id=case.case_id,
251
+ dataset=case.dataset,
252
+ prompt_variant=self.prompt_variant,
253
+ backend=self.client.model,
254
+ budget=max_acq,
255
+ acquired_channels=acquired,
256
+ )
257
+
258
+ final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
259
+ result.final_ranking = final_ranking
260
+ result.final_raw_response = resp.text
261
+ result.total_latency_ms = resp.latency_ms
262
+ result.total_input_tokens = resp.input_tokens
263
+ result.total_output_tokens = resp.output_tokens
264
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
265
+ result.total_case_cost = case.get_total_cost(acquired)
266
+ return result
267
+
268
+
269
+ # ================================================================
270
+ # Baseline 3: Clinical Guideline Order
271
+ # ================================================================
272
+
273
+ class ClinicalGuidelineBaseline:
274
+ """
275
+ Acquire channels in standard clinical workflow order.
276
+
277
+ Tests whether the VLM's learned ordering improves over the
278
+ conventional clinical approach (history -> exam -> labs -> imaging).
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ client: BaseVLMClient,
284
+ prompt_variant: str = "A",
285
+ budget: int = None,
286
+ ):
287
+ self.client = client
288
+ self.prompt_variant = prompt_variant
289
+ self.budget = budget
290
+
291
+ def diagnose(self, case: MedicalCase) -> AgentResult:
292
+ guideline_order = CLINICAL_GUIDELINE_ORDER.get(case.dataset, [])
293
+
294
+ # Filter to channels actually available in this case
295
+ available = set(case.requestable_channels.keys())
296
+ order = [ch for ch in guideline_order if ch in available]
297
+ # Append any remaining channels not in the guideline
298
+ for ch in case.requestable_channels.keys():
299
+ if ch not in order:
300
+ order.append(ch)
301
+
302
+ max_acq = self.budget if self.budget is not None else len(order)
303
+ acquired = order[:max_acq]
304
+
305
+ agent = ActiveMedAgent(
306
+ self.client, self.prompt_variant, budget=0,
307
+ )
308
+ result = AgentResult(
309
+ case_id=case.case_id,
310
+ dataset=case.dataset,
311
+ prompt_variant=self.prompt_variant,
312
+ backend=self.client.model,
313
+ budget=max_acq,
314
+ acquired_channels=acquired,
315
+ )
316
+
317
+ final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
318
+ result.final_ranking = final_ranking
319
+ result.final_raw_response = resp.text
320
+ result.total_latency_ms = resp.latency_ms
321
+ result.total_input_tokens = resp.input_tokens
322
+ result.total_output_tokens = resp.output_tokens
323
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
324
+ result.total_case_cost = case.get_total_cost(acquired)
325
+ return result
326
+
327
+
328
+ # ================================================================
329
+ # Baseline 4: ReAct-Style Free-Form Reasoning
330
+ # ================================================================
331
+
332
+ class ReactBaseline:
333
+ """
334
+ ReAct-style baseline: the VLM reasons in free text and requests
335
+ channels via text (not structured tool calls).
336
+
337
+ Tests whether the structured tool-use architecture improves over
338
+ free-form reasoning + regex parsing (the dominant approach in
339
+ prior medical agent work).
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ client: BaseVLMClient,
345
+ prompt_variant: str = "A",
346
+ budget: int = None,
347
+ ):
348
+ self.client = client
349
+ self.prompt_variant = prompt_variant
350
+ self.budget = budget
351
+
352
+ def diagnose(self, case: MedicalCase) -> AgentResult:
353
+ max_steps = len(case.requestable_names)
354
+ if self.budget is not None:
355
+ max_steps = min(max_steps, self.budget)
356
+
357
+ result = AgentResult(
358
+ case_id=case.case_id,
359
+ dataset=case.dataset,
360
+ prompt_variant=self.prompt_variant,
361
+ backend=self.client.model,
362
+ budget=max_steps,
363
+ )
364
+
365
+ acquired = []
366
+ dataset_channel_config = config.CHANNEL_CONFIGS.get(case.dataset, {})
367
+
368
+ system_prompt = (
369
+ "You are a medical diagnostic agent using a Thought-Action-Observation loop.\n\n"
370
+ "At each step:\n"
371
+ "1. THOUGHT: Reason about what you know and what you're uncertain about\n"
372
+ "2. ACTION: Either REQUEST[channel_name] to get more info, or "
373
+ "COMMIT[diagnosis1 > diagnosis2 > ...] to submit your final ranking\n"
374
+ "3. You will receive an OBSERVATION with the requested data\n\n"
375
+ "Be strategic about which information to request. Stop when additional "
376
+ "information is unlikely to change your diagnosis.\n\n"
377
+ "Format your response EXACTLY as:\n"
378
+ "THOUGHT: ...\n"
379
+ "ACTION: REQUEST[channel_name] or COMMIT[ranked diagnoses with probabilities]"
380
+ )
381
+
382
+ initial_context = format_acquired_info(case.get_text_context([]))
383
+ candidates_str = "\n".join(
384
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
385
+ )
386
+
387
+ # Build channel descriptions
388
+ channel_desc_lines = []
389
+ for name, ch in case.requestable_channels.items():
390
+ channel_desc_lines.append(
391
+ f" - {name}: {ch.description} (cost: ${ch.cost:,.0f})"
392
+ )
393
+ channel_desc = "\n".join(channel_desc_lines)
394
+
395
+ conversation_text = (
396
+ f"Initial information:\n{initial_context}\n\n"
397
+ f"Candidate diagnoses:\n{candidates_str}\n\n"
398
+ f"Available channels:\n{channel_desc}\n"
399
+ )
400
+
401
+ images = case.get_initial_images()
402
+
403
+ for step_idx in range(max_steps):
404
+ available = [n for n in case.requestable_names if n not in acquired]
405
+ if not available:
406
+ break
407
+
408
+ user_text = conversation_text
409
+ if acquired:
410
+ acq_context = format_acquired_info(case.get_text_context(acquired))
411
+ user_text += (
412
+ f"\n\nAcquired information so far:\n{acq_context}\n\n"
413
+ f"Remaining channels: {', '.join(available)}\n"
414
+ )
415
+
416
+ response = self.client.call_with_retry(
417
+ system_prompt=system_prompt,
418
+ user_text=user_text,
419
+ images=images,
420
+ temperature=config.TEMPERATURE,
421
+ max_tokens=config.MAX_TOKENS,
422
+ )
423
+
424
+ result.total_latency_ms += response.latency_ms
425
+ result.total_input_tokens += response.input_tokens
426
+ result.total_output_tokens += response.output_tokens
427
+
428
+ text = response.text
429
+
430
+ # Parse COMMIT
431
+ commit_match = re.search(
432
+ r"COMMIT\[(.+?)\]", text, re.DOTALL
433
+ )
434
+ if commit_match:
435
+ result.committed_early = True
436
+ result.final_ranking = self._parse_commit_text(
437
+ commit_match.group(1), case.candidates
438
+ )
439
+ result.final_raw_response = text
440
+
441
+ step = AcquisitionStep(
442
+ step=step_idx,
443
+ tool_call=None,
444
+ requested_channel=None,
445
+ reasoning=_extract_thought(text),
446
+ differential=result.final_ranking,
447
+ committed=True,
448
+ raw_response=text,
449
+ latency_ms=response.latency_ms,
450
+ )
451
+ result.steps.append(step)
452
+ break
453
+
454
+ # Parse REQUEST
455
+ request_match = re.search(
456
+ r"REQUEST\[(\w+)\]", text, re.IGNORECASE
457
+ )
458
+ if request_match:
459
+ requested = request_match.group(1).strip().lower()
460
+ matched = _match_channel_name(requested, available)
461
+ if matched is None:
462
+ matched = available[0]
463
+
464
+ acquired.append(matched)
465
+ result.acquired_channels.append(matched)
466
+
467
+ # Add new images if the channel is an image
468
+ ch = case.get_channel(matched)
469
+ if ch and ch.channel_type == "image" and ch.value:
470
+ if isinstance(ch.value, list):
471
+ images.extend(ch.value)
472
+ else:
473
+ images.append(ch.value)
474
+
475
+ step = AcquisitionStep(
476
+ step=step_idx,
477
+ tool_call=None,
478
+ requested_channel=matched,
479
+ reasoning=_extract_thought(text),
480
+ differential=[],
481
+ committed=False,
482
+ raw_response=text,
483
+ latency_ms=response.latency_ms,
484
+ )
485
+ result.steps.append(step)
486
+ else:
487
+ # No parseable action — fallback to first available
488
+ matched = available[0]
489
+ acquired.append(matched)
490
+ result.acquired_channels.append(matched)
491
+
492
+ step = AcquisitionStep(
493
+ step=step_idx,
494
+ tool_call=None,
495
+ requested_channel=matched,
496
+ reasoning=f"(unparseable response, fallback to {matched})",
497
+ differential=[],
498
+ committed=False,
499
+ raw_response=text,
500
+ latency_ms=response.latency_ms,
501
+ )
502
+ result.steps.append(step)
503
+
504
+ # Final diagnosis if not committed
505
+ if not result.committed_early or not result.final_ranking:
506
+ agent = ActiveMedAgent(self.client, self.prompt_variant, budget=0)
507
+ final_ranking, resp = agent.get_diagnosis_at_state(case, acquired)
508
+ result.final_ranking = final_ranking
509
+ result.final_raw_response = resp.text
510
+ result.total_latency_ms += resp.latency_ms
511
+ result.total_input_tokens += resp.input_tokens
512
+ result.total_output_tokens += resp.output_tokens
513
+
514
+ result.acquired_channels = acquired
515
+ result.acquisition_cost = case.get_acquisition_cost(acquired)
516
+ result.total_case_cost = case.get_total_cost(acquired)
517
+ return result
518
+
519
+ def _parse_commit_text(
520
+ self, commit_str: str, candidates: list[str]
521
+ ) -> list[dict]:
522
+ """Parse a COMMIT[...] string into a ranking."""
523
+ ranking = []
524
+ # Try "Diagnosis (0.XX)" pattern
525
+ pattern = r"([^>,(]+?)\s*\(?([\d.]+)\)?"
526
+ parts = re.split(r"\s*>\s*", commit_str)
527
+ for i, part in enumerate(parts):
528
+ match = re.match(pattern, part.strip())
529
+ if match:
530
+ name = match.group(1).strip()
531
+ try:
532
+ conf = float(match.group(2))
533
+ except (ValueError, IndexError):
534
+ conf = max(0.1, 1.0 - i * 0.2)
535
+ ranking.append({
536
+ "name": name,
537
+ "confidence": conf,
538
+ "rank": i + 1,
539
+ })
540
+
541
+ if not ranking:
542
+ ranking = _extract_ranking_from_text(commit_str, candidates)
543
+
544
+ ranking.sort(key=lambda x: x.get("confidence", 0), reverse=True)
545
+ for i, entry in enumerate(ranking):
546
+ entry["rank"] = i + 1
547
+ return ranking
548
+
549
+
550
+ # ================================================================
551
+ # Baseline 5: Chain-of-Thought Single Pass
552
+ # ================================================================
553
+
554
+ class CoTSinglePassBaseline:
555
+ """
556
+ Standard chain-of-thought: give the VLM all available info and
557
+ ask it to reason step by step in a single pass.
558
+
559
+ No multi-turn reasoning, no tool use, no acquisition decisions.
560
+ Just: "Here's everything, think step by step, give me your answer."
561
+ """
562
+
563
+ def __init__(self, client: BaseVLMClient, prompt_variant: str = "A"):
564
+ self.client = client
565
+ self.prompt_variant = prompt_variant
566
+
567
+ def diagnose(self, case: MedicalCase) -> AgentResult:
568
+ all_channels = list(case.requestable_channels.keys())
569
+
570
+ result = AgentResult(
571
+ case_id=case.case_id,
572
+ dataset=case.dataset,
573
+ prompt_variant=self.prompt_variant,
574
+ backend=self.client.model,
575
+ budget=len(all_channels),
576
+ acquired_channels=all_channels,
577
+ )
578
+
579
+ images = case.get_all_images_up_to(all_channels)
580
+ text_context = case.get_text_context(all_channels)
581
+ acquired_str = format_acquired_info(text_context)
582
+ candidates_str = "\n".join(
583
+ f" {i + 1}. {c}" for i, c in enumerate(case.candidates)
584
+ )
585
+
586
+ system_prompt = (
587
+ "You are a medical diagnostic expert. Analyze the following "
588
+ "clinical information and provide your diagnosis.\n\n"
589
+ "Think step by step:\n"
590
+ "1. Summarize the key findings\n"
591
+ "2. Consider each candidate diagnosis\n"
592
+ "3. Identify supporting and refuting evidence for each\n"
593
+ "4. Rank all candidates with calibrated probabilities (0-1, sum to 1)\n\n"
594
+ "Format your final answer as:\n"
595
+ "RANKING:\n"
596
+ "1. DiagnosisName (confidence: X.XX) - key evidence\n"
597
+ "2. DiagnosisName (confidence: X.XX) - key evidence\n"
598
+ "..."
599
+ )
600
+
601
+ user_text = (
602
+ f"Clinical information:\n{acquired_str}\n\n"
603
+ f"Candidate diagnoses:\n{candidates_str}\n\n"
604
+ f"Think step by step and provide your ranked diagnosis."
605
+ )
606
+
607
+ response = self.client.call_with_retry(
608
+ system_prompt=system_prompt,
609
+ user_text=user_text,
610
+ images=images,
611
+ temperature=config.TEMPERATURE,
612
+ max_tokens=config.MAX_TOKENS,
613
+ )
614
+
615
+ result.total_latency_ms = response.latency_ms
616
+ result.total_input_tokens = response.input_tokens
617
+ result.total_output_tokens = response.output_tokens
618
+ result.final_raw_response = response.text
619
+ result.final_ranking = _extract_ranking_from_text(
620
+ response.text, case.candidates
621
+ )
622
+ result.acquisition_cost = case.get_acquisition_cost(all_channels)
623
+ result.total_case_cost = case.get_total_cost(all_channels)
624
+ return result
625
+
626
+
627
+ # ================================================================
628
+ # Helpers
629
+ # ================================================================
630
+
631
+ def _extract_thought(text: str) -> str:
632
+ """Extract THOUGHT section from ReAct response."""
633
+ match = re.search(r"THOUGHT:\s*(.+?)(?=ACTION:|$)", text, re.DOTALL)
634
+ if match:
635
+ return match.group(1).strip()[:500]
636
+ return text[:200]
637
+
638
+
639
+ def _match_channel_name(requested: str, available: list[str]) -> str | None:
640
+ """Fuzzy match a requested channel name."""
641
+ requested = requested.lower().strip().replace(" ", "_")
642
+ if requested in available:
643
+ return requested
644
+ for ch in available:
645
+ if requested in ch or ch in requested:
646
+ return ch
647
+ return None
648
+
649
+
650
+ def _extract_ranking_from_text(
651
+ text: str, candidates: list[str]
652
+ ) -> list[dict]:
653
+ """Extract ranking from free-form text response."""
654
+ ranking = []
655
+ pattern = (
656
+ r"(\d+)\.\s*(.+?)\s*"
657
+ r"\((?:confidence|probability|prob|conf):\s*([\d.]+)\)"
658
+ )
659
+ matches = re.findall(pattern, text, re.IGNORECASE)
660
+ if matches:
661
+ for rank_str, name, conf_str in matches:
662
+ try:
663
+ ranking.append({
664
+ "name": name.strip(),
665
+ "confidence": float(conf_str),
666
+ "rank": int(rank_str),
667
+ })
668
+ except ValueError:
669
+ continue
670
+ if not ranking and candidates:
671
+ for i, candidate in enumerate(candidates):
672
+ if candidate.lower() in text.lower():
673
+ ranking.append({
674
+ "name": candidate,
675
+ "confidence": max(0.1, 1.0 - i * 0.2),
676
+ "rank": len(ranking) + 1,
677
+ })
678
+ ranking.sort(key=lambda x: x.get("confidence", 0), reverse=True)
679
+ for i, entry in enumerate(ranking):
680
+ entry["rank"] = i + 1
681
+ return ranking
682
+
683
+
684
+ # ================================================================
685
+ # Registry
686
+ # ================================================================
687
+
688
+ BASELINE_REGISTRY = {
689
+ "all_at_once": AllAtOnceBaseline,
690
+ "random_order": RandomOrderBaseline,
691
+ "clinical_guideline": ClinicalGuidelineBaseline,
692
+ "react": ReactBaseline,
693
+ "cot_single_pass": CoTSinglePassBaseline,
694
+ }
calibration.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calibration Analysis for ActiveMedAgent.
3
+
4
+ Measures whether the VLM's reported probabilities match empirical
5
+ accuracy. Key analyses for the ACL/EMNLP submission:
6
+
7
+ 1. Reliability Diagram: binned confidence vs accuracy
8
+ 2. Expected Calibration Error (ECE): scalar miscalibration summary
9
+ 3. Temperature Scaling: post-hoc recalibration on held-out set
10
+ 4. Robustness to Miscalibration: does the method work with noisy probs?
11
+ 5. Per-Step Calibration: is calibration better/worse at different steps?
12
+ """
13
+ import json
14
+ import logging
15
+ import math
16
+ from dataclasses import dataclass, field
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ from scipy.optimize import minimize_scalar
21
+
22
+ from agent import AgentResult, AcquisitionStep
23
+ from datasets.base import MedicalCase
24
+ from evaluation import evaluate_single_case, CaseMetrics
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ # ================================================================
30
+ # Core Calibration Metrics
31
+ # ================================================================
32
+
33
+ @dataclass
34
+ class CalibrationBin:
35
+ """A single bin in a reliability diagram."""
36
+ bin_lower: float
37
+ bin_upper: float
38
+ bin_center: float
39
+ avg_confidence: float
40
+ avg_accuracy: float
41
+ count: int
42
+ gap: float # |avg_confidence - avg_accuracy|
43
+
44
+
45
+ @dataclass
46
+ class CalibrationResult:
47
+ """Full calibration analysis for a set of predictions."""
48
+ ece: float # Expected Calibration Error
49
+ mce: float # Maximum Calibration Error
50
+ ace: float # Average Calibration Error
51
+ bins: list[CalibrationBin]
52
+ n_predictions: int
53
+ mean_confidence: float
54
+ mean_accuracy: float
55
+ overconfidence_ratio: float # Fraction of bins where conf > acc
56
+ brier_score: float # Brier score (MSE of probabilities)
57
+
58
+
59
+ def compute_calibration(
60
+ confidences: list[float],
61
+ correctness: list[bool],
62
+ n_bins: int = 10,
63
+ ) -> CalibrationResult:
64
+ """
65
+ Compute calibration metrics from confidence-correctness pairs.
66
+
67
+ Args:
68
+ confidences: Model's stated probability for its top prediction
69
+ correctness: Whether the top prediction was correct
70
+ n_bins: Number of bins for the reliability diagram
71
+
72
+ Returns:
73
+ CalibrationResult with ECE, MCE, bins, etc.
74
+ """
75
+ if not confidences:
76
+ return CalibrationResult(
77
+ ece=0, mce=0, ace=0, bins=[], n_predictions=0,
78
+ mean_confidence=0, mean_accuracy=0,
79
+ overconfidence_ratio=0, brier_score=0,
80
+ )
81
+
82
+ confs = np.array(confidences, dtype=np.float64)
83
+ accs = np.array(correctness, dtype=np.float64)
84
+ n = len(confs)
85
+
86
+ bin_boundaries = np.linspace(0.0, 1.0, n_bins + 1)
87
+ bins = []
88
+ ece = 0.0
89
+ mce = 0.0
90
+ overconf_count = 0
91
+
92
+ for i in range(n_bins):
93
+ lower = bin_boundaries[i]
94
+ upper = bin_boundaries[i + 1]
95
+ mask = (confs > lower) & (confs <= upper)
96
+ count = mask.sum()
97
+
98
+ if count == 0:
99
+ bins.append(CalibrationBin(
100
+ bin_lower=lower, bin_upper=upper,
101
+ bin_center=(lower + upper) / 2,
102
+ avg_confidence=0, avg_accuracy=0,
103
+ count=0, gap=0,
104
+ ))
105
+ continue
106
+
107
+ avg_conf = confs[mask].mean()
108
+ avg_acc = accs[mask].mean()
109
+ gap = abs(avg_conf - avg_acc)
110
+
111
+ ece += (count / n) * gap
112
+ mce = max(mce, gap)
113
+
114
+ if avg_conf > avg_acc:
115
+ overconf_count += 1
116
+
117
+ bins.append(CalibrationBin(
118
+ bin_lower=lower, bin_upper=upper,
119
+ bin_center=(lower + upper) / 2,
120
+ avg_confidence=float(avg_conf),
121
+ avg_accuracy=float(avg_acc),
122
+ count=int(count),
123
+ gap=float(gap),
124
+ ))
125
+
126
+ non_empty_bins = [b for b in bins if b.count > 0]
127
+ ace = np.mean([b.gap for b in non_empty_bins]) if non_empty_bins else 0.0
128
+
129
+ # Brier score
130
+ brier = np.mean((confs - accs) ** 2)
131
+
132
+ return CalibrationResult(
133
+ ece=float(ece),
134
+ mce=float(mce),
135
+ ace=float(ace),
136
+ bins=bins,
137
+ n_predictions=n,
138
+ mean_confidence=float(confs.mean()),
139
+ mean_accuracy=float(accs.mean()),
140
+ overconfidence_ratio=overconf_count / len(non_empty_bins) if non_empty_bins else 0,
141
+ brier_score=float(brier),
142
+ )
143
+
144
+
145
+ # ================================================================
146
+ # Extract Predictions from Agent Results
147
+ # ================================================================
148
+
149
+ def extract_predictions(
150
+ results: list[AgentResult],
151
+ cases: list[MedicalCase],
152
+ ) -> tuple[list[float], list[bool]]:
153
+ """
154
+ Extract (confidence, correctness) pairs from agent results.
155
+
156
+ Returns:
157
+ confidences: top-1 stated probability
158
+ correctness: whether top-1 matches ground truth
159
+ """
160
+ confidences = []
161
+ correctness = []
162
+
163
+ for result, case in zip(results, cases):
164
+ if not result.final_ranking:
165
+ continue
166
+
167
+ top = result.final_ranking[0]
168
+ conf = top.get("confidence", 0.0)
169
+ name = top.get("name", "").strip().lower()
170
+ gt = case.ground_truth.strip().lower()
171
+
172
+ correct = name == gt or name in gt or gt in name
173
+
174
+ confidences.append(conf)
175
+ correctness.append(correct)
176
+
177
+ return confidences, correctness
178
+
179
+
180
+ def extract_per_step_predictions(
181
+ results: list[AgentResult],
182
+ cases: list[MedicalCase],
183
+ ) -> dict[int, tuple[list[float], list[bool]]]:
184
+ """
185
+ Extract predictions at each acquisition step.
186
+
187
+ Returns:
188
+ {step_idx: (confidences, correctness)}
189
+ """
190
+ step_data: dict[int, tuple[list, list]] = {}
191
+
192
+ for result, case in zip(results, cases):
193
+ gt = case.ground_truth.strip().lower()
194
+
195
+ for step in result.steps:
196
+ if not step.differential:
197
+ continue
198
+
199
+ idx = step.step
200
+ if idx not in step_data:
201
+ step_data[idx] = ([], [])
202
+
203
+ top = max(step.differential, key=lambda d: d.get("confidence", 0))
204
+ conf = top.get("confidence", 0.0)
205
+ name = top.get("name", "").strip().lower()
206
+ correct = name == gt or name in gt or gt in name
207
+
208
+ step_data[idx][0].append(conf)
209
+ step_data[idx][1].append(correct)
210
+
211
+ return step_data
212
+
213
+
214
+ # ================================================================
215
+ # Temperature Scaling
216
+ # ================================================================
217
+
218
+ def temperature_scale(
219
+ confidences: list[float],
220
+ correctness: list[bool],
221
+ candidates_per_case: list[int] = None,
222
+ ) -> tuple[float, float]:
223
+ """
224
+ Find optimal temperature T that minimizes ECE on held-out data.
225
+
226
+ Temperature scaling: p_calibrated = softmax(logit(p) / T)
227
+ For single top-1 probability, we use the simplified version:
228
+ logit = log(p / (1 - p))
229
+ scaled_logit = logit / T
230
+ p_scaled = sigmoid(scaled_logit)
231
+
232
+ Args:
233
+ confidences: Raw model confidences
234
+ correctness: Whether predictions were correct
235
+ candidates_per_case: Number of candidates per case (for proper scaling)
236
+
237
+ Returns:
238
+ (optimal_temperature, calibrated_ece)
239
+ """
240
+ confs = np.array(confidences, dtype=np.float64)
241
+ accs = np.array(correctness, dtype=np.float64)
242
+
243
+ # Clip to avoid log(0)
244
+ confs = np.clip(confs, 1e-6, 1 - 1e-6)
245
+ logits = np.log(confs / (1 - confs))
246
+
247
+ def ece_at_temperature(T):
248
+ scaled_logits = logits / T
249
+ scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
250
+ # Compute ECE
251
+ n_bins = 10
252
+ bins = np.linspace(0, 1, n_bins + 1)
253
+ ece = 0.0
254
+ n = len(scaled_confs)
255
+ for i in range(n_bins):
256
+ mask = (scaled_confs > bins[i]) & (scaled_confs <= bins[i + 1])
257
+ if mask.sum() == 0:
258
+ continue
259
+ bin_conf = scaled_confs[mask].mean()
260
+ bin_acc = accs[mask].mean()
261
+ ece += (mask.sum() / n) * abs(bin_conf - bin_acc)
262
+ return ece
263
+
264
+ result = minimize_scalar(
265
+ ece_at_temperature,
266
+ bounds=(0.1, 10.0),
267
+ method="bounded",
268
+ )
269
+
270
+ optimal_T = result.x
271
+ calibrated_ece = ece_at_temperature(optimal_T)
272
+
273
+ return float(optimal_T), float(calibrated_ece)
274
+
275
+
276
+ def apply_temperature(
277
+ confidences: list[float], temperature: float
278
+ ) -> list[float]:
279
+ """Apply temperature scaling to a list of confidences."""
280
+ confs = np.array(confidences, dtype=np.float64)
281
+ confs = np.clip(confs, 1e-6, 1 - 1e-6)
282
+ logits = np.log(confs / (1 - confs))
283
+ scaled_logits = logits / temperature
284
+ scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
285
+ return scaled_confs.tolist()
286
+
287
+
288
+ # ================================================================
289
+ # Robustness to Miscalibration
290
+ # ================================================================
291
+
292
+ def test_calibration_robustness(
293
+ results: list[AgentResult],
294
+ cases: list[MedicalCase],
295
+ noise_levels: list[float] = None,
296
+ n_trials: int = 10,
297
+ seed: int = 42,
298
+ ) -> dict[float, dict]:
299
+ """
300
+ Test whether the agent's acquisition decisions are robust to
301
+ probability miscalibration.
302
+
303
+ For each noise level, we perturb the agent's reported probabilities
304
+ and check if the same acquisition order and stopping decisions
305
+ would be made.
306
+
307
+ Args:
308
+ noise_levels: Standard deviations of Gaussian noise to add to logits
309
+ n_trials: Number of random trials per noise level
310
+
311
+ Returns:
312
+ {noise_level: {order_stability, stop_stability, ...}}
313
+ """
314
+ if noise_levels is None:
315
+ noise_levels = [0.0, 0.1, 0.25, 0.5, 1.0, 2.0]
316
+
317
+ rng = np.random.RandomState(seed)
318
+ robustness = {}
319
+
320
+ # Collect original acquisition orders and stopping points
321
+ original_orders = []
322
+ original_stop_steps = []
323
+ original_distributions = []
324
+
325
+ for result in results:
326
+ original_orders.append(tuple(result.acquired_channels))
327
+ original_stop_steps.append(len(result.acquired_channels))
328
+
329
+ step_dists = []
330
+ for step in result.steps:
331
+ if step.differential:
332
+ dist = {
333
+ d.get("name", ""): d.get("confidence", 0)
334
+ for d in step.differential
335
+ }
336
+ step_dists.append(dist)
337
+ original_distributions.append(step_dists)
338
+
339
+ for noise in noise_levels:
340
+ order_matches = 0
341
+ stop_matches = 0
342
+ total = len(results)
343
+
344
+ if noise == 0.0:
345
+ robustness[noise] = {
346
+ "order_stability": 1.0,
347
+ "stop_stability": 1.0,
348
+ "mean_rank_correlation": 1.0,
349
+ "n_cases": total,
350
+ }
351
+ continue
352
+
353
+ rank_correlations = []
354
+
355
+ for trial in range(n_trials):
356
+ trial_order_matches = 0
357
+ trial_stop_matches = 0
358
+ trial_rank_corrs = []
359
+
360
+ for i, (result, dists) in enumerate(
361
+ zip(results, original_distributions)
362
+ ):
363
+ if not dists:
364
+ continue
365
+
366
+ # Perturb each step's distribution
367
+ perturbed_orders = []
368
+ for dist in dists:
369
+ names = list(dist.keys())
370
+ probs = np.array(list(dist.values()), dtype=np.float64)
371
+ probs = np.clip(probs, 1e-6, 1 - 1e-6)
372
+
373
+ # Add noise in logit space
374
+ logits = np.log(probs / (1 - probs))
375
+ noisy_logits = logits + rng.normal(0, noise, len(logits))
376
+ noisy_probs = 1.0 / (1.0 + np.exp(-noisy_logits))
377
+ noisy_probs /= noisy_probs.sum()
378
+
379
+ # Check if ranking order is preserved
380
+ orig_order = np.argsort(-probs)
381
+ noisy_order = np.argsort(-noisy_probs)
382
+
383
+ # Spearman rank correlation
384
+ from scipy.stats import spearmanr
385
+ if len(orig_order) > 1:
386
+ corr, _ = spearmanr(orig_order, noisy_order)
387
+ trial_rank_corrs.append(corr)
388
+
389
+ # Check if acquisition order would be same
390
+ if tuple(result.acquired_channels) == original_orders[i]:
391
+ trial_order_matches += 1
392
+ trial_stop_matches += 1 # Simplified — count all
393
+
394
+ if total > 0:
395
+ order_matches += trial_order_matches / total
396
+ stop_matches += trial_stop_matches / total
397
+ if trial_rank_corrs:
398
+ rank_correlations.extend(trial_rank_corrs)
399
+
400
+ robustness[noise] = {
401
+ "order_stability": order_matches / n_trials if n_trials > 0 else 0,
402
+ "stop_stability": stop_matches / n_trials if n_trials > 0 else 0,
403
+ "mean_rank_correlation": float(np.mean(rank_correlations)) if rank_correlations else 1.0,
404
+ "n_cases": total,
405
+ }
406
+
407
+ return robustness
408
+
409
+
410
+ # ================================================================
411
+ # Full Calibration Analysis Pipeline
412
+ # ================================================================
413
+
414
+ def run_calibration_analysis(
415
+ results: list[AgentResult],
416
+ cases: list[MedicalCase],
417
+ save_dir: Path = None,
418
+ ) -> dict:
419
+ """
420
+ Run the complete calibration analysis suite.
421
+
422
+ Returns a dict with all metrics and saves to disk if save_dir provided.
423
+ """
424
+ logger.info("Running calibration analysis...")
425
+
426
+ # 1. Overall calibration
427
+ confidences, correctness = extract_predictions(results, cases)
428
+ overall = compute_calibration(confidences, correctness)
429
+
430
+ logger.info(f" ECE: {overall.ece:.4f}")
431
+ logger.info(f" MCE: {overall.mce:.4f}")
432
+ logger.info(f" Brier Score: {overall.brier_score:.4f}")
433
+ logger.info(f" Mean Confidence: {overall.mean_confidence:.3f}")
434
+ logger.info(f" Mean Accuracy: {overall.mean_accuracy:.3f}")
435
+ logger.info(f" Overconfidence Ratio: {overall.overconfidence_ratio:.2f}")
436
+
437
+ # 2. Temperature scaling
438
+ if len(confidences) >= 10:
439
+ # Split into calibration and test sets
440
+ n = len(confidences)
441
+ mid = n // 2
442
+ cal_confs, cal_correct = confidences[:mid], correctness[:mid]
443
+ test_confs, test_correct = confidences[mid:], correctness[mid:]
444
+
445
+ opt_T, cal_ece = temperature_scale(cal_confs, cal_correct)
446
+ scaled_test = apply_temperature(test_confs, opt_T)
447
+ post_cal = compute_calibration(scaled_test, test_correct)
448
+
449
+ logger.info(f" Optimal Temperature: {opt_T:.3f}")
450
+ logger.info(f" Post-calibration ECE: {post_cal.ece:.4f}")
451
+ else:
452
+ opt_T = 1.0
453
+ post_cal = overall
454
+
455
+ # 3. Per-step calibration
456
+ step_data = extract_per_step_predictions(results, cases)
457
+ per_step_cal = {}
458
+ for step_idx, (step_confs, step_correct) in sorted(step_data.items()):
459
+ if len(step_confs) >= 5:
460
+ step_cal = compute_calibration(step_confs, step_correct, n_bins=5)
461
+ per_step_cal[step_idx] = {
462
+ "ece": step_cal.ece,
463
+ "mean_confidence": step_cal.mean_confidence,
464
+ "mean_accuracy": step_cal.mean_accuracy,
465
+ "n_predictions": step_cal.n_predictions,
466
+ }
467
+ logger.info(
468
+ f" Step {step_idx}: ECE={step_cal.ece:.4f}, "
469
+ f"Conf={step_cal.mean_confidence:.3f}, "
470
+ f"Acc={step_cal.mean_accuracy:.3f} (n={step_cal.n_predictions})"
471
+ )
472
+
473
+ # 4. Robustness analysis
474
+ robustness = test_calibration_robustness(results, cases)
475
+ for noise, metrics in robustness.items():
476
+ logger.info(
477
+ f" Noise={noise:.2f}: rank_corr={metrics['mean_rank_correlation']:.3f}"
478
+ )
479
+
480
+ # Compile output
481
+ output = {
482
+ "overall": {
483
+ "ece": overall.ece,
484
+ "mce": overall.mce,
485
+ "ace": overall.ace,
486
+ "brier_score": overall.brier_score,
487
+ "mean_confidence": overall.mean_confidence,
488
+ "mean_accuracy": overall.mean_accuracy,
489
+ "overconfidence_ratio": overall.overconfidence_ratio,
490
+ "n_predictions": overall.n_predictions,
491
+ "bins": [
492
+ {
493
+ "center": b.bin_center,
494
+ "confidence": b.avg_confidence,
495
+ "accuracy": b.avg_accuracy,
496
+ "count": b.count,
497
+ "gap": b.gap,
498
+ }
499
+ for b in overall.bins
500
+ ],
501
+ },
502
+ "temperature_scaling": {
503
+ "optimal_temperature": opt_T,
504
+ "pre_calibration_ece": overall.ece,
505
+ "post_calibration_ece": post_cal.ece,
506
+ },
507
+ "per_step_calibration": per_step_cal,
508
+ "robustness": {
509
+ str(k): v for k, v in robustness.items()
510
+ },
511
+ }
512
+
513
+ if save_dir:
514
+ save_dir.mkdir(parents=True, exist_ok=True)
515
+ with open(save_dir / "calibration_analysis.json", "w") as f:
516
+ json.dump(output, f, indent=2)
517
+ logger.info(f" Saved to {save_dir / 'calibration_analysis.json'}")
518
+
519
+ return output
config.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for ActiveMedAgent experiments.
3
+ """
4
+ import os
5
+ from pathlib import Path
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ # ============================================================
11
+ # API Configuration
12
+ # ============================================================
13
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
14
+ ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
15
+ TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY", "")
16
+
17
+ # Model identifiers per backend
18
+ MODELS = {
19
+ "openai": "gpt-4o-2024-11-20",
20
+ "openai_mini": "gpt-4o-mini",
21
+ "anthropic": "claude-sonnet-4-20250514",
22
+ "together": "Qwen/Qwen2.5-VL-72B-Instruct",
23
+ }
24
+
25
+ # Rate limiting (requests per minute)
26
+ RATE_LIMITS = {
27
+ "openai": 30,
28
+ "openai_mini": 60,
29
+ "anthropic": 30,
30
+ "together": 20,
31
+ }
32
+
33
+ # Max tokens for generation — tool calls produce structured JSON with
34
+ # probability distributions, evidence chains, and expected impact analysis,
35
+ # which requires more tokens than free-text responses.
36
+ MAX_TOKENS = 4096
37
+
38
+ # Temperature — low for reproducibility
39
+ TEMPERATURE = 0.1
40
+
41
+ # ============================================================
42
+ # Dataset Paths (update these to your local paths)
43
+ # ============================================================
44
+ DATA_ROOT = Path(os.getenv("DATA_ROOT", "./data"))
45
+
46
+ DATASET_PATHS = {
47
+ "midas": DATA_ROOT / "midas",
48
+ "nejm": DATA_ROOT / "nejm",
49
+ "olives": DATA_ROOT / "OLIVES",
50
+ }
51
+
52
+ # ============================================================
53
+ # Experiment Configuration
54
+ # ============================================================
55
+
56
+ # Prompt variants for robustness analysis (see prompts.py)
57
+ PROMPT_VARIANTS = ["A", "B", "C"]
58
+
59
+ # Default backends to run
60
+ DEFAULT_BACKENDS = ["openai"]
61
+
62
+ # Context management mode for the acquisition loop.
63
+ # "full" — keep entire multi-turn conversation history (best for capable models)
64
+ # "condensed" — each turn gets a fresh single-turn call with a compressed state
65
+ # summary (best for weaker/smaller models that lose track in long context)
66
+ # "adaptive" — auto-select based on model: "full" for GPT-4o/Claude/Qwen-72B,
67
+ # "condensed" for GPT-4o-mini and other small models
68
+ CONTEXT_MODE = "adaptive"
69
+
70
+ # Models that should use condensed context (too weak for long multi-turn)
71
+ CONDENSED_MODELS = {
72
+ "gpt-4o-mini",
73
+ }
74
+
75
+ # Early commit threshold — agent may commit if top diagnosis probability exceeds
76
+ # this AND the gap to #2 exceeds COMMIT_GAP_THRESHOLD
77
+ COMMIT_CONFIDENCE_THRESHOLD = 0.85
78
+ COMMIT_GAP_THRESHOLD = 0.30
79
+
80
+ # Number of bootstrap resamples for confidence intervals
81
+ N_BOOTSTRAP = 1000
82
+
83
+ # Random seed
84
+ SEED = 42
85
+
86
+ # Cost penalty strength for learned policies.
87
+ # Utility reward = diagnostic improvement - lambda * normalized_channel_cost
88
+ COST_PENALTY_LAMBDA = float(os.getenv("COST_PENALTY_LAMBDA", "0.05"))
89
+
90
+ # ============================================================
91
+ # Dataset-Specific Channel Definitions
92
+ # ============================================================
93
+
94
+ MIDAS_CHANNELS = {
95
+ "patient_demographics": {
96
+ "description": "Patient age, sex, and Fitzpatrick skin type",
97
+ "type": "text",
98
+ "always_given": True,
99
+ "tier": "free",
100
+ "cost": 0.0,
101
+ "order": 0,
102
+ },
103
+ "lesion_metadata": {
104
+ "description": "Anatomic location, lesion length and width",
105
+ "type": "text",
106
+ "always_given": True,
107
+ "tier": "cheap",
108
+ "cost": 25.0,
109
+ "order": 1,
110
+ },
111
+ "clinical_30cm": {
112
+ "description": "Clinical photograph taken at 30cm distance",
113
+ "type": "image",
114
+ "always_given": False,
115
+ "tier": "moderate",
116
+ "cost": 50.0,
117
+ "order": 2,
118
+ },
119
+ "clinical_15cm": {
120
+ "description": "Clinical photograph taken at 15cm distance (closer view)",
121
+ "type": "image",
122
+ "always_given": False,
123
+ "tier": "moderate",
124
+ "cost": 50.0,
125
+ "order": 3,
126
+ },
127
+ "dermoscopy": {
128
+ "description": "Dermoscopic image showing subsurface skin structures",
129
+ "type": "image",
130
+ "always_given": False,
131
+ "tier": "expensive",
132
+ "cost": 250.0,
133
+ "order": 4,
134
+ },
135
+ }
136
+
137
+ NEJM_CHANNELS = {
138
+ "demographics": {
139
+ "description": "Patient age, sex, and ethnicity if mentioned",
140
+ "type": "text",
141
+ "always_given": True,
142
+ "tier": "free",
143
+ "cost": 0.0,
144
+ "order": 0,
145
+ },
146
+ "chief_complaint": {
147
+ "description": "The presenting symptom(s) and their duration",
148
+ "type": "text",
149
+ "always_given": True,
150
+ "tier": "free",
151
+ "cost": 0.0,
152
+ "order": 1,
153
+ },
154
+ "medical_history": {
155
+ "description": "Past medical conditions, medications, family and social history",
156
+ "type": "text",
157
+ "always_given": True,
158
+ "tier": "free",
159
+ "cost": 0.0,
160
+ "order": 2,
161
+ },
162
+ "exam_findings": {
163
+ "description": "Physical examination results and observations",
164
+ "type": "text",
165
+ "always_given": False,
166
+ "tier": "cheap",
167
+ "cost": 75.0,
168
+ "order": 3,
169
+ },
170
+ "investigations": {
171
+ "description": "Laboratory values, prior imaging results, and test outcomes",
172
+ "type": "text",
173
+ "always_given": False,
174
+ "tier": "moderate",
175
+ "cost": 250.0,
176
+ "order": 4,
177
+ },
178
+ "image": {
179
+ "description": "The primary diagnostic image",
180
+ "type": "image",
181
+ "always_given": False,
182
+ "tier": "expensive",
183
+ "cost": 800.0,
184
+ "order": 5,
185
+ },
186
+ }
187
+
188
+ OLIVES_CHANNELS = {
189
+ "disease_context": {
190
+ "description": "Disease type and treatment context",
191
+ "type": "text",
192
+ "always_given": True,
193
+ "tier": "free",
194
+ "cost": 0.0,
195
+ "order": 0,
196
+ },
197
+ "clinical_measurements": {
198
+ "description": "Best Corrected Visual Acuity (BCVA) and Central Subfield Thickness (CST)",
199
+ "type": "text",
200
+ "always_given": False,
201
+ "tier": "cheap",
202
+ "cost": 20.0,
203
+ "order": 1,
204
+ },
205
+ "biomarker_hints": {
206
+ "description": "Expert-graded presence of retinal biomarkers (partial list)",
207
+ "type": "text",
208
+ "always_given": False,
209
+ "tier": "moderate",
210
+ "cost": 100.0,
211
+ "order": 2,
212
+ },
213
+ "oct_scan": {
214
+ "description": "Optical Coherence Tomography B-scan showing retinal cross-section",
215
+ "type": "image",
216
+ "always_given": False,
217
+ "tier": "expensive",
218
+ "cost": 300.0,
219
+ "order": 3,
220
+ },
221
+ "additional_oct": {
222
+ "description": "Additional OCT B-scans from different retinal locations",
223
+ "type": "image",
224
+ "always_given": False,
225
+ "tier": "very_expensive",
226
+ "cost": 150.0,
227
+ "order": 4,
228
+ },
229
+ }
230
+
231
+ CHANNEL_CONFIGS = {
232
+ "midas": MIDAS_CHANNELS,
233
+ "nejm": NEJM_CHANNELS,
234
+ "olives": OLIVES_CHANNELS,
235
+ }
236
+
237
+ # ============================================================
238
+ # OLIVES Biomarker Tier Definitions
239
+ # ============================================================
240
+
241
+ OLIVES_BIOMARKER_TIERS = {
242
+ "fundus_visible": [
243
+ "hard_exudates",
244
+ "hemorrhage",
245
+ "microaneurysms",
246
+ "cotton_wool_spots",
247
+ ],
248
+ "oct_dependent": [
249
+ "fluid_irf", # Intraretinal fluid
250
+ "fluid_srf", # Subretinal fluid
251
+ "dril", # Disorganization of retinal inner layers
252
+ "ez_disruption", # Ellipsoid zone disruption
253
+ "ez_absent",
254
+ "drt_me", # Diffuse retinal thickening / macular edema
255
+ "shrm", # Subretinal hyperreflective material
256
+ "full_thickness", # Full thickness involvement
257
+ "preretinal_tissue",
258
+ "vitreous_debris",
259
+ ],
260
+ "clinical_dependent": [
261
+ "drt_me", # Also correlates with CST
262
+ ],
263
+ }
264
+
265
+ # ============================================================
266
+ # Results / Logging
267
+ # ============================================================
268
+ RESULTS_DIR = Path(os.getenv("RESULTS_DIR", "./results"))
269
+ RESULTS_DIR.mkdir(parents=True, exist_ok=True)
270
+
271
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
272
+
273
+
274
+ def get_channel_definition(dataset: str, channel_name: str) -> dict:
275
+ """Return canonical metadata for a dataset channel."""
276
+ return CHANNEL_CONFIGS.get(dataset, {}).get(channel_name, {})
datasets/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import MedicalCase, DatasetBase
2
+ from .midas import MIDASDataset
3
+ from .nejm import NEJMDataset
4
+ from .olives import OLIVESDataset
5
+
6
+ DATASET_REGISTRY = {
7
+ "midas": MIDASDataset,
8
+ "nejm": NEJMDataset,
9
+ "olives": OLIVESDataset,
10
+ }
11
+
12
+
13
+ def load_dataset(name: str, **kwargs) -> DatasetBase:
14
+ """Load a dataset by name."""
15
+ if name not in DATASET_REGISTRY:
16
+ raise ValueError(f"Unknown dataset: {name}. Choose from {list(DATASET_REGISTRY.keys())}")
17
+ return DATASET_REGISTRY[name](**kwargs)
datasets/base.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for medical datasets in the ActiveMedAgent framework.
3
+
4
+ Every dataset must expose cases in a unified format:
5
+ - An initial observation (always-given channels)
6
+ - A set of requestable channels (additional info the agent can acquire)
7
+ - A candidate list (diagnoses to rank)
8
+ - Ground truth (correct ranking)
9
+ """
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+
16
+ @dataclass
17
+ class ChannelData:
18
+ """A single information channel's content."""
19
+ name: str
20
+ channel_type: str # "image" or "text"
21
+ description: str # Human-readable description of this channel
22
+ value: Any = None # Text content (str) or base64-encoded image (str)
23
+ image_path: Path | None = None # Original image path if applicable
24
+ cost: float = 0.0
25
+ tier: str = "unknown"
26
+ always_given: bool = False
27
+
28
+
29
+ @dataclass
30
+ class MedicalCase:
31
+ """
32
+ A single diagnostic case in the unified format.
33
+
34
+ The agent starts with `initial_channels` and can request from
35
+ `requestable_channels`. It must produce a ranked list over `candidates`.
36
+ """
37
+ case_id: str
38
+ dataset: str # "midas", "nejm", "olives"
39
+ initial_channels: dict[str, ChannelData] = field(default_factory=dict)
40
+ requestable_channels: dict[str, ChannelData] = field(default_factory=dict)
41
+ candidates: list[str] = field(default_factory=list)
42
+ ground_truth: str = "" # Correct diagnosis label
43
+ ground_truth_rank: int = 0 # Index in candidates (0-indexed)
44
+ metadata: dict[str, Any] = field(default_factory=dict)
45
+
46
+ @property
47
+ def all_channel_names(self) -> list[str]:
48
+ return list(self.initial_channels.keys()) + list(self.requestable_channels.keys())
49
+
50
+ @property
51
+ def requestable_names(self) -> list[str]:
52
+ return list(self.requestable_channels.keys())
53
+
54
+ def get_channel(self, name: str) -> ChannelData | None:
55
+ """Retrieve a channel by name from either initial or requestable."""
56
+ if name in self.initial_channels:
57
+ return self.initial_channels[name]
58
+ if name in self.requestable_channels:
59
+ return self.requestable_channels[name]
60
+ return None
61
+
62
+ def get_initial_images(self) -> list[str]:
63
+ """Get base64-encoded images from initial channels."""
64
+ images = []
65
+ for ch in self.initial_channels.values():
66
+ if ch.channel_type == "image" and ch.value is not None:
67
+ images.append(ch.value)
68
+ return images
69
+
70
+ def get_all_images_up_to(self, acquired: list[str]) -> list[str]:
71
+ """Get all images from initial + acquired channels."""
72
+ images = self.get_initial_images()
73
+ for name in acquired:
74
+ ch = self.get_channel(name)
75
+ if ch and ch.channel_type == "image" and ch.value is not None:
76
+ if isinstance(ch.value, list):
77
+ images.extend(ch.value)
78
+ else:
79
+ images.append(ch.value)
80
+ return images
81
+
82
+ def get_text_context(self, acquired: list[str]) -> dict[str, dict]:
83
+ """Get all text info from initial + acquired channels."""
84
+ context = {}
85
+ for name, ch in self.initial_channels.items():
86
+ if ch.channel_type == "text" and ch.value:
87
+ context[name] = {"type": "text", "value": ch.value}
88
+ elif ch.channel_type == "image":
89
+ context[name] = {"type": "image", "value": "(image provided)"}
90
+ for name in acquired:
91
+ ch = self.get_channel(name)
92
+ if ch:
93
+ if ch.channel_type == "text" and ch.value:
94
+ context[name] = {"type": "text", "value": ch.value}
95
+ elif ch.channel_type == "image":
96
+ context[name] = {"type": "image", "value": "(image provided)"}
97
+ return context
98
+
99
+ def get_channel_cost(self, name: str) -> float:
100
+ """Return the configured acquisition cost for a channel."""
101
+ ch = self.get_channel(name)
102
+ return float(ch.cost) if ch else 0.0
103
+
104
+ def get_initial_cost(self) -> float:
105
+ """Total cost of channels already available at case start."""
106
+ return float(sum(ch.cost for ch in self.initial_channels.values()))
107
+
108
+ def get_acquisition_cost(self, acquired: list[str]) -> float:
109
+ """Total incremental cost of acquired requestable channels."""
110
+ return float(sum(self.get_channel_cost(name) for name in acquired))
111
+
112
+ def get_total_cost(self, acquired: list[str]) -> float:
113
+ """Initial cost plus any additional acquired channels."""
114
+ return self.get_initial_cost() + self.get_acquisition_cost(acquired)
115
+
116
+ def get_max_requestable_cost(self) -> float:
117
+ """Upper bound if every requestable channel were acquired."""
118
+ return float(sum(ch.cost for ch in self.requestable_channels.values()))
119
+
120
+
121
+ class DatasetBase(ABC):
122
+ """Abstract base class for dataset loaders."""
123
+
124
+ def __init__(self, data_dir: str | Path, split: str = "test"):
125
+ self.data_dir = Path(data_dir)
126
+ self.split = split
127
+ self.cases: list[MedicalCase] = []
128
+
129
+ @abstractmethod
130
+ def load(self) -> list[MedicalCase]:
131
+ """Load and return all cases in unified format."""
132
+ pass
133
+
134
+ def __len__(self) -> int:
135
+ return len(self.cases)
136
+
137
+ def __getitem__(self, idx: int) -> MedicalCase:
138
+ return self.cases[idx]
139
+
140
+ def __iter__(self):
141
+ return iter(self.cases)
142
+
143
+ @abstractmethod
144
+ def get_name(self) -> str:
145
+ """Return dataset identifier string."""
146
+ pass
datasets/midas.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIDAS (MRA-MIDAS) Dataset Loader.
3
+
4
+ Actual Stanford AIMI MIDAS dataset structure:
5
+ midas/
6
+ ├── images/ (flat directory of all images)
7
+ │ ├── s-prd-398966407.jpg
8
+ │ └── ...
9
+ └── release_midas.xlsx (metadata with midas_record_id grouping)
10
+
11
+ Each record_id groups images of one lesion at multiple modalities:
12
+ - midas_distance='1ft' → clinical_30cm
13
+ - midas_distance='6in' → clinical_15cm
14
+ - midas_distance='dscope' → dermoscopy
15
+
16
+ Each case becomes a multi-channel acquisition problem:
17
+ - Initial: patient_demographics (free tier)
18
+ - Requestable: clinical_30cm, clinical_15cm, dermoscopy, lesion_metadata
19
+ """
20
+ import csv
21
+ import hashlib
22
+ import json
23
+ import logging
24
+ import random
25
+ from pathlib import Path
26
+ from collections import Counter, defaultdict
27
+
28
+ from .base import DatasetBase, MedicalCase, ChannelData
29
+ from api_client import encode_image_to_base64
30
+ import config
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Map raw midas_distance values to our channel names
35
+ DISTANCE_TO_CHANNEL = {
36
+ "1ft": "clinical_30cm",
37
+ "6in": "clinical_15cm",
38
+ "dscope": "dermoscopy",
39
+ }
40
+
41
+ # Map raw midas_path values to canonical diagnosis names
42
+ PATH_TO_DIAGNOSIS = {
43
+ "malignant- bcc": "basal_cell_carcinoma",
44
+ "malignant- melanoma": "melanoma_invasive",
45
+ "malignant- scc": "squamous_cell_carcinoma",
46
+ "malignant- sccis": "squamous_cell_carcinoma_in_situ",
47
+ "malignant- ak": "actinic_keratosis",
48
+ "benign-melanocytic nevus": "melanocytic_nevus",
49
+ "benign-seborrheic keratosis": "seborrheic_keratosis",
50
+ "benign-other": "benign_other",
51
+ "other- melanocytic lesion, possible re-excision (severe, spitz, aimp)": "dysplastic_nevus",
52
+ }
53
+
54
+ # MIDAS diagnosis taxonomy — grouped for candidate generation
55
+ MIDAS_DIAGNOSIS_GROUPS = {
56
+ "malignant_melanocytic": [
57
+ "melanoma_invasive",
58
+ "melanoma_in_situ",
59
+ ],
60
+ "benign_melanocytic": [
61
+ "melanocytic_nevus",
62
+ "dysplastic_nevus",
63
+ "blue_nevus",
64
+ "spitz_nevus",
65
+ ],
66
+ "malignant_nonmelanocytic": [
67
+ "basal_cell_carcinoma",
68
+ "squamous_cell_carcinoma",
69
+ "squamous_cell_carcinoma_in_situ",
70
+ "actinic_keratosis",
71
+ ],
72
+ "benign_nonmelanocytic": [
73
+ "seborrheic_keratosis",
74
+ "dermatofibroma",
75
+ "angioma",
76
+ "solar_lentigo",
77
+ "benign_other",
78
+ ],
79
+ "inflammatory": [
80
+ "eczema",
81
+ "psoriasis",
82
+ "lichen_planus",
83
+ ],
84
+ }
85
+
86
+ # Flattened list of all possible diagnoses
87
+ ALL_DIAGNOSES = []
88
+ for group in MIDAS_DIAGNOSIS_GROUPS.values():
89
+ ALL_DIAGNOSES.extend(group)
90
+
91
+
92
+ def _case_rng(case_id: str) -> random.Random:
93
+ """Create a deterministic RNG seeded by case ID for reproducible candidate generation."""
94
+ seed = int(hashlib.sha256(case_id.encode()).hexdigest()[:8], 16)
95
+ return random.Random(seed)
96
+
97
+
98
+ class MIDASDataset(DatasetBase):
99
+ """Loader for MRA-MIDAS dermatology dataset."""
100
+
101
+ def __init__(self, data_dir: str | Path = None, split: str = "test", n_candidates: int = 5):
102
+ super().__init__(data_dir or config.DATASET_PATHS["midas"], split)
103
+ self.n_candidates = n_candidates
104
+
105
+ def get_name(self) -> str:
106
+ return "midas"
107
+
108
+ def load(self) -> list[MedicalCase]:
109
+ logger.info(f"Loading MIDAS dataset from {self.data_dir}")
110
+
111
+ # ---- Discover metadata file ----
112
+ metadata_path = self._find_metadata_file()
113
+ if metadata_path is None:
114
+ logger.error(f"No metadata file found in {self.data_dir}")
115
+ return []
116
+
117
+ records = self._load_metadata(metadata_path)
118
+ logger.info(f"Found {len(records)} records in metadata")
119
+
120
+ # ---- Group records by lesion (midas_record_id) ----
121
+ lesion_groups = defaultdict(list)
122
+ for r in records:
123
+ rid = r.get("midas_record_id", r.get("lesion_id", ""))
124
+ if rid:
125
+ lesion_groups[str(rid)].append(r)
126
+
127
+ logger.info(f"Found {len(lesion_groups)} unique lesions")
128
+
129
+ # ---- Build diagnosis distribution for candidate sampling ----
130
+ all_dx = []
131
+ for rid, recs in lesion_groups.items():
132
+ dx = self._get_diagnosis(recs[0])
133
+ if dx:
134
+ all_dx.append(dx)
135
+ dx_counter = Counter(all_dx)
136
+
137
+ # ---- Convert each lesion group to MedicalCase ----
138
+ self.cases = []
139
+ for rid, recs in lesion_groups.items():
140
+ case = self._build_case(rid, recs, dx_counter)
141
+ if case is not None:
142
+ self.cases.append(case)
143
+
144
+ logger.info(f"Loaded {len(self.cases)} MIDAS cases")
145
+ return self.cases
146
+
147
+ def _find_metadata_file(self) -> Path | None:
148
+ """Find the metadata file (xlsx, csv, or json)."""
149
+ # Try xlsx first (actual MIDAS format)
150
+ for name in ["release_midas.xlsx", "metadata.xlsx"]:
151
+ p = self.data_dir / name
152
+ if p.exists():
153
+ return p
154
+ # Then CSV
155
+ for name in ["metadata.csv", "labels.csv", "midas_metadata.csv"]:
156
+ p = self.data_dir / name
157
+ if p.exists():
158
+ return p
159
+ # Then JSON
160
+ for name in ["metadata.json", "labels.json"]:
161
+ p = self.data_dir / name
162
+ if p.exists():
163
+ return p
164
+ # Glob fallback
165
+ for pattern in ["*.xlsx", "*.csv"]:
166
+ matches = list(self.data_dir.glob(pattern))
167
+ if matches:
168
+ return matches[0]
169
+ return None
170
+
171
+ def _load_metadata(self, path: Path) -> list[dict]:
172
+ """Load metadata from xlsx, csv, or json."""
173
+ if path.suffix == ".xlsx":
174
+ return self._load_xlsx(path)
175
+ elif path.suffix == ".json":
176
+ with open(path, encoding="utf-8") as f:
177
+ return json.load(f)
178
+ else:
179
+ with open(path, newline="", encoding="utf-8-sig") as f:
180
+ reader = csv.DictReader(f)
181
+ return list(reader)
182
+
183
+ def _load_xlsx(self, path: Path) -> list[dict]:
184
+ """Load metadata from Excel file."""
185
+ import openpyxl
186
+ wb = openpyxl.load_workbook(path, read_only=True)
187
+ ws = wb[wb.sheetnames[0]]
188
+ rows = list(ws.iter_rows(values_only=True))
189
+ wb.close()
190
+
191
+ if not rows:
192
+ return []
193
+ headers = [str(h) if h is not None else f"col_{i}" for i, h in enumerate(rows[0])]
194
+ return [dict(zip(headers, row)) for row in rows[1:]]
195
+
196
+ def _get_diagnosis(self, record: dict) -> str | None:
197
+ """Extract canonical diagnosis from a record."""
198
+ raw_path = record.get("midas_path", record.get("diagnosis", ""))
199
+ if not raw_path or raw_path == "None" or raw_path is None:
200
+ return None
201
+ raw_path = str(raw_path).strip().lower()
202
+ # Try exact match in mapping
203
+ for key, canonical in PATH_TO_DIAGNOSIS.items():
204
+ if key.lower() == raw_path:
205
+ return canonical
206
+ # Fuzzy fallback
207
+ if "melanoma" in raw_path:
208
+ return "melanoma_invasive"
209
+ if "bcc" in raw_path or "basal" in raw_path:
210
+ return "basal_cell_carcinoma"
211
+ if "sccis" in raw_path:
212
+ return "squamous_cell_carcinoma_in_situ"
213
+ if "scc" in raw_path or "squamous" in raw_path:
214
+ return "squamous_cell_carcinoma"
215
+ if "nevus" in raw_path or "melanocytic" in raw_path:
216
+ return "melanocytic_nevus"
217
+ if "seborrheic" in raw_path or "keratosis" in raw_path:
218
+ return "seborrheic_keratosis"
219
+ if "ak" in raw_path or "actinic" in raw_path:
220
+ return "actinic_keratosis"
221
+ if "benign" in raw_path:
222
+ return "benign_other"
223
+ return None
224
+
225
+ def _find_image_by_filename(self, filename: str) -> Path | None:
226
+ """Find an image by its filename in the images directory."""
227
+ if not filename:
228
+ return None
229
+ # Try images/ subdir, then root, case-insensitive
230
+ search_dirs = [
231
+ self.data_dir / "images",
232
+ self.data_dir,
233
+ ]
234
+ for d in search_dirs:
235
+ if not d.exists():
236
+ continue
237
+ p = d / filename
238
+ if p.exists():
239
+ return p
240
+ # Case-insensitive search
241
+ for ext_p in d.iterdir():
242
+ if ext_p.name.lower() == filename.lower():
243
+ return ext_p
244
+ return None
245
+
246
+ def _build_case(
247
+ self,
248
+ record_id: str,
249
+ records: list[dict],
250
+ dx_counter: Counter,
251
+ ) -> MedicalCase | None:
252
+ """Convert a lesion's grouped records into a MedicalCase."""
253
+ # Use first non-control record for metadata
254
+ primary = None
255
+ for r in records:
256
+ if str(r.get("midas_iscontrol", "no")).lower() != "yes":
257
+ dx = self._get_diagnosis(r)
258
+ if dx:
259
+ primary = r
260
+ break
261
+ if primary is None:
262
+ return None # Skip control-only lesions
263
+
264
+ diagnosis = self._get_diagnosis(primary)
265
+ if not diagnosis:
266
+ return None
267
+
268
+ # ---- Build channels from all records in this lesion group ----
269
+ all_channels = {}
270
+
271
+ # Group images by modality
272
+ for r in records:
273
+ if str(r.get("midas_iscontrol", "no")).lower() == "yes":
274
+ continue
275
+ distance = str(r.get("midas_distance", "")).strip().lower()
276
+ channel_name = DISTANCE_TO_CHANNEL.get(distance)
277
+ if not channel_name:
278
+ continue
279
+ if channel_name in all_channels:
280
+ continue # Already have this modality
281
+
282
+ filename = r.get("midas_file_name", "")
283
+ img_path = self._find_image_by_filename(filename)
284
+ if img_path is None:
285
+ continue
286
+
287
+ try:
288
+ img_b64 = encode_image_to_base64(img_path)
289
+ except Exception:
290
+ continue
291
+
292
+ ch_meta = config.get_channel_definition("midas", channel_name)
293
+ descriptions = {
294
+ "clinical_30cm": "Clinical photograph at 30cm distance",
295
+ "clinical_15cm": "Clinical photograph at 15cm distance (closer view)",
296
+ "dermoscopy": "Dermoscopic image showing subsurface skin structures",
297
+ }
298
+ all_channels[channel_name] = ChannelData(
299
+ name=channel_name,
300
+ channel_type="image",
301
+ description=descriptions.get(channel_name, channel_name),
302
+ value=img_b64,
303
+ image_path=img_path,
304
+ cost=float(ch_meta.get("cost", 0.0)),
305
+ tier=ch_meta.get("tier", "unknown"),
306
+ always_given=bool(ch_meta.get("always_given", False)),
307
+ )
308
+
309
+ # Patient demographics
310
+ age = primary.get("midas_age", primary.get("age", ""))
311
+ sex = primary.get("midas_gender", primary.get("sex", ""))
312
+ fitz = primary.get("midas_fitzpatrick", primary.get("fitzpatrick", ""))
313
+ ethnicity = primary.get("midas_ethnicity", "")
314
+ race = primary.get("midas_race", "")
315
+ if any([age, sex, fitz]):
316
+ demo_parts = []
317
+ if age:
318
+ demo_parts.append(f"Age: {age}")
319
+ if sex:
320
+ demo_parts.append(f"Sex: {sex}")
321
+ if fitz:
322
+ demo_parts.append(f"Fitzpatrick skin type: {fitz}")
323
+ if ethnicity and str(ethnicity).lower() not in ("no", "none", ""):
324
+ demo_parts.append(f"Ethnicity: {ethnicity}")
325
+ if race and str(race).lower() not in ("no", "none", ""):
326
+ demo_parts.append(f"Race: {race}")
327
+ ch_meta = config.get_channel_definition("midas", "patient_demographics")
328
+ all_channels["patient_demographics"] = ChannelData(
329
+ name="patient_demographics",
330
+ channel_type="text",
331
+ description="Patient age, sex, and Fitzpatrick skin type",
332
+ value="; ".join(demo_parts),
333
+ cost=float(ch_meta.get("cost", 0.0)),
334
+ tier=ch_meta.get("tier", "unknown"),
335
+ always_given=bool(ch_meta.get("always_given", False)),
336
+ )
337
+
338
+ # Lesion metadata
339
+ location = primary.get("midas_location", primary.get("location", ""))
340
+ length = primary.get("length_(mm)", primary.get("length_mm", ""))
341
+ width = primary.get("width_(mm)", primary.get("width_mm", ""))
342
+ if any([location, length, width]):
343
+ meta_parts = []
344
+ if location:
345
+ meta_parts.append(f"Anatomic location: {location}")
346
+ if length:
347
+ meta_parts.append(f"Lesion length: {length}mm")
348
+ if width:
349
+ meta_parts.append(f"Lesion width: {width}mm")
350
+ ch_meta = config.get_channel_definition("midas", "lesion_metadata")
351
+ all_channels["lesion_metadata"] = ChannelData(
352
+ name="lesion_metadata",
353
+ channel_type="text",
354
+ description="Anatomic location, lesion length and width",
355
+ value="; ".join(meta_parts),
356
+ cost=float(ch_meta.get("cost", 0.0)),
357
+ tier=ch_meta.get("tier", "unknown"),
358
+ always_given=bool(ch_meta.get("always_given", False)),
359
+ )
360
+
361
+ if not all_channels:
362
+ return None
363
+
364
+ initial_channels = {
365
+ name: ch for name, ch in all_channels.items() if ch.always_given
366
+ }
367
+ requestable = {
368
+ name: ch for name, ch in all_channels.items() if not ch.always_given
369
+ }
370
+
371
+ if not initial_channels and not requestable:
372
+ return None
373
+
374
+ # ---- Build candidate list (correct + distractors) ----
375
+ case_id = f"midas_{record_id}"
376
+ candidates = self._generate_candidates(diagnosis, dx_counter, case_id)
377
+
378
+ if diagnosis not in candidates:
379
+ logger.warning(f"Ground truth '{diagnosis}' not in candidate list for {case_id}, forcing inclusion")
380
+ candidates[0] = diagnosis
381
+ rng = _case_rng(case_id)
382
+ rng.shuffle(candidates)
383
+
384
+ return MedicalCase(
385
+ case_id=case_id,
386
+ dataset="midas",
387
+ initial_channels=initial_channels,
388
+ requestable_channels=requestable,
389
+ candidates=candidates,
390
+ ground_truth=diagnosis,
391
+ ground_truth_rank=candidates.index(diagnosis),
392
+ metadata={
393
+ "lesion_id": record_id,
394
+ "original_record": {k: str(v) for k, v in primary.items()
395
+ if k not in ("image", "img")},
396
+ },
397
+ )
398
+
399
+ def _generate_candidates(self, correct_dx: str, dx_counter: Counter, case_id: str) -> list[str]:
400
+ """
401
+ Generate N candidate diagnoses: 1 correct + (N-1) distractors.
402
+
403
+ Uses a per-case deterministic RNG for reproducibility across conditions.
404
+ Distractors are sampled to be clinically plausible:
405
+ - At least one from the same diagnostic group
406
+ - Others from different groups, weighted by dataset frequency
407
+ """
408
+ n = self.n_candidates
409
+ rng = _case_rng(case_id)
410
+
411
+ # Find which group the correct dx belongs to
412
+ correct_group = None
413
+ for group_name, members in MIDAS_DIAGNOSIS_GROUPS.items():
414
+ if correct_dx in members:
415
+ correct_group = group_name
416
+ break
417
+
418
+ distractors = set()
419
+
420
+ # Add one same-group distractor if possible
421
+ if correct_group:
422
+ same_group = [d for d in MIDAS_DIAGNOSIS_GROUPS[correct_group] if d != correct_dx]
423
+ if same_group:
424
+ distractors.add(rng.choice(same_group))
425
+
426
+ # Fill rest from other groups, weighted by frequency
427
+ other_dx = [d for d in ALL_DIAGNOSES if d != correct_dx and d not in distractors]
428
+ weights = [dx_counter.get(d, 1) for d in other_dx]
429
+ total_w = sum(weights)
430
+ weights = [w / total_w for w in weights]
431
+
432
+ while len(distractors) < n - 1 and other_dx:
433
+ pick = rng.choices(other_dx, weights=weights, k=1)[0]
434
+ distractors.add(pick)
435
+ idx = other_dx.index(pick)
436
+ other_dx.pop(idx)
437
+ weights.pop(idx)
438
+ if weights:
439
+ total_w = sum(weights)
440
+ weights = [w / total_w for w in weights]
441
+
442
+ candidates = [correct_dx] + list(distractors)
443
+ rng.shuffle(candidates)
444
+ return candidates[:n]
datasets/nejm.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NEJM Image Challenge Dataset Loader.
3
+
4
+ Expects the cx0/nejm-image-challenge dataset structure:
5
+ nejm/
6
+ ├── data.json (or nejm_data.json)
7
+ │ Each entry: {date, image_url, prompt (clinical vignette),
8
+ │ options [A..E], correct_answer, votes}
9
+ ├── images/ (downloaded images, named by date YYYYMMDD.jpg)
10
+ └── parsed_vignettes.json (pre-parsed structured fields, optional)
11
+
12
+ The clinical vignette is decomposed into 5 requestable text channels
13
+ using LLM-based parsing (see scripts/parse_nejm_vignettes.py).
14
+ """
15
+ import json
16
+ import logging
17
+ import random
18
+ import re
19
+ from pathlib import Path
20
+
21
+ from .base import DatasetBase, MedicalCase, ChannelData
22
+ from api_client import encode_image_to_base64
23
+ import config
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # ---- Vignette parsing schema ----
28
+ VIGNETTE_FIELDS = [
29
+ "demographics",
30
+ "chief_complaint",
31
+ "medical_history",
32
+ "exam_findings",
33
+ "investigations",
34
+ ]
35
+
36
+ VIGNETTE_PARSE_PROMPT = """You are a medical data extraction system. Parse the following clinical \
37
+ vignette into exactly 5 structured fields. Extract ONLY information that is explicitly stated. \
38
+ If a field has no relevant information, write "Not mentioned."
39
+
40
+ FIELDS:
41
+ 1. demographics: Patient age, sex, race/ethnicity if stated.
42
+ 2. chief_complaint: The primary presenting symptom(s) and their duration.
43
+ 3. medical_history: Past medical conditions, medications, surgical history, family history, social history (smoking, alcohol, etc.).
44
+ 4. exam_findings: Physical examination findings, vital signs.
45
+ 5. investigations: Laboratory results, imaging findings, test results (anything with numbers or test names).
46
+
47
+ CLINICAL VIGNETTE:
48
+ {vignette}
49
+
50
+ Respond in EXACTLY this JSON format (no markdown, no extra text):
51
+ {{"demographics": "...", "chief_complaint": "...", "medical_history": "...", "exam_findings": "...", "investigations": "..."}}"""
52
+
53
+
54
+ class NEJMDataset(DatasetBase):
55
+ """Loader for NEJM Image Challenge dataset."""
56
+
57
+ def __init__(
58
+ self,
59
+ data_dir: str | Path = None,
60
+ split: str = "test",
61
+ vlm_client=None,
62
+ use_cached_parse: bool = True,
63
+ ):
64
+ super().__init__(data_dir or config.DATASET_PATHS["nejm"], split)
65
+ self.vlm_client = vlm_client
66
+ self.use_cached_parse = use_cached_parse
67
+ self._parsed_cache_path = self.data_dir / "parsed_vignettes.json"
68
+
69
+ def get_name(self) -> str:
70
+ return "nejm"
71
+
72
+ def load(self) -> list[MedicalCase]:
73
+ logger.info(f"Loading NEJM dataset from {self.data_dir}")
74
+
75
+ # ---- Load raw data ----
76
+ raw_data = self._load_raw_data()
77
+ if not raw_data:
78
+ return []
79
+ logger.info(f"Found {len(raw_data)} NEJM cases")
80
+
81
+ # ---- Load or create parsed vignettes ----
82
+ parsed = self._load_or_parse_vignettes(raw_data)
83
+
84
+ # ---- Build cases ----
85
+ self.cases = []
86
+ for entry in raw_data:
87
+ case_id = entry.get("date", entry.get("id", "unknown"))
88
+ case = self._build_case(entry, parsed.get(case_id, {}))
89
+ if case is not None:
90
+ self.cases.append(case)
91
+
92
+ logger.info(f"Loaded {len(self.cases)} NEJM cases")
93
+ return self.cases
94
+
95
+ def _load_raw_data(self) -> list[dict]:
96
+ """Load the raw NEJM dataset JSON."""
97
+ for name in ["data.json", "nejm_data.json", "nejm.json", "dataset.json"]:
98
+ p = self.data_dir / name
99
+ if p.exists():
100
+ with open(p, encoding="utf-8") as f:
101
+ data = json.load(f)
102
+ if isinstance(data, dict):
103
+ # Handle {date: entry} format
104
+ return [{"date": k, **v} if isinstance(v, dict) else v
105
+ for k, v in data.items()]
106
+ return data
107
+ # Try loading all JSON files
108
+ jsons = list(self.data_dir.glob("*.json"))
109
+ if jsons:
110
+ with open(jsons[0], encoding="utf-8") as f:
111
+ return json.load(f)
112
+ logger.error(f"No data file found in {self.data_dir}")
113
+ return []
114
+
115
+ def _load_or_parse_vignettes(self, raw_data: list[dict]) -> dict:
116
+ """Load cached parsed vignettes or parse them with LLM."""
117
+ # Try cache first
118
+ if self.use_cached_parse and self._parsed_cache_path.exists():
119
+ logger.info(f"Loading cached vignette parses from {self._parsed_cache_path}")
120
+ with open(self._parsed_cache_path) as f:
121
+ return json.load(f)
122
+
123
+ # Parse with LLM if client is available
124
+ if self.vlm_client is not None:
125
+ logger.info("Parsing vignettes with LLM (this may take a while)...")
126
+ parsed = {}
127
+ for entry in raw_data:
128
+ case_id = entry.get("date", entry.get("id", "unknown"))
129
+ vignette = entry.get("question", entry.get("prompt", entry.get("vignette", "")))
130
+ if vignette:
131
+ parsed[case_id] = self._parse_vignette_with_llm(vignette)
132
+ # Cache results
133
+ with open(self._parsed_cache_path, "w") as f:
134
+ json.dump(parsed, f, indent=2)
135
+ logger.info(f"Cached {len(parsed)} parsed vignettes")
136
+ return parsed
137
+
138
+ # Fallback: rule-based parsing
139
+ logger.info("No LLM client available. Using rule-based vignette parsing (less accurate).")
140
+ parsed = {}
141
+ for entry in raw_data:
142
+ case_id = entry.get("date", entry.get("id", "unknown"))
143
+ vignette = entry.get("question", entry.get("prompt", entry.get("vignette", "")))
144
+ if vignette:
145
+ parsed[case_id] = self._parse_vignette_rules(vignette)
146
+ return parsed
147
+
148
+ def _parse_vignette_with_llm(self, vignette: str) -> dict:
149
+ """Parse a single vignette using the LLM API."""
150
+ prompt = VIGNETTE_PARSE_PROMPT.format(vignette=vignette)
151
+ try:
152
+ response = self.vlm_client.call_with_retry(
153
+ system_prompt="You are a medical data extraction system. Respond only with valid JSON.",
154
+ user_text=prompt,
155
+ images=None,
156
+ temperature=0.0,
157
+ max_tokens=1024,
158
+ )
159
+ # Parse JSON from response
160
+ text = response.text.strip()
161
+ # Strip markdown code fences if present
162
+ text = re.sub(r"^```(?:json)?\s*", "", text)
163
+ text = re.sub(r"\s*```$", "", text)
164
+ parsed = json.loads(text)
165
+ # Validate expected fields
166
+ for field in VIGNETTE_FIELDS:
167
+ if field not in parsed:
168
+ parsed[field] = "Not mentioned."
169
+ return parsed
170
+ except Exception as e:
171
+ logger.warning(f"LLM vignette parsing failed: {e}. Falling back to rules.")
172
+ return self._parse_vignette_rules(vignette)
173
+
174
+ def _parse_vignette_rules(self, vignette: str) -> dict:
175
+ """
176
+ Rule-based fallback for vignette parsing.
177
+ Uses heuristic sentence classification.
178
+ """
179
+ result = {f: "" for f in VIGNETTE_FIELDS}
180
+ sentences = re.split(r'(?<=[.!?])\s+', vignette)
181
+
182
+ # Patterns for classification
183
+ demo_pattern = re.compile(
184
+ r'\b(\d{1,3})[-\s]year[-\s]old\b|'
185
+ r'\b(male|female|man|woman|boy|girl)\b',
186
+ re.IGNORECASE,
187
+ )
188
+ complaint_pattern = re.compile(
189
+ r'\bpresent(?:s|ed|ing)\b|\bcomplain(?:s|ed|ing)\b|\breport(?:s|ed|ing)\b|'
190
+ r'\bseek(?:s|ing)\b|\badmitted\b',
191
+ re.IGNORECASE,
192
+ )
193
+ history_pattern = re.compile(
194
+ r'\bhistory\b|\bprevious(?:ly)?\b|\bmedication\b|\btaking\b|\bdiagnosed\b|'
195
+ r'\bsmok(?:es|ing|er)\b|\balcohol\b|\bfamily\b|\bsurgery\b',
196
+ re.IGNORECASE,
197
+ )
198
+ exam_pattern = re.compile(
199
+ r'\bexamination\b|\bexam\b|\bpalpat(?:ion|ed)\b|\bauscult(?:ation|ed)\b|'
200
+ r'\bvital\b|\bblood\s+pressure\b|\bheart\s+rate\b|\btemperature\b|'
201
+ r'\bappears\b|\btender\b|\bswollen\b|\berythema\b',
202
+ re.IGNORECASE,
203
+ )
204
+ invest_pattern = re.compile(
205
+ r'\b(?:hemoglobin|WBC|platelet|creatinine|BUN|glucose|sodium|potassium)\b|'
206
+ r'\b(?:CT|MRI|X[-\s]?ray|ultrasound|ECG|EKG|biopsy)\b|'
207
+ r'\b\d+\.?\d*\s*(?:mg|g|mL|mmol|mEq|U|IU|mmHg|\/dL|\/L)\b|'
208
+ r'\blaboratory\b|\blab(?:s)?\b|\btest\b|\blevel\b|\bfinding\b',
209
+ re.IGNORECASE,
210
+ )
211
+
212
+ for sent in sentences:
213
+ sent = sent.strip()
214
+ if not sent:
215
+ continue
216
+
217
+ # Demographics: typically the first sentence
218
+ if demo_pattern.search(sent) and not result["demographics"]:
219
+ result["demographics"] = sent
220
+ continue
221
+
222
+ # Check each pattern (a sentence can match multiple, take first)
223
+ matched = False
224
+ for field, pattern in [
225
+ ("investigations", invest_pattern),
226
+ ("exam_findings", exam_pattern),
227
+ ("medical_history", history_pattern),
228
+ ("chief_complaint", complaint_pattern),
229
+ ]:
230
+ if pattern.search(sent):
231
+ if result[field]:
232
+ result[field] += " " + sent
233
+ else:
234
+ result[field] = sent
235
+ matched = True
236
+ break
237
+
238
+ # Unmatched sentences go to chief_complaint as default
239
+ if not matched:
240
+ if result["chief_complaint"]:
241
+ result["chief_complaint"] += " " + sent
242
+ else:
243
+ result["chief_complaint"] = sent
244
+
245
+ # Replace empty fields
246
+ for field in VIGNETTE_FIELDS:
247
+ if not result[field].strip():
248
+ result[field] = "Not mentioned."
249
+
250
+ return result
251
+
252
+ @staticmethod
253
+ def _date_to_yyyymmdd(date_str: str) -> str | None:
254
+ """Convert 'apr-01-2010' style date to '20100401' for image lookup."""
255
+ from datetime import datetime
256
+ for fmt in ("%b-%d-%Y", "%B-%d-%Y", "%Y-%m-%d", "%Y%m%d"):
257
+ try:
258
+ dt = datetime.strptime(date_str, fmt)
259
+ return dt.strftime("%Y%m%d")
260
+ except ValueError:
261
+ continue
262
+ return None
263
+
264
+ def _build_case(self, entry: dict, parsed_vignette: dict) -> MedicalCase | None:
265
+ """Convert a raw NEJM entry + parsed vignette into a MedicalCase."""
266
+ case_id = entry.get("date", entry.get("id", "unknown"))
267
+
268
+ # ---- Find image ----
269
+ img_b64 = None
270
+ img_dir = self.data_dir / "images"
271
+ # Build candidate filenames: original case_id + YYYYMMDD conversion
272
+ name_candidates = [case_id]
273
+ yyyymmdd = self._date_to_yyyymmdd(case_id)
274
+ if yyyymmdd:
275
+ name_candidates.append(yyyymmdd)
276
+
277
+ if img_dir.exists():
278
+ for name in name_candidates:
279
+ for ext in [".jpg", ".jpeg", ".png"]:
280
+ p = img_dir / f"{name}{ext}"
281
+ if p.exists():
282
+ try:
283
+ img_b64 = encode_image_to_base64(p)
284
+ except Exception:
285
+ pass
286
+ break
287
+ if img_b64 is not None:
288
+ break
289
+ if img_b64 is None:
290
+ # Glob for any match
291
+ for name in name_candidates:
292
+ matches = list(img_dir.glob(f"*{name}*"))
293
+ if matches:
294
+ try:
295
+ img_b64 = encode_image_to_base64(matches[0])
296
+ except Exception:
297
+ pass
298
+ break
299
+
300
+ # ---- Build all available channels, then split by config ----
301
+ all_channels = {}
302
+ if img_b64 is not None:
303
+ image_meta = config.get_channel_definition("nejm", "image")
304
+ all_channels["image"] = ChannelData(
305
+ name="image",
306
+ channel_type="image",
307
+ description="The primary diagnostic image",
308
+ value=img_b64,
309
+ cost=float(image_meta.get("cost", 0.0)),
310
+ tier=image_meta.get("tier", "unknown"),
311
+ always_given=bool(image_meta.get("always_given", False)),
312
+ )
313
+
314
+ field_descriptions = {
315
+ "demographics": "Patient age, sex, and ethnicity if mentioned",
316
+ "chief_complaint": "The presenting symptom(s) and their duration",
317
+ "medical_history": "Past medical conditions, medications, family and social history",
318
+ "exam_findings": "Physical examination results and observations",
319
+ "investigations": "Laboratory values, prior imaging results, and test outcomes",
320
+ }
321
+
322
+ for field in VIGNETTE_FIELDS:
323
+ value = parsed_vignette.get(field, "Not mentioned.")
324
+ field_meta = config.get_channel_definition("nejm", field)
325
+ if value and value.strip() != "Not mentioned.":
326
+ all_channels[field] = ChannelData(
327
+ name=field,
328
+ channel_type="text",
329
+ description=field_descriptions.get(field, field),
330
+ value=value,
331
+ cost=float(field_meta.get("cost", 0.0)),
332
+ tier=field_meta.get("tier", "unknown"),
333
+ always_given=bool(field_meta.get("always_given", False)),
334
+ )
335
+ else:
336
+ all_channels[field] = ChannelData(
337
+ name=field,
338
+ channel_type="text",
339
+ description=field_descriptions.get(field, field),
340
+ value="No additional information available for this category.",
341
+ cost=float(field_meta.get("cost", 0.0)),
342
+ tier=field_meta.get("tier", "unknown"),
343
+ always_given=bool(field_meta.get("always_given", False)),
344
+ )
345
+
346
+ initial_channels = {
347
+ name: ch for name, ch in all_channels.items() if ch.always_given
348
+ }
349
+ requestable = {
350
+ name: ch for name, ch in all_channels.items() if not ch.always_given
351
+ }
352
+
353
+ if not initial_channels and not requestable:
354
+ logger.debug(f"Skipping NEJM {case_id}: no usable channels found")
355
+ return None
356
+
357
+ # ---- Candidates: the 5 MCQ options ----
358
+ options = entry.get("options", [])
359
+ correct = entry.get("correct_answer", entry.get("answer", ""))
360
+
361
+ # Handle flat option_A..option_E keys (cx0/nejm-image-challenge format)
362
+ if not options:
363
+ flat_options = {}
364
+ for letter in "ABCDE":
365
+ val = entry.get(f"option_{letter}", "")
366
+ if val:
367
+ flat_options[letter] = val
368
+ if flat_options:
369
+ options = flat_options
370
+
371
+ if isinstance(options, dict):
372
+ # {A: "...", B: "...", ...}
373
+ candidates = [f"{k}. {v}" for k, v in sorted(options.items())]
374
+ gt_label = None
375
+ for k, v in sorted(options.items()):
376
+ if k == correct:
377
+ gt_label = f"{k}. {v}"
378
+ break
379
+ if gt_label is None:
380
+ gt_label = candidates[0] if candidates else ""
381
+ elif isinstance(options, list) and options:
382
+ candidates = options
383
+ if isinstance(correct, int):
384
+ gt_label = options[correct] if correct < len(options) else options[0]
385
+ elif isinstance(correct, str) and len(correct) == 1:
386
+ # Letter answer (A=0, B=1, ...)
387
+ idx = ord(correct.upper()) - ord("A")
388
+ gt_label = options[idx] if idx < len(options) else options[0]
389
+ else:
390
+ gt_label = correct
391
+ else:
392
+ candidates = [correct] if correct else ["Unknown"]
393
+ gt_label = correct
394
+
395
+ # ---- Votes (physician response distribution) ----
396
+ votes = entry.get("votes", {})
397
+ # Handle flat vote keys (option_A_votes, etc.)
398
+ if not votes:
399
+ for letter in "ABCDE":
400
+ val = entry.get(f"option_{letter}_votes", "")
401
+ if val:
402
+ votes[letter] = val
403
+
404
+ return MedicalCase(
405
+ case_id=f"nejm_{case_id}",
406
+ dataset="nejm",
407
+ initial_channels=initial_channels,
408
+ requestable_channels=requestable,
409
+ candidates=candidates,
410
+ ground_truth=gt_label,
411
+ ground_truth_rank=(candidates.index(gt_label) if gt_label in candidates else 0),
412
+ metadata={
413
+ "date": case_id,
414
+ "votes": votes,
415
+ "full_vignette": entry.get("question", entry.get("prompt", entry.get("vignette", ""))),
416
+ "parsed_fields": parsed_vignette,
417
+ },
418
+ )
419
+
420
+ def get_human_difficulty(self, case: MedicalCase) -> float | None:
421
+ """
422
+ Compute human difficulty score from physician vote distribution.
423
+
424
+ Returns: proportion of physicians who answered correctly (0-1),
425
+ or None if votes unavailable.
426
+ """
427
+ votes = case.metadata.get("votes", {})
428
+ if not votes:
429
+ return None
430
+ correct_key = case.metadata.get("date", "")
431
+ # votes might be {A: 0.12, B: 0.65, ...} or {A: 120, B: 650, ...}
432
+ total = sum(float(v) for v in votes.values())
433
+ if total == 0:
434
+ return None
435
+ # Find the correct answer key
436
+ gt = case.ground_truth
437
+ for key, val in votes.items():
438
+ if key in gt or gt.startswith(key):
439
+ return float(val) / total if total > 1 else float(val)
440
+ return None
datasets/olives.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OLIVES Dataset Loader.
3
+
4
+ Adapted for the actual Zenodo OLIVES dataset structure:
5
+ data/
6
+ ├── OLIVES/OLIVES/
7
+ │ ├── Prime_FULL/Prime_FULL/ (DR patients — OCT B-scans)
8
+ │ │ └── <patient_id>/<visit>/<eye>/*.png
9
+ │ └── TREX_DME/TREX DME/ (DME patients — OCT B-scans)
10
+ │ └── <arm>/<patient_id>/<visit>/<eye>/*.tif
11
+ └── OLIVES_Dataset_Labels/OLIVES_Dataset_Labels/
12
+ └── full_labels/Biomarker_Clinical_Data_Images.csv
13
+
14
+ Task: Biomarker profile ranking.
15
+ - Given an OCT B-scan, rank candidate biomarker profiles
16
+ - Each profile is a subset of the 16 annotated biomarkers
17
+ - Correct profile = actual biomarker vector for this eye
18
+ - Distractors = profiles from other eyes
19
+
20
+ Channels:
21
+ - Initial: single OCT B-scan (middle slice)
22
+ - Requestable: additional OCT slices, clinical measurements (BCVA/CST),
23
+ biomarker hints (fundus-visible subset), treatment history
24
+ """
25
+ import csv
26
+ import hashlib
27
+ import json
28
+ import logging
29
+ import random
30
+ from pathlib import Path
31
+ from collections import defaultdict
32
+
33
+ import numpy as np
34
+
35
+ from .base import DatasetBase, MedicalCase, ChannelData
36
+ from api_client import encode_image_to_base64
37
+ import config
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ # The biomarker columns as they appear in the CSV
42
+ OLIVES_CSV_BIOMARKERS = {
43
+ "Fluid (IRF)": "fluid_irf",
44
+ "Fluid (SRF)": "fluid_srf",
45
+ "DRT/ME": "drt_me",
46
+ "SHRM": "shrm",
47
+ "Preretinal tissue/hemorrhage": "preretinal_tissue",
48
+ "Vitreous debris": "vitreous_debris",
49
+ "DRIL": "dril",
50
+ "Disruption of EZ": "ez_disruption",
51
+ "IR hemorrhages": "hemorrhage",
52
+ "IR HRF": "ir_hrf",
53
+ "Disruption of RPE": "rpe_disruption",
54
+ "PED (serous)": "ped_serous",
55
+ "Atrophy / thinning of retinal layers": "atrophy",
56
+ "VMT": "vmt",
57
+ "Partially attached vitreous face": "partial_vitreous",
58
+ "Fully attached vitreous face": "full_vitreous",
59
+ }
60
+
61
+ # Canonical biomarker names for profiles
62
+ OLIVES_BIOMARKERS = sorted(OLIVES_CSV_BIOMARKERS.values())
63
+
64
+
65
+ def biomarker_vector_to_profile_string(vector: dict[str, bool]) -> str:
66
+ """Convert a biomarker dict to a human-readable profile string."""
67
+ present = [
68
+ name.replace("_", " ").title()
69
+ for name, val in sorted(vector.items()) if val
70
+ ]
71
+ if not present:
72
+ return "No biomarkers detected"
73
+ return "Present biomarkers: " + ", ".join(present)
74
+
75
+
76
+ def compute_profile_distance(profile_a: dict, profile_b: dict) -> int:
77
+ """Hamming distance between two biomarker profiles."""
78
+ dist = 0
79
+ for key in OLIVES_BIOMARKERS:
80
+ if profile_a.get(key, False) != profile_b.get(key, False):
81
+ dist += 1
82
+ return dist
83
+
84
+
85
+ def _case_rng(case_id: str) -> random.Random:
86
+ seed = int(hashlib.sha256(case_id.encode()).hexdigest()[:8], 16)
87
+ return random.Random(seed)
88
+
89
+
90
+ class OLIVESDataset(DatasetBase):
91
+ """Loader for OLIVES ophthalmology dataset."""
92
+
93
+ def __init__(
94
+ self,
95
+ data_dir: str | Path = None,
96
+ split: str = "test",
97
+ n_candidates: int = 5,
98
+ n_oct_samples: int = 3,
99
+ ):
100
+ super().__init__(data_dir or config.DATASET_PATHS["olives"], split)
101
+ self.n_candidates = n_candidates
102
+ self.n_oct_samples = n_oct_samples
103
+
104
+ def get_name(self) -> str:
105
+ return "olives"
106
+
107
+ def load(self) -> list[MedicalCase]:
108
+ logger.info(f"Loading OLIVES dataset from {self.data_dir}")
109
+
110
+ # ---- Find the CSV ----
111
+ csv_path = self._find_csv()
112
+ if csv_path is None:
113
+ logger.error("No biomarker CSV found")
114
+ return []
115
+
116
+ # ---- Load records ----
117
+ with open(csv_path, newline="", encoding="utf-8-sig") as f:
118
+ rows = list(csv.DictReader(f))
119
+ logger.info(f"Found {len(rows)} records in {csv_path.name}")
120
+
121
+ # ---- Find the image root ----
122
+ image_root = self._find_image_root()
123
+ if image_root is None:
124
+ logger.error("No image directory found")
125
+ return []
126
+ logger.info(f"Image root: {image_root}")
127
+
128
+ # ---- Group by eye ----
129
+ eye_groups = defaultdict(list)
130
+ for r in rows:
131
+ pid = r.get("Patient_ID", "")
132
+ path_str = r.get(
133
+ "Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
134
+ )
135
+ parts = path_str.strip("/").split("/")
136
+ if len(parts) >= 5:
137
+ eye = parts[4] # OD or OS
138
+ else:
139
+ eye = r.get("Eye_ID", "unknown")
140
+ eye_key = f"{pid}_{eye}"
141
+ r["_eye_key"] = eye_key
142
+ r["_path_parts"] = parts
143
+ eye_groups[eye_key].append(r)
144
+
145
+ logger.info(f"Found {len(eye_groups)} unique eyes")
146
+
147
+ # ---- Build biomarker profiles ----
148
+ all_profiles = {}
149
+ for eye_key, records in eye_groups.items():
150
+ latest = records[-1]
151
+ all_profiles[eye_key] = self._extract_biomarker_vector(latest)
152
+
153
+ # ---- Build cases ----
154
+ self.cases = []
155
+ for eye_key, records in eye_groups.items():
156
+ case = self._build_case(
157
+ eye_key, records, all_profiles, image_root
158
+ )
159
+ if case is not None:
160
+ self.cases.append(case)
161
+
162
+ logger.info(f"Loaded {len(self.cases)} OLIVES cases")
163
+ return self.cases
164
+
165
+ def _find_csv(self) -> Path | None:
166
+ """Find the biomarker CSV in various locations."""
167
+ search_paths = [
168
+ self.data_dir / "Biomarker_Clinical_Data_Images.csv",
169
+ self.data_dir / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv",
170
+ self.data_dir.parent / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv",
171
+ ]
172
+ for p in search_paths:
173
+ if p.exists():
174
+ return p
175
+ # Glob fallback
176
+ csvs = list(self.data_dir.rglob("Biomarker*Clinical*.csv"))
177
+ if csvs:
178
+ return csvs[0]
179
+ # Check parent
180
+ csvs = list(self.data_dir.parent.rglob("Biomarker*Clinical*.csv"))
181
+ if csvs:
182
+ return csvs[0]
183
+ return None
184
+
185
+ def _find_image_root(self) -> Path | None:
186
+ """Find the root directory containing Prime_FULL and TREX_DME."""
187
+ search = [
188
+ self.data_dir / "OLIVES",
189
+ self.data_dir / "OLIVES" / "OLIVES",
190
+ self.data_dir,
191
+ ]
192
+ for d in search:
193
+ if (d / "Prime_FULL").exists() or (d / "TREX_DME").exists():
194
+ return d
195
+ # Search deeper
196
+ for p in self.data_dir.rglob("Prime_FULL"):
197
+ return p.parent
198
+ return None
199
+
200
+ def _extract_biomarker_vector(self, record: dict) -> dict[str, bool]:
201
+ """Extract biomarker vector from a CSV row."""
202
+ vector = {}
203
+ for csv_col, canonical_name in OLIVES_CSV_BIOMARKERS.items():
204
+ val = record.get(csv_col, "0")
205
+ if isinstance(val, str):
206
+ vector[canonical_name] = val.strip() == "1"
207
+ else:
208
+ vector[canonical_name] = bool(int(float(val or 0)))
209
+ return vector
210
+
211
+ def _find_oct_images(
212
+ self, records: list[dict], image_root: Path, n: int = 3
213
+ ) -> list[Path]:
214
+ """Find OCT B-scan images for an eye."""
215
+ # Try to locate images from the path in the CSV
216
+ for r in records:
217
+ path_str = r.get(
218
+ "Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
219
+ )
220
+ parts = path_str.strip("/").split("/")
221
+ if len(parts) < 5:
222
+ continue
223
+
224
+ # Construct search directory (without the image filename)
225
+ # Path format: /Trial/Arm/Patient/Visit/Eye/Image
226
+ trial = parts[0]
227
+ remaining = "/".join(parts[1:-1])
228
+
229
+ search_dirs = [
230
+ image_root / trial / remaining,
231
+ image_root / parts[0].replace(" ", "_") / remaining,
232
+ ]
233
+
234
+ # For Prime: Prime_FULL/Prime_FULL/Patient/Visit/Eye/
235
+ if "Prime" in trial or "prime" in trial:
236
+ pid = parts[2] if len(parts) > 2 else ""
237
+ visit = parts[3] if len(parts) > 3 else ""
238
+ eye = parts[4] if len(parts) > 4 else ""
239
+ search_dirs.extend([
240
+ image_root / "Prime_FULL" / "Prime_FULL" / pid / visit / eye,
241
+ image_root / "Prime_FULL" / pid / visit / eye,
242
+ ])
243
+
244
+ # For TREX: TREX_DME/TREX DME/Arm/Patient/Visit/Eye/
245
+ if "TREX" in trial:
246
+ arm = parts[1] if len(parts) > 1 else ""
247
+ pid = parts[2] if len(parts) > 2 else ""
248
+ visit = parts[3] if len(parts) > 3 else ""
249
+ eye = parts[4] if len(parts) > 4 else ""
250
+ search_dirs.extend([
251
+ image_root / "TREX_DME" / "TREX DME" / arm / pid / visit / eye,
252
+ image_root / "TREX_DME" / trial / arm / pid / visit / eye,
253
+ ])
254
+
255
+ for d in search_dirs:
256
+ if not d.exists():
257
+ continue
258
+ images = sorted(
259
+ list(d.glob("*.png")) + list(d.glob("*.tif"))
260
+ + list(d.glob("*.jpg"))
261
+ )
262
+ if images:
263
+ # Sample N evenly spaced scans
264
+ if len(images) <= n:
265
+ return images
266
+ indices = np.linspace(
267
+ 0, len(images) - 1, n, dtype=int
268
+ )
269
+ return [images[i] for i in indices]
270
+
271
+ return []
272
+
273
+ def _build_case(
274
+ self,
275
+ eye_key: str,
276
+ records: list[dict],
277
+ all_profiles: dict[str, dict[str, bool]],
278
+ image_root: Path,
279
+ ) -> MedicalCase | None:
280
+ """Convert an eye's records into a MedicalCase."""
281
+ latest = records[-1]
282
+
283
+ # ---- Find OCT images ----
284
+ oct_images = self._find_oct_images(records, image_root, self.n_oct_samples + 1)
285
+ if not oct_images:
286
+ logger.debug(f"Skipping eye {eye_key}: no images found")
287
+ return None
288
+
289
+ # Build all available channels, then split by config
290
+ all_channels = {}
291
+
292
+ # Use middle scan as canonical first-line OCT, rest as optional extras
293
+ mid_idx = len(oct_images) // 2
294
+ initial_image = oct_images[mid_idx]
295
+ additional_images = [
296
+ img for i, img in enumerate(oct_images) if i != mid_idx
297
+ ]
298
+
299
+ try:
300
+ initial_b64 = encode_image_to_base64(initial_image)
301
+ except Exception as e:
302
+ logger.debug(f"Skipping eye {eye_key}: encode failed: {e}")
303
+ return None
304
+
305
+ oct_meta = config.get_channel_definition("olives", "oct_scan")
306
+ all_channels["oct_scan"] = ChannelData(
307
+ name="oct_scan",
308
+ channel_type="image",
309
+ description="OCT B-scan showing retinal cross-section",
310
+ value=initial_b64,
311
+ image_path=initial_image,
312
+ cost=float(oct_meta.get("cost", 0.0)),
313
+ tier=oct_meta.get("tier", "unknown"),
314
+ always_given=bool(oct_meta.get("always_given", False)),
315
+ )
316
+
317
+ # Additional OCT slices
318
+ if additional_images:
319
+ try:
320
+ add_b64 = [encode_image_to_base64(p) for p in additional_images]
321
+ ch_meta = config.get_channel_definition("olives", "additional_oct")
322
+ all_channels["additional_oct"] = ChannelData(
323
+ name="additional_oct",
324
+ channel_type="image",
325
+ description="Additional OCT B-scans from different retinal locations",
326
+ value=add_b64,
327
+ cost=float(ch_meta.get("cost", 0.0)),
328
+ tier=ch_meta.get("tier", "unknown"),
329
+ always_given=bool(ch_meta.get("always_given", False)),
330
+ )
331
+ except Exception:
332
+ pass
333
+
334
+ # Clinical measurements (BCVA and CST)
335
+ bcva = latest.get("BCVA", "")
336
+ cst = latest.get("CST", "")
337
+ if bcva or cst:
338
+ parts = []
339
+ if bcva:
340
+ parts.append(f"BCVA (logMAR): {bcva}")
341
+ if cst:
342
+ parts.append(f"CST: {cst} um")
343
+ ch_meta = config.get_channel_definition("olives", "clinical_measurements")
344
+ all_channels["clinical_measurements"] = ChannelData(
345
+ name="clinical_measurements",
346
+ channel_type="text",
347
+ description="Visual acuity (BCVA) and retinal thickness (CST)",
348
+ value="; ".join(parts),
349
+ cost=float(ch_meta.get("cost", 0.0)),
350
+ tier=ch_meta.get("tier", "unknown"),
351
+ always_given=bool(ch_meta.get("always_given", False)),
352
+ )
353
+
354
+ # Biomarker hints (subset — only the most obvious ones)
355
+ biomarker_vec = all_profiles[eye_key]
356
+ obvious_markers = ["fluid_irf", "fluid_srf", "hemorrhage", "drt_me"]
357
+ hint_parts = []
358
+ for m in obvious_markers:
359
+ if m in biomarker_vec:
360
+ status = "Present" if biomarker_vec[m] else "Not detected"
361
+ hint_parts.append(
362
+ f"{m.replace('_', ' ').title()}: {status}"
363
+ )
364
+ if hint_parts:
365
+ ch_meta = config.get_channel_definition("olives", "biomarker_hints")
366
+ all_channels["biomarker_hints"] = ChannelData(
367
+ name="biomarker_hints",
368
+ channel_type="text",
369
+ description="Partial biomarker annotations (subset)",
370
+ value="; ".join(hint_parts),
371
+ cost=float(ch_meta.get("cost", 0.0)),
372
+ tier=ch_meta.get("tier", "unknown"),
373
+ always_given=bool(ch_meta.get("always_given", False)),
374
+ )
375
+
376
+ # Disease type hint
377
+ path_str = latest.get(
378
+ "Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
379
+ )
380
+ disease = "DME" if "TREX" in path_str else "DR"
381
+ ch_meta = config.get_channel_definition("olives", "disease_context")
382
+ all_channels["disease_context"] = ChannelData(
383
+ name="disease_context",
384
+ channel_type="text",
385
+ description="Disease type and treatment context",
386
+ value=f"Disease: {disease}",
387
+ cost=float(ch_meta.get("cost", 0.0)),
388
+ tier=ch_meta.get("tier", "unknown"),
389
+ always_given=bool(ch_meta.get("always_given", False)),
390
+ )
391
+
392
+ initial_channels = {
393
+ name: ch for name, ch in all_channels.items() if ch.always_given
394
+ }
395
+ requestable = {
396
+ name: ch for name, ch in all_channels.items() if not ch.always_given
397
+ }
398
+
399
+ # ---- Build candidates ----
400
+ case_id = f"olives_{eye_key}"
401
+ correct_profile = biomarker_vector_to_profile_string(biomarker_vec)
402
+ candidates = self._generate_profile_candidates(
403
+ eye_key, biomarker_vec, all_profiles, case_id
404
+ )
405
+
406
+ if correct_profile not in candidates:
407
+ candidates[0] = correct_profile
408
+ rng = _case_rng(case_id)
409
+ rng.shuffle(candidates)
410
+
411
+ return MedicalCase(
412
+ case_id=case_id,
413
+ dataset="olives",
414
+ initial_channels=initial_channels,
415
+ requestable_channels=requestable,
416
+ candidates=candidates,
417
+ ground_truth=correct_profile,
418
+ ground_truth_rank=(
419
+ candidates.index(correct_profile)
420
+ if correct_profile in candidates else 0
421
+ ),
422
+ metadata={
423
+ "eye_id": eye_key,
424
+ "disease": disease,
425
+ "biomarker_vector": biomarker_vec,
426
+ },
427
+ )
428
+
429
+ def _generate_profile_candidates(
430
+ self,
431
+ eye_id: str,
432
+ correct_vec: dict[str, bool],
433
+ all_profiles: dict[str, dict[str, bool]],
434
+ case_id: str,
435
+ ) -> list[str]:
436
+ """Generate biomarker profile candidates."""
437
+ n = self.n_candidates
438
+ rng = _case_rng(case_id)
439
+ correct_str = biomarker_vector_to_profile_string(correct_vec)
440
+
441
+ scored = []
442
+ for eid, vec in all_profiles.items():
443
+ if eid == eye_id:
444
+ continue
445
+ dist = compute_profile_distance(correct_vec, vec)
446
+ profile_str = biomarker_vector_to_profile_string(vec)
447
+ if profile_str != correct_str:
448
+ scored.append((dist, profile_str, vec))
449
+
450
+ scored.sort(key=lambda x: x[0])
451
+
452
+ distractors = []
453
+ if scored:
454
+ distractors.append(scored[0][1]) # Hard distractor
455
+ if len(scored) > 1:
456
+ distractors.append(scored[-1][1]) # Easy distractor
457
+ mid_pool = scored[len(scored) // 4: 3 * len(scored) // 4]
458
+ rng.shuffle(mid_pool)
459
+ for dist, prof, vec in mid_pool:
460
+ if prof not in distractors and len(distractors) < n - 1:
461
+ distractors.append(prof)
462
+
463
+ while len(distractors) < n - 1 and scored:
464
+ pick = rng.choice(scored)
465
+ if pick[1] not in distractors:
466
+ distractors.append(pick[1])
467
+
468
+ candidates = [correct_str] + distractors[:n - 1]
469
+ rng.shuffle(candidates)
470
+ return candidates
demo_cases/chest_xray_ipf.png ADDED

Git LFS Details

  • SHA256: 68275cef7e60ce4fa6c2402c4f6a18cd70fa32056300979063f3acd221335ea3
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
demo_cases/ct_pulmonary_pe.png ADDED

Git LFS Details

  • SHA256: 341515b895e6bb9c1a226b80a4c8373e744522126643a68aa2137e2fe60de263
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
demo_cases/fundus_dme.png ADDED

Git LFS Details

  • SHA256: 663d0080e4f79c1b3a8b91d7c514cfa892852a9815bf717ad5c144547522b188
  • Pointer size: 131 Bytes
  • Size of remote file: 283 kB
demo_cases/oct_bscan_dme.png ADDED

Git LFS Details

  • SHA256: 31c2470dc3ce6be6d5ab876bdcaeba8353a70843db7cf278c4fa56e784225960
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
demo_cases/skin_lesion_dermoscopy.png ADDED

Git LFS Details

  • SHA256: 9de6739f5d6a4d3771fde0d8c5f473fd652ab66a23868986d51b24f96dfaf9f7
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
evaluation/__init__.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Metrics for ActiveMedAgent.
3
+
4
+ Unified metrics across all three datasets:
5
+ - MRR (Mean Reciprocal Rank)
6
+ - Acquisition Efficiency (normalized improvement)
7
+ - Top-1 Accuracy
8
+ - Acquisition Precision
9
+ - Uncertainty Calibration (ECE-style)
10
+ - Information-Theoretic Metrics (entropy, IG, VoI)
11
+ - Bootstrap confidence intervals
12
+ """
13
+ import logging
14
+ from dataclasses import dataclass, field
15
+
16
+ import numpy as np
17
+ from scipy import stats
18
+
19
+ from agent import AgentResult
20
+ from datasets.base import MedicalCase
21
+ from information_gain import BeliefTrajectory, compute_information_metrics
22
+ import config
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class CaseMetrics:
29
+ """Metrics for a single case."""
30
+ case_id: str
31
+ dataset: str
32
+ top1_correct: bool = False
33
+ reciprocal_rank: float = 0.0
34
+ ground_truth_rank: int = -1 # 1-indexed rank of correct answer
35
+ n_acquired: int = 0
36
+ acquired_channels: list[str] = field(default_factory=list)
37
+ committed_early: bool = False
38
+ top1_confidence: float = 0.0 # Confidence of the top-ranked diagnosis
39
+ acquisition_cost: float = 0.0
40
+ total_case_cost: float = 0.0
41
+
42
+
43
+ @dataclass
44
+ class DatasetMetrics:
45
+ """Aggregated metrics for a dataset."""
46
+ dataset: str
47
+ n_cases: int
48
+ top1_accuracy: float
49
+ mrr: float # Mean Reciprocal Rank
50
+ top1_accuracy_ci: tuple = (0.0, 0.0) # 95% CI
51
+ mrr_ci: tuple = (0.0, 0.0)
52
+ mean_channels_acquired: float = 0.0
53
+ early_commit_rate: float = 0.0
54
+ per_channel_request_rate: dict = field(default_factory=dict)
55
+ mean_acquisition_cost: float = 0.0
56
+ mean_total_case_cost: float = 0.0
57
+
58
+
59
+ def compute_reciprocal_rank(
60
+ ranking: list[dict],
61
+ ground_truth: str,
62
+ candidates: list[str],
63
+ ) -> float:
64
+ """
65
+ Compute reciprocal rank of the ground truth in the agent's ranking.
66
+
67
+ Returns 1/rank if found, 0 if not found.
68
+ """
69
+ if not ranking:
70
+ return 0.0
71
+
72
+ gt_lower = ground_truth.lower().strip()
73
+
74
+ for entry in ranking:
75
+ name = entry.get("name", "").lower().strip()
76
+ rank = entry.get("rank", 999)
77
+
78
+ # Flexible matching: check substring containment both ways
79
+ if gt_lower in name or name in gt_lower:
80
+ return 1.0 / rank
81
+
82
+ # Check if it matches any candidate that matches ground truth
83
+ for candidate in candidates:
84
+ if (
85
+ gt_lower in candidate.lower()
86
+ and (name in candidate.lower() or candidate.lower() in name)
87
+ ):
88
+ return 1.0 / rank
89
+
90
+ # Ground truth not found in ranking — return 1/(N+1)
91
+ return 1.0 / (len(ranking) + 1) if ranking else 0.0
92
+
93
+
94
+ def evaluate_single_case(
95
+ result: AgentResult,
96
+ case: MedicalCase,
97
+ ) -> CaseMetrics:
98
+ """Evaluate a single agent result against ground truth."""
99
+ ranking = result.final_ranking
100
+ gt = case.ground_truth
101
+ candidates = case.candidates
102
+
103
+ rr = compute_reciprocal_rank(ranking, gt, candidates)
104
+ top1_correct = rr == 1.0 # RR=1 means correct answer is ranked first
105
+
106
+ top1_conf = ranking[0]["confidence"] if ranking else 0.0
107
+
108
+ # Determine ground truth rank in agent's output
109
+ gt_rank = -1
110
+ gt_lower = gt.lower().strip()
111
+ for entry in ranking:
112
+ name = entry.get("name", "").lower().strip()
113
+ if gt_lower in name or name in gt_lower:
114
+ gt_rank = entry.get("rank", -1)
115
+ break
116
+
117
+ return CaseMetrics(
118
+ case_id=result.case_id,
119
+ dataset=result.dataset,
120
+ top1_correct=top1_correct,
121
+ reciprocal_rank=rr,
122
+ ground_truth_rank=gt_rank,
123
+ n_acquired=len(result.acquired_channels),
124
+ acquired_channels=result.acquired_channels,
125
+ committed_early=result.committed_early,
126
+ top1_confidence=top1_conf,
127
+ acquisition_cost=result.acquisition_cost,
128
+ total_case_cost=result.total_case_cost,
129
+ )
130
+
131
+
132
+ def aggregate_metrics(
133
+ case_metrics: list[CaseMetrics],
134
+ dataset_name: str,
135
+ n_bootstrap: int = None,
136
+ ) -> DatasetMetrics:
137
+ """Aggregate per-case metrics into dataset-level stats with bootstrap CIs."""
138
+ if n_bootstrap is None:
139
+ n_bootstrap = config.N_BOOTSTRAP
140
+
141
+ n = len(case_metrics)
142
+ if n == 0:
143
+ return DatasetMetrics(dataset=dataset_name, n_cases=0, top1_accuracy=0, mrr=0)
144
+
145
+ accuracies = np.array([int(cm.top1_correct) for cm in case_metrics])
146
+ rrs = np.array([cm.reciprocal_rank for cm in case_metrics])
147
+
148
+ top1_acc = float(np.mean(accuracies))
149
+ mrr = float(np.mean(rrs))
150
+
151
+ # Bootstrap CIs
152
+ acc_ci = _bootstrap_ci(accuracies, n_bootstrap)
153
+ mrr_ci = _bootstrap_ci(rrs, n_bootstrap)
154
+
155
+ # Channel request rates
156
+ channel_counts: dict[str, int] = {}
157
+ for cm in case_metrics:
158
+ for ch in cm.acquired_channels:
159
+ channel_counts[ch] = channel_counts.get(ch, 0) + 1
160
+ channel_rates = {ch: count / n for ch, count in channel_counts.items()}
161
+
162
+ return DatasetMetrics(
163
+ dataset=dataset_name,
164
+ n_cases=n,
165
+ top1_accuracy=top1_acc,
166
+ mrr=mrr,
167
+ top1_accuracy_ci=acc_ci,
168
+ mrr_ci=mrr_ci,
169
+ mean_channels_acquired=float(np.mean([cm.n_acquired for cm in case_metrics])),
170
+ early_commit_rate=float(np.mean([int(cm.committed_early) for cm in case_metrics])),
171
+ per_channel_request_rate=channel_rates,
172
+ mean_acquisition_cost=float(np.mean([cm.acquisition_cost for cm in case_metrics])),
173
+ mean_total_case_cost=float(np.mean([cm.total_case_cost for cm in case_metrics])),
174
+ )
175
+
176
+
177
+ def compute_acquisition_efficiency(
178
+ mrr_at_k: float,
179
+ mrr_passive: float,
180
+ mrr_oracle: float,
181
+ ) -> float:
182
+ """
183
+ Normalized Acquisition Efficiency.
184
+
185
+ AE(K) = (MRR_K - MRR_passive) / (MRR_oracle - MRR_passive)
186
+
187
+ Returns 0 if oracle = passive (no room for improvement),
188
+ can exceed 1 if active outperforms oracle (shouldn't happen normally).
189
+ """
190
+ denom = mrr_oracle - mrr_passive
191
+ if abs(denom) < 1e-8:
192
+ return 0.0
193
+ return (mrr_at_k - mrr_passive) / denom
194
+
195
+
196
+ def compute_acquisition_precision(
197
+ active_results: list[AgentResult],
198
+ passive_results: list[AgentResult],
199
+ cases: list[MedicalCase],
200
+ ) -> dict:
201
+ """
202
+ Acquisition Precision: when the agent requests info, does the diagnosis change?
203
+
204
+ Two sub-metrics:
205
+ - request_change_rate: fraction of acquisitions that changed the top-1 diagnosis
206
+ - change_correctness: among diagnosis changes, fraction that were improvements
207
+ """
208
+ assert len(active_results) == len(passive_results) == len(cases)
209
+
210
+ total_acquisitions = 0
211
+ diagnosis_changed = 0
212
+ change_improved = 0
213
+
214
+ for active, passive, case in zip(active_results, passive_results, cases):
215
+ passive_top1 = _get_top1_name(passive.final_ranking)
216
+ active_top1 = _get_top1_name(active.final_ranking)
217
+
218
+ n_acq = len(active.acquired_channels)
219
+ if n_acq > 0:
220
+ total_acquisitions += 1
221
+ if passive_top1 != active_top1:
222
+ diagnosis_changed += 1
223
+ # Did it change to the correct answer?
224
+ gt = case.ground_truth.lower().strip()
225
+ if gt in active_top1.lower() or active_top1.lower() in gt:
226
+ change_improved += 1
227
+
228
+ return {
229
+ "total_cases_with_acquisition": total_acquisitions,
230
+ "request_change_rate": (
231
+ diagnosis_changed / total_acquisitions if total_acquisitions > 0 else 0
232
+ ),
233
+ "change_correctness": (
234
+ change_improved / diagnosis_changed if diagnosis_changed > 0 else 0
235
+ ),
236
+ }
237
+
238
+
239
+ def compute_prompt_agreement(
240
+ results_by_variant: dict[str, list[AgentResult]],
241
+ ) -> dict:
242
+ """
243
+ Prompt sensitivity analysis: measure agreement across prompt variants.
244
+
245
+ Returns:
246
+ - top1_agreement: fraction of cases where all variants agree on top-1
247
+ - acquisition_agreement: fraction of cases where all variants request
248
+ the same first channel
249
+ """
250
+ variants = list(results_by_variant.keys())
251
+ if len(variants) < 2:
252
+ return {"top1_agreement": 1.0, "acquisition_agreement": 1.0}
253
+
254
+ # Align by case_id
255
+ case_ids = set()
256
+ for results in results_by_variant.values():
257
+ case_ids.update(r.case_id for r in results)
258
+
259
+ by_case: dict[str, dict[str, AgentResult]] = {}
260
+ for variant, results in results_by_variant.items():
261
+ for r in results:
262
+ if r.case_id not in by_case:
263
+ by_case[r.case_id] = {}
264
+ by_case[r.case_id][variant] = r
265
+
266
+ top1_agree_count = 0
267
+ acq_agree_count = 0
268
+ total = 0
269
+
270
+ for case_id, variant_results in by_case.items():
271
+ if len(variant_results) < len(variants):
272
+ continue # Skip cases not in all variants
273
+ total += 1
274
+
275
+ # Top-1 agreement
276
+ top1s = set()
277
+ for vr in variant_results.values():
278
+ top1s.add(_get_top1_name(vr.final_ranking).lower())
279
+ if len(top1s) == 1:
280
+ top1_agree_count += 1
281
+
282
+ # First acquisition agreement
283
+ first_acqs = set()
284
+ for vr in variant_results.values():
285
+ if vr.acquired_channels:
286
+ first_acqs.add(vr.acquired_channels[0])
287
+ else:
288
+ first_acqs.add("_committed_")
289
+ if len(first_acqs) == 1:
290
+ acq_agree_count += 1
291
+
292
+ return {
293
+ "top1_agreement": top1_agree_count / total if total > 0 else 0,
294
+ "acquisition_agreement": acq_agree_count / total if total > 0 else 0,
295
+ "n_cases_compared": total,
296
+ }
297
+
298
+
299
+ def compute_regret_analysis(
300
+ active_results: list[AgentResult],
301
+ oracle_results: list[AgentResult],
302
+ cases: list[MedicalCase],
303
+ ) -> dict:
304
+ """
305
+ Regret Analysis: when the agent gets a case wrong, could a different
306
+ acquisition strategy have saved it?
307
+
308
+ For each case where active got it wrong:
309
+ 1. Did the oracle get it right? (recoverable error)
310
+ 2. Which channels were available but not requested? (missed channels)
311
+ 3. Among recoverable errors, which missing channels correlate most
312
+ with oracle success? (high-regret channels)
313
+
314
+ Returns a rich dict with per-case traces and aggregate statistics.
315
+ """
316
+ assert len(active_results) == len(oracle_results) == len(cases)
317
+
318
+ per_case_regret = []
319
+ n_active_wrong = 0
320
+ n_oracle_right_when_active_wrong = 0 # recoverable
321
+ n_both_wrong = 0 # unrecoverable — VLM reasoning bottleneck
322
+ missed_channel_counts: dict[str, int] = {} # channels not requested in recoverable cases
323
+ missed_channel_total: dict[str, int] = {} # total times a channel was missed (all wrong)
324
+
325
+ for active, oracle, case in zip(active_results, oracle_results, cases):
326
+ active_rr = compute_reciprocal_rank(active.final_ranking, case.ground_truth, case.candidates)
327
+ oracle_rr = compute_reciprocal_rank(oracle.final_ranking, case.ground_truth, case.candidates)
328
+ active_correct = active_rr == 1.0
329
+ oracle_correct = oracle_rr == 1.0
330
+
331
+ if active_correct:
332
+ continue # No regret if agent got it right
333
+
334
+ n_active_wrong += 1
335
+
336
+ # Channels available but not acquired
337
+ all_requestable = set(case.requestable_channels.keys())
338
+ acquired = set(active.acquired_channels)
339
+ missed = all_requestable - acquired
340
+
341
+ case_entry = {
342
+ "case_id": case.case_id,
343
+ "ground_truth": case.ground_truth,
344
+ "active_top1": _get_top1_name(active.final_ranking),
345
+ "oracle_top1": _get_top1_name(oracle.final_ranking),
346
+ "active_correct": False,
347
+ "oracle_correct": oracle_correct,
348
+ "acquired_channels": list(acquired),
349
+ "missed_channels": list(missed),
350
+ "recoverable": oracle_correct,
351
+ "active_rr": active_rr,
352
+ "oracle_rr": oracle_rr,
353
+ }
354
+
355
+ for ch in missed:
356
+ missed_channel_total[ch] = missed_channel_total.get(ch, 0) + 1
357
+
358
+ if oracle_correct:
359
+ n_oracle_right_when_active_wrong += 1
360
+ for ch in missed:
361
+ missed_channel_counts[ch] = missed_channel_counts.get(ch, 0) + 1
362
+ else:
363
+ n_both_wrong += 1
364
+
365
+ per_case_regret.append(case_entry)
366
+
367
+ # Compute per-channel regret score: how often a missed channel appears
368
+ # in recoverable errors vs all errors
369
+ channel_regret_scores = {}
370
+ for ch in set(list(missed_channel_counts.keys()) + list(missed_channel_total.keys())):
371
+ recoverable_miss = missed_channel_counts.get(ch, 0)
372
+ total_miss = missed_channel_total.get(ch, 0)
373
+ # Regret score: fraction of times this channel was missed AND oracle succeeded
374
+ channel_regret_scores[ch] = {
375
+ "missed_in_recoverable": recoverable_miss,
376
+ "missed_in_all_wrong": total_miss,
377
+ "regret_rate": recoverable_miss / total_miss if total_miss > 0 else 0.0,
378
+ }
379
+
380
+ # Sort channels by regret rate descending
381
+ sorted_channels = sorted(
382
+ channel_regret_scores.items(),
383
+ key=lambda x: (-x[1]["regret_rate"], -x[1]["missed_in_recoverable"]),
384
+ )
385
+
386
+ return {
387
+ "n_cases": len(cases),
388
+ "n_active_wrong": n_active_wrong,
389
+ "n_recoverable": n_oracle_right_when_active_wrong,
390
+ "n_unrecoverable": n_both_wrong,
391
+ "recovery_rate": (
392
+ n_oracle_right_when_active_wrong / n_active_wrong
393
+ if n_active_wrong > 0 else 0.0
394
+ ),
395
+ "error_rate": n_active_wrong / len(cases) if cases else 0.0,
396
+ "channel_regret_scores": dict(sorted_channels),
397
+ "per_case_regret": per_case_regret,
398
+ "summary": {
399
+ "total_errors": n_active_wrong,
400
+ "recoverable_pct": (
401
+ n_oracle_right_when_active_wrong / n_active_wrong * 100
402
+ if n_active_wrong > 0 else 0.0
403
+ ),
404
+ "unrecoverable_pct": (
405
+ n_both_wrong / n_active_wrong * 100
406
+ if n_active_wrong > 0 else 0.0
407
+ ),
408
+ "highest_regret_channel": sorted_channels[0][0] if sorted_channels else None,
409
+ },
410
+ }
411
+
412
+
413
+ def compute_info_theoretic_metrics(
414
+ results: list[AgentResult],
415
+ ) -> dict:
416
+ """
417
+ Compute information-theoretic metrics from belief trajectories.
418
+
419
+ Extracts BeliefTrajectory objects from AgentResults and computes
420
+ aggregate entropy, information gain, and per-channel value metrics.
421
+ """
422
+ trajectories = [
423
+ r.belief_trajectory for r in results
424
+ if r.belief_trajectory and r.belief_trajectory.states
425
+ ]
426
+ if not trajectories:
427
+ return {"n_cases_with_trajectory": 0}
428
+
429
+ metrics = compute_information_metrics(trajectories)
430
+ metrics["n_cases_with_trajectory"] = len(trajectories)
431
+ return metrics
432
+
433
+
434
+ def _get_top1_name(ranking: list[dict]) -> str:
435
+ """Get the name of the top-ranked diagnosis."""
436
+ if not ranking:
437
+ return ""
438
+ return ranking[0].get("name", "")
439
+
440
+
441
+ def _bootstrap_ci(
442
+ values: np.ndarray, n_bootstrap: int = 1000, ci: float = 0.95
443
+ ) -> tuple[float, float]:
444
+ """Compute bootstrap confidence interval."""
445
+ if len(values) == 0:
446
+ return (0.0, 0.0)
447
+ rng = np.random.RandomState(config.SEED)
448
+ boot_means = []
449
+ for _ in range(n_bootstrap):
450
+ sample = rng.choice(values, size=len(values), replace=True)
451
+ boot_means.append(np.mean(sample))
452
+ alpha = (1 - ci) / 2
453
+ lower = float(np.percentile(boot_means, alpha * 100))
454
+ upper = float(np.percentile(boot_means, (1 - alpha) * 100))
455
+ return (lower, upper)
evaluation/analysis.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-dataset analysis and figure generation.
3
+
4
+ Produces the key figures for the paper:
5
+ 1. Acquisition Efficiency curves (all 3 datasets, shared y-axis)
6
+ 2. Per-channel request frequency heatmap
7
+ 3. Prompt sensitivity agreement matrix
8
+ 4. OLIVES biomarker-tier acquisition analysis
9
+ 5. NEJM difficulty-vs-acquisition scatter
10
+ """
11
+ import json
12
+ import logging
13
+ from pathlib import Path
14
+ from dataclasses import asdict
15
+
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib
19
+ import seaborn as sns
20
+ from scipy import stats
21
+
22
+ from agent import AgentResult
23
+ from datasets.base import MedicalCase
24
+ from evaluation import (
25
+ CaseMetrics,
26
+ DatasetMetrics,
27
+ evaluate_single_case,
28
+ aggregate_metrics,
29
+ compute_acquisition_efficiency,
30
+ compute_acquisition_precision,
31
+ compute_prompt_agreement,
32
+ compute_regret_analysis,
33
+ )
34
+ import config
35
+
36
+ matplotlib.rcParams["font.family"] = "serif"
37
+ matplotlib.rcParams["font.size"] = 11
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class ExperimentAnalyzer:
43
+ """Analyze and visualize results across all experiments."""
44
+
45
+ def __init__(self, results_dir: Path = None):
46
+ self.results_dir = results_dir or config.RESULTS_DIR
47
+ self.figures_dir = self.results_dir / "figures"
48
+ self.figures_dir.mkdir(parents=True, exist_ok=True)
49
+
50
+ def load_results(self, experiment_name: str) -> dict:
51
+ """Load saved experiment results."""
52
+ path = self.results_dir / f"{experiment_name}.json"
53
+ if not path.exists():
54
+ logger.error(f"Results file not found: {path}")
55
+ return {}
56
+ with open(path) as f:
57
+ return json.load(f)
58
+
59
+ def save_results(self, data: dict, experiment_name: str):
60
+ """Save experiment results."""
61
+ path = self.results_dir / f"{experiment_name}.json"
62
+ with open(path, "w") as f:
63
+ json.dump(data, f, indent=2, default=str)
64
+ logger.info(f"Results saved to {path}")
65
+
66
+ # ================================================================
67
+ # Figure 1: Acquisition Efficiency Curves
68
+ # ================================================================
69
+
70
+ def plot_acquisition_efficiency(
71
+ self,
72
+ results_by_dataset: dict[str, dict[int, DatasetMetrics]],
73
+ passive_metrics: dict[str, DatasetMetrics],
74
+ oracle_metrics: dict[str, DatasetMetrics],
75
+ save_name: str = "fig1_acquisition_efficiency",
76
+ ):
77
+ """
78
+ Main result figure: normalized acquisition efficiency vs budget K.
79
+
80
+ Args:
81
+ results_by_dataset: {dataset_name: {K: DatasetMetrics}}
82
+ passive_metrics: {dataset_name: DatasetMetrics} at K=0
83
+ oracle_metrics: {dataset_name: DatasetMetrics} with all channels
84
+ """
85
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
86
+
87
+ colors = {"midas": "#E07A5F", "nejm": "#3D405B", "olives": "#81B29A"}
88
+ markers = {"midas": "o", "nejm": "s", "olives": "D"}
89
+ labels = {"midas": "MIDAS (Dermatology)", "nejm": "NEJM (Multi-Specialty)",
90
+ "olives": "OLIVES (Ophthalmology)"}
91
+
92
+ # Left panel: Raw MRR vs K
93
+ ax = axes[0]
94
+ for ds_name in ["midas", "nejm", "olives"]:
95
+ if ds_name not in results_by_dataset:
96
+ continue
97
+ ks = sorted(results_by_dataset[ds_name].keys())
98
+ mrrs = [results_by_dataset[ds_name][k].mrr for k in ks]
99
+ cis = [results_by_dataset[ds_name][k].mrr_ci for k in ks]
100
+
101
+ # Add passive at K=0
102
+ all_k = [0] + list(ks)
103
+ all_mrr = [passive_metrics[ds_name].mrr] + mrrs
104
+ all_lower = [passive_metrics[ds_name].mrr_ci[0]] + [c[0] for c in cis]
105
+ all_upper = [passive_metrics[ds_name].mrr_ci[1]] + [c[1] for c in cis]
106
+
107
+ ax.plot(all_k, all_mrr, color=colors[ds_name], marker=markers[ds_name],
108
+ label=labels[ds_name], linewidth=2, markersize=7)
109
+ ax.fill_between(all_k, all_lower, all_upper, alpha=0.15, color=colors[ds_name])
110
+
111
+ # Oracle line
112
+ ax.axhline(y=oracle_metrics[ds_name].mrr, color=colors[ds_name],
113
+ linestyle="--", alpha=0.4, linewidth=1)
114
+
115
+ ax.set_xlabel("Acquisition Budget (K)")
116
+ ax.set_ylabel("Mean Reciprocal Rank (MRR)")
117
+ ax.set_title("(a) Diagnostic Quality vs. Budget")
118
+ ax.legend(fontsize=9)
119
+ ax.set_xticks(range(max(4, max(max(r.keys()) for r in results_by_dataset.values()) + 1)))
120
+ ax.grid(True, alpha=0.3)
121
+
122
+ # Right panel: Normalized Acquisition Efficiency
123
+ ax = axes[1]
124
+ for ds_name in ["midas", "nejm", "olives"]:
125
+ if ds_name not in results_by_dataset:
126
+ continue
127
+ ks = sorted(results_by_dataset[ds_name].keys())
128
+ effs = []
129
+ for k in ks:
130
+ ae = compute_acquisition_efficiency(
131
+ results_by_dataset[ds_name][k].mrr,
132
+ passive_metrics[ds_name].mrr,
133
+ oracle_metrics[ds_name].mrr,
134
+ )
135
+ effs.append(ae)
136
+
137
+ all_k = [0] + list(ks)
138
+ all_eff = [0.0] + effs
139
+
140
+ ax.plot(all_k, all_eff, color=colors[ds_name], marker=markers[ds_name],
141
+ label=labels[ds_name], linewidth=2, markersize=7)
142
+
143
+ ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, linewidth=1,
144
+ label="Oracle ceiling")
145
+ ax.set_xlabel("Acquisition Budget (K)")
146
+ ax.set_ylabel("Acquisition Efficiency")
147
+ ax.set_title("(b) Normalized Efficiency")
148
+ ax.legend(fontsize=9)
149
+ ax.set_ylim(-0.05, 1.15)
150
+ ax.grid(True, alpha=0.3)
151
+
152
+ plt.tight_layout()
153
+ save_path = self.figures_dir / f"{save_name}.pdf"
154
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
155
+ plt.close(fig)
156
+ logger.info(f"Saved figure: {save_path}")
157
+
158
+ # ================================================================
159
+ # Figure 2: Per-Channel Request Frequency
160
+ # ================================================================
161
+
162
+ def plot_channel_request_heatmap(
163
+ self,
164
+ results_by_dataset: dict[str, list[AgentResult]],
165
+ save_name: str = "fig2_channel_requests",
166
+ ):
167
+ """Heatmap showing which channels the agent requests most, by dataset."""
168
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
169
+ dataset_names = ["midas", "nejm", "olives"]
170
+ titles = ["MIDAS", "NEJM", "OLIVES"]
171
+
172
+ for idx, (ds_name, title) in enumerate(zip(dataset_names, titles)):
173
+ if ds_name not in results_by_dataset:
174
+ continue
175
+
176
+ results = results_by_dataset[ds_name]
177
+
178
+ # Count first-request frequency
179
+ first_requests: dict[str, int] = {}
180
+ for r in results:
181
+ if r.acquired_channels:
182
+ ch = r.acquired_channels[0]
183
+ first_requests[ch] = first_requests.get(ch, 0) + 1
184
+
185
+ # Count overall request frequency
186
+ all_requests: dict[str, int] = {}
187
+ for r in results:
188
+ for ch in r.acquired_channels:
189
+ all_requests[ch] = all_requests.get(ch, 0) + 1
190
+
191
+ if not all_requests:
192
+ continue
193
+
194
+ channels = sorted(all_requests.keys())
195
+ n = len(results)
196
+
197
+ ax = axes[idx]
198
+ data = np.array([
199
+ [first_requests.get(ch, 0) / n for ch in channels],
200
+ [all_requests.get(ch, 0) / n for ch in channels],
201
+ ])
202
+
203
+ sns.heatmap(
204
+ data,
205
+ ax=ax,
206
+ xticklabels=[ch.replace("_", "\n") for ch in channels],
207
+ yticklabels=["First\nRequest", "Any\nRequest"],
208
+ annot=True,
209
+ fmt=".2f",
210
+ cmap="YlOrRd",
211
+ vmin=0,
212
+ vmax=1,
213
+ cbar_kws={"shrink": 0.8},
214
+ )
215
+ ax.set_title(title)
216
+
217
+ plt.tight_layout()
218
+ save_path = self.figures_dir / f"{save_name}.pdf"
219
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
220
+ plt.close(fig)
221
+ logger.info(f"Saved figure: {save_path}")
222
+
223
+ # ================================================================
224
+ # Figure 3: OLIVES Biomarker Tier Analysis
225
+ # ================================================================
226
+
227
+ def plot_olives_biomarker_tiers(
228
+ self,
229
+ results: list[AgentResult],
230
+ cases: list[MedicalCase],
231
+ save_name: str = "fig3_olives_biomarker_tiers",
232
+ ):
233
+ """
234
+ For OLIVES: does the agent request OCT more for OCT-dependent
235
+ biomarkers than for fundus-visible ones?
236
+ """
237
+ oct_request_by_tier: dict[str, list[bool]] = {
238
+ "fundus_visible": [],
239
+ "oct_dependent": [],
240
+ }
241
+
242
+ for result, case in zip(results, cases):
243
+ if case.dataset != "olives":
244
+ continue
245
+ tier_labels = case.metadata.get("biomarker_tier_labels", {})
246
+ requested_oct = "oct_scan" in result.acquired_channels
247
+
248
+ # For cases where the eye has fundus-visible biomarkers
249
+ if tier_labels.get("fundus_visible"):
250
+ oct_request_by_tier["fundus_visible"].append(requested_oct)
251
+
252
+ # For cases where the eye has OCT-dependent biomarkers
253
+ if tier_labels.get("oct_dependent"):
254
+ oct_request_by_tier["oct_dependent"].append(requested_oct)
255
+
256
+ fig, ax = plt.subplots(figsize=(6, 4))
257
+
258
+ tiers = ["fundus_visible", "oct_dependent"]
259
+ tier_labels = ["Fundus-Visible\nBiomarkers", "OCT-Dependent\nBiomarkers"]
260
+ rates = []
261
+ cis_lower = []
262
+ cis_upper = []
263
+
264
+ for tier in tiers:
265
+ vals = oct_request_by_tier.get(tier, [])
266
+ if vals:
267
+ rate = np.mean(vals)
268
+ rates.append(rate)
269
+ # Wilson CI for proportions
270
+ n = len(vals)
271
+ z = 1.96
272
+ p = rate
273
+ denom = 1 + z ** 2 / n
274
+ center = (p + z ** 2 / (2 * n)) / denom
275
+ margin = z * np.sqrt((p * (1 - p) + z ** 2 / (4 * n)) / n) / denom
276
+ cis_lower.append(center - margin)
277
+ cis_upper.append(center + margin)
278
+ else:
279
+ rates.append(0)
280
+ cis_lower.append(0)
281
+ cis_upper.append(0)
282
+
283
+ colors_bar = ["#81B29A", "#E07A5F"]
284
+ bars = ax.bar(tier_labels, rates, color=colors_bar, edgecolor="white", width=0.5)
285
+ ax.errorbar(
286
+ tier_labels, rates,
287
+ yerr=[np.array(rates) - np.array(cis_lower),
288
+ np.array(cis_upper) - np.array(rates)],
289
+ fmt="none", ecolor="black", capsize=5,
290
+ )
291
+
292
+ ax.set_ylabel("OCT Request Rate")
293
+ ax.set_title("Agent's OCT Request Rate by Biomarker Type")
294
+ ax.set_ylim(0, 1.05)
295
+ ax.grid(True, axis="y", alpha=0.3)
296
+
297
+ # Add counts
298
+ for i, tier in enumerate(tiers):
299
+ n = len(oct_request_by_tier.get(tier, []))
300
+ ax.text(i, rates[i] + 0.05, f"n={n}", ha="center", fontsize=10)
301
+
302
+ plt.tight_layout()
303
+ save_path = self.figures_dir / f"{save_name}.pdf"
304
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
305
+ plt.close(fig)
306
+ logger.info(f"Saved figure: {save_path}")
307
+
308
+ # ================================================================
309
+ # Figure 4: NEJM Difficulty vs Acquisition Behavior
310
+ # ================================================================
311
+
312
+ def plot_nejm_difficulty_analysis(
313
+ self,
314
+ results: list[AgentResult],
315
+ cases: list[MedicalCase],
316
+ save_name: str = "fig4_nejm_difficulty",
317
+ ):
318
+ """
319
+ Scatter: human difficulty (physician correct rate) vs
320
+ agent's acquisition behavior (N channels requested + early commit).
321
+ """
322
+ difficulties = []
323
+ n_acquired = []
324
+ committed_early = []
325
+
326
+ for result, case in zip(results, cases):
327
+ if case.dataset != "nejm":
328
+ continue
329
+ votes = case.metadata.get("votes", {})
330
+ if not votes:
331
+ continue
332
+
333
+ # Compute human difficulty (proportion correct)
334
+ total_votes = sum(float(v) for v in votes.values())
335
+ if total_votes == 0:
336
+ continue
337
+ gt = case.ground_truth
338
+ human_correct = 0.0
339
+ for key, val in votes.items():
340
+ if key in gt or gt.startswith(key):
341
+ human_correct = float(val) / total_votes if total_votes > 1 else float(val)
342
+ break
343
+
344
+ difficulties.append(human_correct)
345
+ n_acquired.append(len(result.acquired_channels))
346
+ committed_early.append(result.committed_early)
347
+
348
+ if not difficulties:
349
+ logger.warning("No NEJM cases with difficulty data found")
350
+ return
351
+
352
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
353
+
354
+ # Left: Difficulty vs N channels acquired
355
+ ax = axes[0]
356
+ ax.scatter(difficulties, n_acquired, alpha=0.5, s=30, color="#3D405B", edgecolors="white")
357
+ # Add trend line
358
+ if len(difficulties) > 10:
359
+ z = np.polyfit(difficulties, n_acquired, 1)
360
+ p = np.poly1d(z)
361
+ x_line = np.linspace(min(difficulties), max(difficulties), 100)
362
+ ax.plot(x_line, p(x_line), "--", color="#E07A5F", linewidth=2,
363
+ label=f"Trend (slope={z[0]:.2f})")
364
+ # Correlation
365
+ r, pval = stats.pearsonr(difficulties, n_acquired)
366
+ ax.text(0.05, 0.95, f"r={r:.3f}, p={pval:.3f}",
367
+ transform=ax.transAxes, fontsize=9, verticalalignment="top")
368
+ ax.set_xlabel("Human Correct Rate (easier →)")
369
+ ax.set_ylabel("Channels Acquired by Agent")
370
+ ax.set_title("(a) Case Difficulty vs. Acquisition Amount")
371
+ ax.legend(fontsize=9)
372
+ ax.grid(True, alpha=0.3)
373
+
374
+ # Right: Difficulty bins vs early commit rate
375
+ ax = axes[1]
376
+ diff_arr = np.array(difficulties)
377
+ commit_arr = np.array(committed_early, dtype=float)
378
+ bins = [0, 0.25, 0.50, 0.75, 1.01]
379
+ bin_labels = ["<25%", "25-50%", "50-75%", ">75%"]
380
+ bin_rates = []
381
+ bin_ns = []
382
+
383
+ for i in range(len(bins) - 1):
384
+ mask = (diff_arr >= bins[i]) & (diff_arr < bins[i + 1])
385
+ if mask.sum() > 0:
386
+ bin_rates.append(commit_arr[mask].mean())
387
+ bin_ns.append(mask.sum())
388
+ else:
389
+ bin_rates.append(0)
390
+ bin_ns.append(0)
391
+
392
+ bar_colors = ["#E07A5F", "#F2CC8F", "#81B29A", "#3D405B"]
393
+ bars = ax.bar(bin_labels, bin_rates, color=bar_colors, edgecolor="white", width=0.6)
394
+ for i, (rate, n) in enumerate(zip(bin_rates, bin_ns)):
395
+ ax.text(i, rate + 0.02, f"n={n}", ha="center", fontsize=9)
396
+ ax.set_xlabel("Human Correct Rate (easier →)")
397
+ ax.set_ylabel("Agent Early Commit Rate")
398
+ ax.set_title("(b) Early Commitment vs. Difficulty")
399
+ ax.set_ylim(0, 1.05)
400
+ ax.grid(True, axis="y", alpha=0.3)
401
+
402
+ plt.tight_layout()
403
+ save_path = self.figures_dir / f"{save_name}.pdf"
404
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
405
+ plt.close(fig)
406
+ logger.info(f"Saved figure: {save_path}")
407
+
408
+ # ================================================================
409
+ # Figure 5: Regret Analysis
410
+ # ================================================================
411
+
412
+ def plot_regret_analysis(
413
+ self,
414
+ regret: dict,
415
+ dataset_name: str = "",
416
+ save_name: str = "fig5_regret_analysis",
417
+ ):
418
+ """
419
+ Visualize regret analysis results.
420
+
421
+ Left: Stacked bar showing recoverable vs unrecoverable errors.
422
+ Right: Per-channel regret scores (which missed channels cost the most).
423
+ """
424
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
425
+ title_suffix = f" — {dataset_name.upper()}" if dataset_name else ""
426
+
427
+ # ---- Left panel: Error decomposition ----
428
+ ax = axes[0]
429
+ summary = regret["summary"]
430
+ n_correct = regret["n_cases"] - regret["n_active_wrong"]
431
+ n_recoverable = regret["n_recoverable"]
432
+ n_unrecoverable = regret["n_unrecoverable"]
433
+
434
+ categories = ["Agent\nCorrect", "Recoverable\nErrors", "Unrecoverable\nErrors"]
435
+ values = [n_correct, n_recoverable, n_unrecoverable]
436
+ colors_bar = ["#81B29A", "#F2CC8F", "#E07A5F"]
437
+
438
+ bars = ax.bar(categories, values, color=colors_bar, edgecolor="white", width=0.55)
439
+ for bar, val in zip(bars, values):
440
+ pct = val / regret["n_cases"] * 100 if regret["n_cases"] > 0 else 0
441
+ ax.text(
442
+ bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
443
+ f"{val}\n({pct:.0f}%)", ha="center", fontsize=10,
444
+ )
445
+
446
+ ax.set_ylabel("Number of Cases")
447
+ ax.set_title(f"(a) Error Decomposition{title_suffix}")
448
+ ax.grid(True, axis="y", alpha=0.3)
449
+
450
+ # ---- Right panel: Per-channel regret ----
451
+ ax = axes[1]
452
+ channel_scores = regret["channel_regret_scores"]
453
+
454
+ if channel_scores:
455
+ channels = list(channel_scores.keys())
456
+ regret_rates = [channel_scores[ch]["regret_rate"] for ch in channels]
457
+ miss_counts = [channel_scores[ch]["missed_in_recoverable"] for ch in channels]
458
+
459
+ # Sort by regret rate
460
+ sorted_idx = sorted(range(len(channels)), key=lambda i: -regret_rates[i])
461
+ channels = [channels[i] for i in sorted_idx]
462
+ regret_rates = [regret_rates[i] for i in sorted_idx]
463
+ miss_counts = [miss_counts[i] for i in sorted_idx]
464
+
465
+ y_pos = range(len(channels))
466
+ bar_colors = plt.cm.YlOrRd(np.linspace(0.3, 0.9, len(channels)))
467
+ bars = ax.barh(
468
+ y_pos, regret_rates, color=bar_colors, edgecolor="white", height=0.6,
469
+ )
470
+
471
+ ax.set_yticks(y_pos)
472
+ ax.set_yticklabels([ch.replace("_", " ").title() for ch in channels], fontsize=9)
473
+ ax.set_xlabel("Regret Rate")
474
+ ax.set_xlim(0, 1.05)
475
+ ax.invert_yaxis()
476
+
477
+ # Annotate with counts
478
+ for i, (rate, count) in enumerate(zip(regret_rates, miss_counts)):
479
+ ax.text(
480
+ rate + 0.02, i, f"n={count}",
481
+ va="center", fontsize=9, color="#333",
482
+ )
483
+ else:
484
+ ax.text(0.5, 0.5, "No channel data", ha="center", va="center",
485
+ transform=ax.transAxes, fontsize=12)
486
+
487
+ ax.set_title(f"(b) Channel Regret Scores{title_suffix}")
488
+ ax.grid(True, axis="x", alpha=0.3)
489
+
490
+ plt.tight_layout()
491
+ save_path = self.figures_dir / f"{save_name}.pdf"
492
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
493
+ plt.close(fig)
494
+ logger.info(f"Saved figure: {save_path}")
495
+
496
+ def print_regret_summary(self, regret: dict):
497
+ """Print a concise text summary of regret analysis."""
498
+ s = regret["summary"]
499
+ print("\n" + "=" * 55)
500
+ print(" REGRET ANALYSIS")
501
+ print("=" * 55)
502
+ print(f" Total cases: {regret['n_cases']}")
503
+ print(f" Agent errors: {s['total_errors']} ({regret['error_rate']*100:.1f}%)")
504
+ print(f" Recoverable: {regret['n_recoverable']} ({s['recoverable_pct']:.1f}% of errors)")
505
+ print(f" Unrecoverable: {regret['n_unrecoverable']} ({s['unrecoverable_pct']:.1f}% of errors)")
506
+ print(f" Highest-regret channel: {s['highest_regret_channel']}")
507
+ print()
508
+ print(" Per-channel regret:")
509
+ for ch, scores in regret["channel_regret_scores"].items():
510
+ print(f" {ch:<25} regret={scores['regret_rate']:.2f} "
511
+ f"(missed in {scores['missed_in_recoverable']}/{scores['missed_in_all_wrong']} errors)")
512
+ print("=" * 55)
513
+
514
+ # ================================================================
515
+ # Summary Table
516
+ # ================================================================
517
+
518
+ def print_summary_table(
519
+ self,
520
+ all_metrics: dict[str, dict[str, DatasetMetrics]],
521
+ ):
522
+ """
523
+ Print the main results table.
524
+
525
+ Args:
526
+ all_metrics: {condition: {dataset: DatasetMetrics}}
527
+ where condition is "passive", "K=1", "K=2", "K=3",
528
+ "fixed_order", "oracle"
529
+ """
530
+ header = f"{'Condition':<15} {'Dataset':<12} {'Top-1 Acc':<15} {'MRR':<15} {'Avg K':<8}"
531
+ print("=" * len(header))
532
+ print(header)
533
+ print("=" * len(header))
534
+
535
+ for condition in ["passive", "K=1", "K=2", "K=3", "fixed_order", "oracle"]:
536
+ if condition not in all_metrics:
537
+ continue
538
+ for ds in ["midas", "nejm", "olives"]:
539
+ if ds not in all_metrics[condition]:
540
+ continue
541
+ m = all_metrics[condition][ds]
542
+ acc_str = f"{m.top1_accuracy:.3f} ({m.top1_accuracy_ci[0]:.3f}-{m.top1_accuracy_ci[1]:.3f})"
543
+ mrr_str = f"{m.mrr:.3f} ({m.mrr_ci[0]:.3f}-{m.mrr_ci[1]:.3f})"
544
+ print(f"{condition:<15} {ds:<12} {acc_str:<15} {mrr_str:<15} {m.mean_channels_acquired:<8.1f}")
545
+
546
+ print("=" * len(header))
information_gain.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Information-theoretic computation for ActiveMedAgent.
3
+
4
+ Provides grounded entropy and expected information gain (EIG) computation
5
+ from the agent's reported probability distributions. This transforms the
6
+ "information-theoretic framing" from a prompt label into actual computation.
7
+
8
+ Key concepts:
9
+ - Belief State: The agent's probability distribution over candidate diagnoses
10
+ - Shannon Entropy: H(p) = -sum(p_i * log2(p_i)) — measures diagnostic uncertainty
11
+ - Information Gain: H(before) - H(after) — how much a channel reduced uncertainty
12
+ - Expected Information Gain (EIG): Estimated reduction in entropy from acquiring a channel
13
+ - Value of Information (VoI): Whether acquiring more data is worth the cost
14
+
15
+ No training required — these are computed analytically from the probability
16
+ distributions the agent reports through tool calls at each step.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ import logging
22
+ from dataclasses import dataclass, field
23
+
24
+ import numpy as np
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class BeliefState:
31
+ """
32
+ The agent's probability distribution over candidate diagnoses at a given step.
33
+
34
+ Extracted directly from the tool call's `current_differential` parameter,
35
+ so no parsing heuristics are needed.
36
+ """
37
+ step: int
38
+ distribution: dict[str, float] # {diagnosis_name: probability}
39
+ entropy: float = 0.0
40
+ channel_acquired: str | None = None
41
+
42
+ def __post_init__(self):
43
+ self.entropy = compute_entropy(self.distribution)
44
+
45
+
46
+ @dataclass
47
+ class BeliefTrajectory:
48
+ """
49
+ Full trajectory of belief states across the acquisition process.
50
+
51
+ Tracks how the agent's uncertainty evolves as it acquires information,
52
+ enabling information-theoretic analysis of acquisition quality.
53
+ """
54
+ case_id: str
55
+ states: list[BeliefState] = field(default_factory=list)
56
+
57
+ @property
58
+ def initial_entropy(self) -> float:
59
+ return self.states[0].entropy if self.states else 0.0
60
+
61
+ @property
62
+ def final_entropy(self) -> float:
63
+ return self.states[-1].entropy if self.states else 0.0
64
+
65
+ @property
66
+ def total_information_gain(self) -> float:
67
+ """Total reduction in entropy across all acquisitions."""
68
+ return self.initial_entropy - self.final_entropy
69
+
70
+ @property
71
+ def per_step_information_gain(self) -> list[float]:
72
+ """Information gain at each acquisition step."""
73
+ gains = []
74
+ for i in range(1, len(self.states)):
75
+ gains.append(self.states[i - 1].entropy - self.states[i].entropy)
76
+ return gains
77
+
78
+ @property
79
+ def entropy_trajectory(self) -> list[float]:
80
+ """Entropy at each step."""
81
+ return [s.entropy for s in self.states]
82
+
83
+ @property
84
+ def information_efficiency(self) -> float:
85
+ """
86
+ Information efficiency: actual IG / maximum possible IG.
87
+
88
+ Maximum possible IG is going from initial entropy to 0 (perfect certainty).
89
+ Returns ratio in [0, 1].
90
+ """
91
+ if self.initial_entropy < 1e-10:
92
+ return 1.0 # Already certain
93
+ return self.total_information_gain / self.initial_entropy
94
+
95
+ def get_channel_information_values(self) -> dict[str, float]:
96
+ """Map each acquired channel to its observed information gain."""
97
+ values = {}
98
+ for i in range(1, len(self.states)):
99
+ ch = self.states[i].channel_acquired
100
+ if ch:
101
+ values[ch] = self.states[i - 1].entropy - self.states[i].entropy
102
+ return values
103
+
104
+
105
+ # ============================================================
106
+ # Core Computations
107
+ # ============================================================
108
+
109
+ def compute_entropy(distribution: dict[str, float]) -> float:
110
+ """
111
+ Shannon entropy H(p) = -sum(p_i * log2(p_i)) in bits.
112
+
113
+ Handles edge cases: p=0 contributes 0, normalizes if sum != 1.
114
+ """
115
+ probs = np.array(list(distribution.values()), dtype=np.float64)
116
+
117
+ # Normalize if needed (VLM probabilities may not sum exactly to 1)
118
+ total = probs.sum()
119
+ if total < 1e-10:
120
+ return 0.0
121
+ probs = probs / total
122
+
123
+ # Compute entropy, handling p=0
124
+ entropy = 0.0
125
+ for p in probs:
126
+ if p > 1e-15:
127
+ entropy -= p * math.log2(p)
128
+ return entropy
129
+
130
+
131
+ def compute_kl_divergence(p: dict[str, float], q: dict[str, float]) -> float:
132
+ """
133
+ KL divergence D_KL(p || q) = sum(p_i * log2(p_i / q_i)).
134
+
135
+ Measures how much the belief shifted from q (prior) to p (posterior).
136
+ """
137
+ all_keys = set(list(p.keys()) + list(q.keys()))
138
+ p_arr = np.array([p.get(k, 1e-10) for k in all_keys], dtype=np.float64)
139
+ q_arr = np.array([q.get(k, 1e-10) for k in all_keys], dtype=np.float64)
140
+
141
+ # Normalize
142
+ p_arr = p_arr / p_arr.sum()
143
+ q_arr = q_arr / q_arr.sum()
144
+
145
+ # Smoothing to avoid log(0)
146
+ q_arr = np.maximum(q_arr, 1e-10)
147
+
148
+ kl = 0.0
149
+ for pi, qi in zip(p_arr, q_arr):
150
+ if pi > 1e-15:
151
+ kl += pi * math.log2(pi / qi)
152
+ return kl
153
+
154
+
155
+ def estimate_expected_information_gain(
156
+ current_distribution: dict[str, float],
157
+ channel_name: str,
158
+ expected_impact: dict[str, str],
159
+ candidates: list[str],
160
+ ) -> float:
161
+ """
162
+ Estimate expected information gain (EIG) for a candidate channel.
163
+
164
+ Uses the agent's stated expected_impact (from tool call) to estimate
165
+ how much the entropy would decrease. This is a lightweight approximation:
166
+ we model two scenarios (positive/negative finding) and compute the
167
+ expected entropy reduction.
168
+
169
+ Args:
170
+ current_distribution: Current belief state
171
+ channel_name: Channel being evaluated
172
+ expected_impact: {"if_positive": diagnosis_name, "if_negative": diagnosis_name}
173
+ candidates: All candidate diagnoses
174
+
175
+ Returns:
176
+ Estimated information gain in bits
177
+ """
178
+ current_entropy = compute_entropy(current_distribution)
179
+
180
+ # Model the positive scenario: the indicated diagnosis gets boosted
181
+ pos_target = expected_impact.get("if_positive", "")
182
+ neg_target = expected_impact.get("if_negative", "")
183
+
184
+ # Estimate posterior distributions under each scenario
185
+ pos_posterior = _shift_belief(current_distribution, pos_target, boost=0.3)
186
+ neg_posterior = _shift_belief(current_distribution, neg_target, boost=0.3)
187
+
188
+ # Weight scenarios by current probability of the positive-target diagnosis
189
+ p_positive = current_distribution.get(pos_target, 0.5)
190
+ p_negative = 1.0 - p_positive
191
+
192
+ expected_posterior_entropy = (
193
+ p_positive * compute_entropy(pos_posterior)
194
+ + p_negative * compute_entropy(neg_posterior)
195
+ )
196
+
197
+ eig = current_entropy - expected_posterior_entropy
198
+ return max(0.0, eig) # EIG should be non-negative
199
+
200
+
201
+ def _shift_belief(
202
+ distribution: dict[str, float],
203
+ target: str,
204
+ boost: float = 0.3,
205
+ ) -> dict[str, float]:
206
+ """
207
+ Shift probability mass toward a target diagnosis.
208
+
209
+ Simple model: add `boost` to target, renormalize.
210
+ Used for EIG estimation only.
211
+ """
212
+ result = dict(distribution)
213
+
214
+ # Find best matching key (case-insensitive)
215
+ matched_key = None
216
+ target_lower = target.lower().strip()
217
+ for key in result:
218
+ if target_lower in key.lower() or key.lower() in target_lower:
219
+ matched_key = key
220
+ break
221
+
222
+ if matched_key is None:
223
+ return result
224
+
225
+ result[matched_key] = result.get(matched_key, 0.0) + boost
226
+
227
+ # Renormalize
228
+ total = sum(result.values())
229
+ if total > 0:
230
+ result = {k: v / total for k, v in result.items()}
231
+
232
+ return result
233
+
234
+
235
+ # ============================================================
236
+ # Stopping Criterion: When Has the Agent Gathered Enough?
237
+ # ============================================================
238
+
239
+ def should_commit(
240
+ trajectory: BeliefTrajectory,
241
+ available_channels: list[str],
242
+ min_steps: int = 0,
243
+ ) -> tuple[bool, str]:
244
+ """
245
+ Principled stopping criterion based on the agent's belief trajectory.
246
+
247
+ CRITICAL DESIGN PRINCIPLE: Never trust raw VLM probabilities from a
248
+ single observation. Weaker models (GPT-4o-mini) routinely assign 0.85
249
+ to wrong diagnoses after seeing just one image. Stopping criteria must
250
+ be grounded in OBSERVED BELIEF DYNAMICS (how beliefs changed after
251
+ seeing evidence), not in the raw probability the VLM reports.
252
+
253
+ Three conditions, all requiring evidence of belief stability:
254
+
255
+ 1. CONVERGENCE: The last acquisition produced negligible IG (< 0.05 bits).
256
+ Requires >= 2 belief states. If new evidence doesn't change the
257
+ agent's mind, further evidence probably won't either.
258
+
259
+ 2. CONFIRMED DOMINANCE: The top diagnosis has probability >= 0.90 AND
260
+ the gap to #2 is >= 0.40, AND the agent has acquired >= 2 channels.
261
+ Raw first-impression confidence is meaningless — dominance only
262
+ counts after the belief has SURVIVED multiple evidence updates.
263
+
264
+ 3. DIMINISHING RETURNS: The last 2 acquisitions both had IG < 0.1 bits.
265
+ Requires >= 3 belief states. The agent hit a plateau.
266
+
267
+ Returns:
268
+ (should_commit: bool, reason: str)
269
+ """
270
+ n_states = len(trajectory.states)
271
+
272
+ if n_states < max(1, min_steps):
273
+ return False, "min_steps not reached"
274
+
275
+ if not trajectory.states:
276
+ return False, "no belief states yet"
277
+
278
+ # Count actual acquisitions (states with a channel acquired)
279
+ n_acquired = sum(
280
+ 1 for s in trajectory.states if s.channel_acquired is not None
281
+ )
282
+
283
+ latest = trajectory.states[-1]
284
+ dist = latest.distribution
285
+
286
+ if not dist:
287
+ return False, "empty distribution"
288
+
289
+ # Normalize
290
+ total = sum(dist.values())
291
+ if total < 1e-10:
292
+ return False, "zero distribution"
293
+ probs = sorted(dist.values(), reverse=True)
294
+ probs = [p / total for p in probs]
295
+
296
+ top1_prob = probs[0] if probs else 0
297
+ top2_prob = probs[1] if len(probs) > 1 else 0
298
+ gap = top1_prob - top2_prob
299
+
300
+ # Condition 1: CONVERGENCE — last step had negligible IG
301
+ # Requires at least 2 states (before/after an acquisition)
302
+ if n_states >= 2:
303
+ last_ig = (
304
+ trajectory.states[-2].entropy - trajectory.states[-1].entropy
305
+ )
306
+ if last_ig < 0.05 and n_acquired >= 1:
307
+ return True, (
308
+ f"convergence: last IG={last_ig:.3f} bits < 0.05 threshold "
309
+ f"(after {n_acquired} acquisition(s))"
310
+ )
311
+
312
+ # Condition 2: CONFIRMED DOMINANCE — high confidence AFTER evidence
313
+ # Must have acquired >= 2 channels. A first-impression 0.85 is not
314
+ # dominance — it's overconfidence. True dominance is when the belief
315
+ # stays dominant after being tested by new evidence.
316
+ if n_acquired >= 2 and top1_prob >= 0.90 and gap >= 0.40:
317
+ return True, (
318
+ f"confirmed dominance: top1={top1_prob:.2f}, gap={gap:.2f} "
319
+ f"(after {n_acquired} acquisitions)"
320
+ )
321
+
322
+ # Condition 3: DIMINISHING RETURNS — last 2 acquisitions both low IG
323
+ # Requires at least 3 states
324
+ if n_states >= 3:
325
+ ig_n1 = trajectory.states[-3].entropy - trajectory.states[-2].entropy
326
+ ig_n2 = trajectory.states[-2].entropy - trajectory.states[-1].entropy
327
+ if ig_n1 < 0.1 and ig_n2 < 0.1 and n_acquired >= 2:
328
+ return True, (
329
+ f"diminishing returns: last 2 IGs={ig_n1:.3f}, {ig_n2:.3f} "
330
+ f"(after {n_acquired} acquisitions)"
331
+ )
332
+
333
+ # No remaining channels
334
+ if not available_channels:
335
+ return True, "no channels remaining"
336
+
337
+ return False, "continue acquiring"
338
+
339
+
340
+ def compute_value_of_information(
341
+ trajectory: BeliefTrajectory,
342
+ n_remaining_channels: int,
343
+ ) -> float:
344
+ """
345
+ Estimate the value of continuing to acquire information.
346
+
347
+ Uses the trajectory's IG history to extrapolate whether the next
348
+ acquisition would be worth it. Returns a score in [0, 1]:
349
+ - Near 0: little value in continuing (should commit)
350
+ - Near 1: high value in continuing (should acquire)
351
+
352
+ Method: weighted average of recent IG values, normalized by initial
353
+ entropy. Decays with the number of remaining channels (diminishing
354
+ marginal returns).
355
+ """
356
+ if not trajectory.states or n_remaining_channels == 0:
357
+ return 0.0
358
+
359
+ per_step_ig = trajectory.per_step_information_gain
360
+ if not per_step_ig:
361
+ return 0.5 # No history — uncertain, lean toward acquiring
362
+
363
+ initial_h = trajectory.initial_entropy
364
+ if initial_h < 1e-10:
365
+ return 0.0 # Already certain
366
+
367
+ # Exponentially-weighted recent IG (most recent steps matter more)
368
+ weights = [0.5 ** i for i in range(len(per_step_ig))]
369
+ weights.reverse() # Most recent gets highest weight
370
+ weighted_ig = sum(w * ig for w, ig in zip(weights, per_step_ig))
371
+ weighted_ig /= sum(weights)
372
+
373
+ # Normalize by initial entropy
374
+ normalized_ig = weighted_ig / initial_h
375
+
376
+ # Discount by remaining channels (diminishing returns)
377
+ total_channels = len(trajectory.states) + n_remaining_channels
378
+ progress = len(trajectory.states) / total_channels
379
+ discount = 1.0 - (progress * 0.5) # Mild discount as we acquire more
380
+
381
+ voi = normalized_ig * discount
382
+ return max(0.0, min(1.0, voi))
383
+
384
+
385
+ # ============================================================
386
+ # Aggregate Information-Theoretic Metrics
387
+ # ============================================================
388
+
389
+ def compute_information_metrics(trajectories: list[BeliefTrajectory]) -> dict:
390
+ """
391
+ Compute aggregate information-theoretic metrics across cases.
392
+
393
+ Returns:
394
+ dict with:
395
+ - mean_initial_entropy: Average starting uncertainty
396
+ - mean_final_entropy: Average ending uncertainty
397
+ - mean_total_ig: Average total information gain
398
+ - mean_info_efficiency: Average IG / initial entropy
399
+ - per_channel_mean_ig: Average IG contributed by each channel
400
+ - entropy_reduction_curve: Mean entropy at each step
401
+ """
402
+ if not trajectories:
403
+ return {}
404
+
405
+ initial_entropies = [t.initial_entropy for t in trajectories]
406
+ final_entropies = [t.final_entropy for t in trajectories]
407
+ total_igs = [t.total_information_gain for t in trajectories]
408
+ efficiencies = [t.information_efficiency for t in trajectories]
409
+
410
+ # Per-channel IG
411
+ channel_igs: dict[str, list[float]] = {}
412
+ for t in trajectories:
413
+ for ch, ig in t.get_channel_information_values().items():
414
+ if ch not in channel_igs:
415
+ channel_igs[ch] = []
416
+ channel_igs[ch].append(ig)
417
+
418
+ per_channel_mean_ig = {
419
+ ch: float(np.mean(igs)) for ch, igs in channel_igs.items()
420
+ }
421
+
422
+ # Entropy curve (pad shorter trajectories with final entropy)
423
+ max_steps = max(len(t.states) for t in trajectories)
424
+ curves = []
425
+ for t in trajectories:
426
+ curve = t.entropy_trajectory
427
+ # Pad with final value
428
+ curve += [curve[-1]] * (max_steps - len(curve))
429
+ curves.append(curve)
430
+
431
+ mean_curve = list(np.mean(curves, axis=0))
432
+
433
+ return {
434
+ "mean_initial_entropy": float(np.mean(initial_entropies)),
435
+ "mean_final_entropy": float(np.mean(final_entropies)),
436
+ "mean_total_ig": float(np.mean(total_igs)),
437
+ "mean_info_efficiency": float(np.mean(efficiencies)),
438
+ "per_channel_mean_ig": per_channel_mean_ig,
439
+ "entropy_reduction_curve": mean_curve,
440
+ "n_cases": len(trajectories),
441
+ }
policy.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Acquisition Policy Learning for ActiveMedAgent.
3
+
4
+ Three learned policies, all API-based or CPU-only:
5
+
6
+ 1. RewardWeightedICL: Select the best past trajectories as in-context
7
+ examples for the VLM. The VLM sees "here's what worked before on
8
+ similar cases" and makes better acquisition decisions.
9
+
10
+ 2. PolicyNetwork: A small MLP trained on CPU that predicts which channel
11
+ to request given a featurized state. Cheap, fast, interpretable.
12
+
13
+ 3. SelfReflectivePolicy: The VLM critiques its own past failures
14
+ and generates an improved acquisition strategy.
15
+
16
+ All three produce an acquisition policy that replaces the zero-shot
17
+ decision in agent.py.
18
+ """
19
+ import json
20
+ import logging
21
+ import random
22
+ from collections import defaultdict
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+
28
+ import config
29
+ from api_client import BaseVLMClient
30
+ from datasets.base import MedicalCase
31
+ from trajectory import Trajectory, TrajectoryStep
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # ================================================================
37
+ # Approach 1: Reward-Weighted In-Context Learning (ICL)
38
+ # ================================================================
39
+
40
+ class RewardWeightedICL:
41
+ """
42
+ Learn an acquisition policy via reward-weighted few-shot prompting.
43
+
44
+ Strategy:
45
+ 1. From collected trajectories, identify GOOD acquisition decisions
46
+ (positive reward) and BAD ones (negative/zero reward)
47
+ 2. For each new case, retrieve the K most similar past cases
48
+ (by dataset + channel overlap + uncertainty similarity)
49
+ 3. Construct few-shot examples showing good acquisitions
50
+ 4. The VLM sees concrete examples of "when uncertain about X,
51
+ requesting Y helped" and makes better decisions
52
+
53
+ This is essentially offline policy improvement via in-context learning.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ trajectories: list[Trajectory],
59
+ n_examples: int = 3,
60
+ min_reward: float = 0.05,
61
+ ):
62
+ self.n_examples = n_examples
63
+ self.min_reward = min_reward
64
+
65
+ # Index good acquisition decisions
66
+ self.good_decisions: list[dict] = []
67
+ self.bad_decisions: list[dict] = []
68
+
69
+ for traj in trajectories:
70
+ for step in traj.steps:
71
+ if step.action == "COMMIT":
72
+ continue
73
+ decision = {
74
+ "case_id": traj.case_id,
75
+ "dataset": traj.dataset,
76
+ "acquired_before": step.acquired_so_far,
77
+ "action": step.action,
78
+ "uncertainty": step.uncertainty_text,
79
+ "reward": step.utility_reward,
80
+ "mrr_reward": step.reward,
81
+ "cost": step.acquisition_cost,
82
+ "diagnosis_changed": step.diagnosis_changed,
83
+ "diagnosis_improved": step.diagnosis_improved,
84
+ "mrr_before": step.mrr_before,
85
+ "mrr_after": step.mrr_after,
86
+ }
87
+ if step.utility_reward >= min_reward:
88
+ self.good_decisions.append(decision)
89
+ else:
90
+ self.bad_decisions.append(decision)
91
+
92
+ logger.info(
93
+ f"RewardWeightedICL: {len(self.good_decisions)} good, "
94
+ f"{len(self.bad_decisions)} bad decisions indexed"
95
+ )
96
+
97
+ def get_few_shot_examples(
98
+ self,
99
+ case: MedicalCase,
100
+ acquired_so_far: list[str],
101
+ ) -> str:
102
+ """
103
+ Retrieve the best few-shot examples for the current case state.
104
+
105
+ Returns formatted text to prepend to the acquisition prompt.
106
+ """
107
+ # Filter to same dataset
108
+ candidates = [d for d in self.good_decisions if d["dataset"] == case.dataset]
109
+
110
+ if not candidates:
111
+ candidates = self.good_decisions # Fallback to cross-dataset
112
+
113
+ # Score by similarity to current state
114
+ scored = []
115
+ for d in candidates:
116
+ similarity = self._compute_similarity(d, acquired_so_far)
117
+ scored.append((similarity, d))
118
+
119
+ scored.sort(key=lambda x: (-x[0], -x[1]["reward"]))
120
+
121
+ # Take top N
122
+ selected = scored[: self.n_examples]
123
+
124
+ if not selected:
125
+ return ""
126
+
127
+ # Format as few-shot examples
128
+ lines = [
129
+ "Here are examples of helpful acquisition decisions from similar past cases:\n"
130
+ ]
131
+ for i, (sim, d) in enumerate(selected):
132
+ lines.append(f"Example {i + 1}:")
133
+ lines.append(f" Already acquired: {d['acquired_before'] or ['(nothing)']}")
134
+ lines.append(f" Uncertainty: {d['uncertainty'][:150]}")
135
+ lines.append(f" Decision: REQUEST {d['action']}")
136
+ lines.append(
137
+ f" Outcome: MRR improved from {d['mrr_before']:.2f} to {d['mrr_after']:.2f} "
138
+ f"(reward: {d['reward']:+.3f})"
139
+ )
140
+ lines.append("")
141
+
142
+ lines.append(
143
+ "Learn from these examples. Prioritize channels that resolved similar uncertainties.\n"
144
+ )
145
+ return "\n".join(lines)
146
+
147
+ def _compute_similarity(self, decision: dict, acquired_so_far: list[str]) -> float:
148
+ """
149
+ Compute similarity between a past decision and current state.
150
+ Based on acquisition stage overlap.
151
+ """
152
+ past_acquired = set(decision["acquired_before"])
153
+ current_acquired = set(acquired_so_far)
154
+
155
+ # Jaccard similarity of acquisition state
156
+ if not past_acquired and not current_acquired:
157
+ return 1.0 # Both at start
158
+ union = past_acquired | current_acquired
159
+ intersection = past_acquired & current_acquired
160
+ stage_sim = len(intersection) / max(len(union), 1)
161
+
162
+ # Bonus for same acquisition stage (same number of channels acquired)
163
+ stage_match = 1.0 if len(past_acquired) == len(current_acquired) else 0.5
164
+
165
+ return stage_sim * 0.5 + stage_match * 0.5
166
+
167
+
168
+ # ================================================================
169
+ # Approach 2: Lightweight Policy Network (CPU-only)
170
+ # ================================================================
171
+
172
+ class PolicyNetwork:
173
+ """
174
+ Small MLP that predicts which channel to request.
175
+
176
+ State features (input):
177
+ - One-hot: which channels have been acquired
178
+ - One-hot: which dataset this is
179
+ - Scalar: current top-1 confidence
180
+ - Scalar: confidence gap (top1 - top2)
181
+ - Scalar: acquisition step index (0, 1, 2)
182
+
183
+ Output: probability distribution over requestable channels.
184
+
185
+ Trained with cross-entropy loss weighted by trajectory reward.
186
+ Runs entirely on CPU — no GPU needed. This is a <1000 parameter model.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ all_channels: list[str],
192
+ all_datasets: list[str],
193
+ hidden_dim: int = 32,
194
+ ):
195
+ self.all_channels = sorted(all_channels)
196
+ self.all_datasets = sorted(all_datasets)
197
+ self.channel_to_idx = {c: i for i, c in enumerate(self.all_channels)}
198
+ self.dataset_to_idx = {d: i for i, d in enumerate(self.all_datasets)}
199
+ self.n_channels = len(self.all_channels)
200
+ self.n_datasets = len(self.all_datasets)
201
+
202
+ # Feature dimension: acquired_mask + dataset_onehot + confidence + gap + step
203
+ self.input_dim = self.n_channels + self.n_datasets + 3
204
+ self.hidden_dim = hidden_dim
205
+ self.output_dim = self.n_channels
206
+
207
+ # Initialize weights (small random, CPU numpy)
208
+ rng = np.random.RandomState(config.SEED)
209
+ scale1 = np.sqrt(2.0 / self.input_dim)
210
+ scale2 = np.sqrt(2.0 / hidden_dim)
211
+
212
+ self.W1 = rng.randn(self.input_dim, hidden_dim).astype(np.float32) * scale1
213
+ self.b1 = np.zeros(hidden_dim, dtype=np.float32)
214
+ self.W2 = rng.randn(hidden_dim, self.output_dim).astype(np.float32) * scale2
215
+ self.b2 = np.zeros(self.output_dim, dtype=np.float32)
216
+
217
+ self.trained = False
218
+
219
+ def featurize(
220
+ self,
221
+ dataset: str,
222
+ acquired: list[str],
223
+ top1_confidence: float,
224
+ top2_confidence: float,
225
+ step_idx: int,
226
+ ) -> np.ndarray:
227
+ """Convert state to feature vector."""
228
+ features = np.zeros(self.input_dim, dtype=np.float32)
229
+
230
+ # Acquired channels mask
231
+ for ch in acquired:
232
+ if ch in self.channel_to_idx:
233
+ features[self.channel_to_idx[ch]] = 1.0
234
+
235
+ # Dataset one-hot
236
+ offset = self.n_channels
237
+ if dataset in self.dataset_to_idx:
238
+ features[offset + self.dataset_to_idx[dataset]] = 1.0
239
+
240
+ # Scalars
241
+ offset += self.n_datasets
242
+ features[offset] = top1_confidence
243
+ features[offset + 1] = top1_confidence - top2_confidence # Confidence gap
244
+ features[offset + 2] = step_idx / 3.0 # Normalized step
245
+
246
+ return features
247
+
248
+ def predict(
249
+ self,
250
+ features: np.ndarray,
251
+ available_channels: list[str],
252
+ ) -> dict[str, float]:
253
+ """
254
+ Forward pass: predict channel selection probabilities.
255
+
256
+ Returns dict mapping channel_name → probability.
257
+ Only available (not yet acquired) channels get nonzero probability.
258
+ """
259
+ # Forward pass: input → ReLU → softmax (masked)
260
+ h = np.maximum(0, features @ self.W1 + self.b1) # ReLU
261
+ logits = h @ self.W2 + self.b2
262
+
263
+ # Mask unavailable channels to -inf
264
+ mask = np.full(self.output_dim, -1e9, dtype=np.float32)
265
+ for ch in available_channels:
266
+ if ch in self.channel_to_idx:
267
+ mask[self.channel_to_idx[ch]] = 0.0
268
+ logits = logits + mask
269
+
270
+ # Softmax
271
+ logits = logits - logits.max()
272
+ exp_logits = np.exp(logits)
273
+ probs = exp_logits / (exp_logits.sum() + 1e-8)
274
+
275
+ return {ch: float(probs[self.channel_to_idx[ch]])
276
+ for ch in available_channels if ch in self.channel_to_idx}
277
+
278
+ def train(
279
+ self,
280
+ trajectories: list[Trajectory],
281
+ lr: float = 0.01,
282
+ n_epochs: int = 100,
283
+ reward_temperature: float = 1.0,
284
+ ):
285
+ """
286
+ Train the policy network on collected trajectories.
287
+
288
+ Uses reward-weighted cross-entropy:
289
+ loss = -sum(reward * log(P(action|state)))
290
+
291
+ Positive rewards encourage the action; negative discourage it.
292
+ """
293
+ # Build training data
294
+ X = []
295
+ actions = []
296
+ rewards = []
297
+ available_masks = []
298
+
299
+ for traj in trajectories:
300
+ for step in traj.steps:
301
+ if step.action == "COMMIT":
302
+ continue
303
+ if step.action not in self.channel_to_idx:
304
+ continue
305
+
306
+ # Extract features from the step's state
307
+ top1_conf = step.differential_before[0]["confidence"] if step.differential_before else 0.5
308
+ top2_conf = step.differential_before[1]["confidence"] if len(step.differential_before) > 1 else 0.0
309
+
310
+ feat = self.featurize(
311
+ dataset=traj.dataset,
312
+ acquired=step.acquired_so_far,
313
+ top1_confidence=top1_conf,
314
+ top2_confidence=top2_conf,
315
+ step_idx=step.step_idx,
316
+ )
317
+ X.append(feat)
318
+ actions.append(self.channel_to_idx[step.action])
319
+
320
+ # Reward shaping: normalize across trajectories
321
+ rewards.append(step.utility_reward)
322
+
323
+ # Available channels mask
324
+ mask = np.zeros(self.output_dim, dtype=np.float32)
325
+ for ch in step.available_channels:
326
+ if ch in self.channel_to_idx:
327
+ mask[self.channel_to_idx[ch]] = 1.0
328
+ available_masks.append(mask)
329
+
330
+ if not X:
331
+ logger.warning("No training data available for policy network")
332
+ return
333
+
334
+ X = np.array(X)
335
+ actions = np.array(actions)
336
+ rewards = np.array(rewards)
337
+ available_masks = np.array(available_masks)
338
+
339
+ # Normalize rewards
340
+ if rewards.std() > 0:
341
+ rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
342
+
343
+ # Apply temperature
344
+ weights = np.exp(rewards * reward_temperature)
345
+ weights = weights / weights.sum() * len(weights) # Normalize to mean=1
346
+
347
+ n = len(X)
348
+ logger.info(f"Training policy network on {n} state-action pairs for {n_epochs} epochs")
349
+
350
+ for epoch in range(n_epochs):
351
+ # Forward pass
352
+ h = np.maximum(0, X @ self.W1 + self.b1)
353
+ logits = h @ self.W2 + self.b2
354
+
355
+ # Mask unavailable channels
356
+ logits = logits + (1 - available_masks) * (-1e9)
357
+
358
+ # Softmax
359
+ logits_shifted = logits - logits.max(axis=1, keepdims=True)
360
+ exp_logits = np.exp(logits_shifted)
361
+ probs = exp_logits / (exp_logits.sum(axis=1, keepdims=True) + 1e-8)
362
+
363
+ # Cross-entropy loss (reward-weighted)
364
+ action_probs = probs[np.arange(n), actions]
365
+ loss = -np.mean(weights * np.log(action_probs + 1e-8))
366
+
367
+ # Backward pass (manual gradient)
368
+ # dL/d_logits = probs - one_hot(action), weighted by reward
369
+ grad_logits = probs.copy()
370
+ grad_logits[np.arange(n), actions] -= 1.0
371
+ grad_logits *= weights[:, np.newaxis] / n
372
+
373
+ # Gradient for W2, b2
374
+ grad_W2 = h.T @ grad_logits
375
+ grad_b2 = grad_logits.sum(axis=0)
376
+
377
+ # Gradient for W1, b1 (through ReLU)
378
+ grad_h = grad_logits @ self.W2.T
379
+ grad_h *= (h > 0).astype(np.float32) # ReLU derivative
380
+ grad_W1 = X.T @ grad_h
381
+ grad_b1 = grad_h.sum(axis=0)
382
+
383
+ # Update
384
+ self.W1 -= lr * grad_W1
385
+ self.b1 -= lr * grad_b1
386
+ self.W2 -= lr * grad_W2
387
+ self.b2 -= lr * grad_b2
388
+
389
+ if (epoch + 1) % 20 == 0:
390
+ # Compute accuracy
391
+ predicted = np.argmax(probs, axis=1)
392
+ accuracy = np.mean(predicted == actions)
393
+ logger.info(f" Epoch {epoch + 1}: loss={loss:.4f}, accuracy={accuracy:.3f}")
394
+
395
+ self.trained = True
396
+ logger.info("Policy network training complete")
397
+
398
+ def get_action(
399
+ self,
400
+ case: MedicalCase,
401
+ acquired: list[str],
402
+ differential: list[dict],
403
+ step_idx: int,
404
+ ) -> str:
405
+ """Select the best channel to request using the learned policy."""
406
+ available = [ch for ch in case.requestable_names if ch not in acquired]
407
+ if not available:
408
+ return "COMMIT"
409
+
410
+ top1_conf = differential[0]["confidence"] if differential else 0.5
411
+ top2_conf = differential[1]["confidence"] if len(differential) > 1 else 0.0
412
+
413
+ features = self.featurize(
414
+ dataset=case.dataset,
415
+ acquired=acquired,
416
+ top1_confidence=top1_conf,
417
+ top2_confidence=top2_conf,
418
+ step_idx=step_idx,
419
+ )
420
+
421
+ probs = self.predict(features, available)
422
+
423
+ if not probs:
424
+ return random.choice(available)
425
+
426
+ # Select highest probability channel
427
+ best_channel = max(probs, key=probs.get)
428
+ return best_channel
429
+
430
+ def save(self, path: Path):
431
+ """Save model weights."""
432
+ np.savez(
433
+ path,
434
+ W1=self.W1, b1=self.b1,
435
+ W2=self.W2, b2=self.b2,
436
+ channels=self.all_channels,
437
+ datasets=self.all_datasets,
438
+ )
439
+ logger.info(f"Saved policy network to {path}")
440
+
441
+ def load(self, path: Path):
442
+ """Load model weights."""
443
+ data = np.load(path, allow_pickle=True)
444
+ self.W1 = data["W1"]
445
+ self.b1 = data["b1"]
446
+ self.W2 = data["W2"]
447
+ self.b2 = data["b2"]
448
+ self.trained = True
449
+ logger.info(f"Loaded policy network from {path}")
450
+
451
+
452
+ # ================================================================
453
+ # Approach 3: Self-Reflective Refinement
454
+ # ================================================================
455
+
456
+ class SelfReflectivePolicy:
457
+ """
458
+ The VLM critiques its own past failures and generates improved strategies.
459
+
460
+ Pipeline:
461
+ 1. Collect cases where zero-shot acquisition was suboptimal
462
+ (the agent requested info that didn't help, or missed info that would have)
463
+ 2. Show the VLM its own failure traces and ask it to generate
464
+ "acquisition rules" — structured if-then policies
465
+ 3. Inject these self-generated rules into the system prompt
466
+ 4. Re-run with the improved prompt
467
+
468
+ This is a form of self-play / self-improvement via reflection.
469
+ """
470
+
471
+ def __init__(self, client: BaseVLMClient, dataset_name: str):
472
+ self.client = client
473
+ self.dataset_name = dataset_name
474
+ self.rules: list[str] = []
475
+
476
+ def generate_rules_from_failures(
477
+ self,
478
+ trajectories: list[Trajectory],
479
+ n_failure_examples: int = 10,
480
+ ) -> list[str]:
481
+ """
482
+ Analyze failures and generate acquisition rules.
483
+
484
+ A "failure" is a case where:
485
+ - Agent requested a channel with zero or negative utility
486
+ - Agent didn't request a channel that would have helped
487
+ - Agent committed too early (final MRR << oracle MRR)
488
+ """
489
+ # Collect failure examples
490
+ failures = []
491
+
492
+ for traj in trajectories:
493
+ if traj.dataset != self.dataset_name:
494
+ continue
495
+
496
+ # Type 1: Unhelpful acquisitions
497
+ for step in traj.steps:
498
+ if step.action != "COMMIT" and step.utility_reward <= 0:
499
+ failures.append({
500
+ "type": "unhelpful_acquisition",
501
+ "case_id": traj.case_id,
502
+ "action": step.action,
503
+ "uncertainty": step.uncertainty_text[:200],
504
+ "utility_reward": step.utility_reward,
505
+ "mrr_reward": step.reward,
506
+ "cost": step.acquisition_cost,
507
+ "available": step.available_channels,
508
+ })
509
+
510
+ # Type 2: Premature commitment
511
+ if traj.final_mrr < traj.oracle_mrr - 0.2:
512
+ failures.append({
513
+ "type": "premature_commit",
514
+ "case_id": traj.case_id,
515
+ "acquired": [s.action for s in traj.steps if s.action != "COMMIT"],
516
+ "final_mrr": traj.final_mrr,
517
+ "oracle_mrr": traj.oracle_mrr,
518
+ "gap": traj.oracle_mrr - traj.final_mrr,
519
+ })
520
+
521
+ if not failures:
522
+ logger.info("No failures found — zero-shot policy may already be strong")
523
+ return []
524
+
525
+ # Sample failures
526
+ random.shuffle(failures)
527
+ sampled = failures[:n_failure_examples]
528
+
529
+ # Ask the VLM to analyze and generate rules
530
+ failure_text = json.dumps(sampled, indent=2, default=str)
531
+
532
+ prompt = f"""You are analyzing an AI medical diagnostic agent's acquisition failures on {self.dataset_name} cases.
533
+ The agent must decide what additional information to request (imaging modalities, clinical data, etc.) before making a diagnosis.
534
+
535
+ Here are examples of FAILED acquisition decisions:
536
+
537
+ {failure_text}
538
+
539
+ Based on these failures, generate 5-8 specific, actionable ACQUISITION RULES that would improve future decisions.
540
+
541
+ Format each rule as:
542
+ RULE N: IF [condition about the current state/uncertainty] THEN [specific acquisition action] BECAUSE [reasoning]
543
+
544
+ Rules should be specific to the {self.dataset_name} dataset and its available channels.
545
+ Focus on patterns across failures, not individual cases.
546
+ Be concrete — "request OCT when uncertain about subretinal fluid" is better than "request more information when uncertain."
547
+
548
+ Respond ONLY with the rules, no preamble."""
549
+
550
+ response = self.client.call_with_retry(
551
+ system_prompt="You are an expert in medical diagnostic AI systems.",
552
+ user_text=prompt,
553
+ images=None,
554
+ temperature=0.3,
555
+ max_tokens=2048,
556
+ )
557
+
558
+ # Parse rules
559
+ rules = []
560
+ for line in response.text.split("\n"):
561
+ line = line.strip()
562
+ if line.startswith("RULE") or line.startswith("Rule"):
563
+ rules.append(line)
564
+ elif rules and line and not line.startswith("RULE"):
565
+ # Continuation of previous rule
566
+ rules[-1] += " " + line
567
+
568
+ self.rules = rules
569
+ logger.info(f"Generated {len(rules)} acquisition rules from {len(sampled)} failures")
570
+ for r in rules:
571
+ logger.info(f" {r[:120]}...")
572
+
573
+ return rules
574
+
575
+ def get_enhanced_system_prompt(self, base_prompt: str) -> str:
576
+ """
577
+ Inject learned rules into the system prompt.
578
+
579
+ This is the key mechanism: the VLM's behavior is modified
580
+ by giving it its own self-generated rules as instructions.
581
+ """
582
+ if not self.rules:
583
+ return base_prompt
584
+
585
+ rules_text = "\n".join(self.rules)
586
+ injection = f"""
587
+
588
+ LEARNED ACQUISITION STRATEGY (from analyzing past diagnostic cases):
589
+ The following rules have been learned from analyzing cases where acquisition
590
+ decisions were suboptimal. Apply these rules when deciding what information to request:
591
+
592
+ {rules_text}
593
+
594
+ Apply these rules in addition to your general diagnostic reasoning."""
595
+
596
+ return base_prompt + injection
597
+
598
+ def save_rules(self, path: Path):
599
+ """Save generated rules."""
600
+ with open(path, "w") as f:
601
+ json.dump({"dataset": self.dataset_name, "rules": self.rules}, f, indent=2)
602
+
603
+ def load_rules(self, path: Path):
604
+ """Load previously generated rules."""
605
+ with open(path) as f:
606
+ data = json.load(f)
607
+ self.rules = data["rules"]
608
+ logger.info(f"Loaded {len(self.rules)} rules for {self.dataset_name}")
prompts.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for ActiveMedAgent.
3
+
4
+ Three semantically equivalent but lexically different variants (A/B/C)
5
+ for prompt sensitivity analysis.
6
+
7
+ Each prompt has:
8
+ - system_prompt: Sets the agent's role and reasoning format
9
+ - acquisition_prompt: Asks the agent to decide what to request next
10
+ - diagnosis_prompt: Asks the agent to commit to a ranked differential
11
+ """
12
+
13
+ # ============================================================
14
+ # Channel description formatters
15
+ # ============================================================
16
+
17
+ def format_available_channels(channels: dict, already_acquired: list[str]) -> str:
18
+ """Format the list of requestable channels for the prompt."""
19
+ lines = []
20
+ sortable = []
21
+ for name, info in channels.items():
22
+ if info.get("always_given"):
23
+ continue
24
+ if name in already_acquired:
25
+ continue
26
+ sortable.append((info.get("order", 999), info.get("cost", 0.0), name, info))
27
+ for _, _, name, info in sorted(sortable):
28
+ cost = float(info.get("cost", 0.0))
29
+ tier = info.get("tier", "unknown")
30
+ lines.append(
31
+ f" - [{name}]: {info['description']} "
32
+ f"(tier: {tier}, cost: ${cost:,.0f})"
33
+ )
34
+ if not lines:
35
+ return " (No additional information available to request.)"
36
+ return "\n".join(lines)
37
+
38
+
39
+ def format_acquired_info(acquired_data: dict) -> str:
40
+ """Format all previously acquired information for context."""
41
+ if not acquired_data:
42
+ return "(No additional information acquired yet.)"
43
+ parts = []
44
+ for channel_name, content in acquired_data.items():
45
+ if content["type"] == "text":
46
+ parts.append(f"[{channel_name}]: {content['value']}")
47
+ elif content["type"] == "image":
48
+ parts.append(f"[{channel_name}]: (image provided)")
49
+ return "\n".join(parts)
50
+
51
+
52
+ # ============================================================
53
+ # Prompt Variant A — Clinical Framing
54
+ # ============================================================
55
+
56
+ VARIANT_A = {
57
+ "name": "clinical",
58
+
59
+ "system_prompt": """You are an experienced physician performing a diagnostic evaluation. \
60
+ You will be shown a medical image and possibly additional clinical information. \
61
+ Your goal is to arrive at the most accurate diagnosis by strategically requesting \
62
+ the most informative additional data.
63
+
64
+ You reason through cases using a structured clinical approach:
65
+ 1. OBSERVATION: Describe what you see in the available image(s) and data.
66
+ 2. DIFFERENTIAL: List your top 3-5 candidate diagnoses ranked by likelihood, with confidence estimates (0-1).
67
+ 3. UNCERTAINTY: Identify specifically what you are uncertain about — which diagnoses cannot be distinguished with current information and WHY.
68
+ 4. ACTION: You MUST request one additional piece of information. Choose the one that would best disambiguate your top differential diagnoses.
69
+
70
+ CRITICAL: You must ALWAYS use your remaining budget to request information. \
71
+ Do NOT commit early — additional information almost always improves diagnostic accuracy. \
72
+ Always respond in this exact structured format.""",
73
+
74
+ "acquisition_prompt": """You have {remaining_budget} request(s) remaining. You MUST use them.
75
+
76
+ Available information you can request:
77
+ {available_channels}
78
+
79
+ Previously acquired information:
80
+ {acquired_info}
81
+
82
+ Think carefully: which available channel would MOST help distinguish between your top diagnoses?
83
+
84
+ Respond in EXACTLY this format:
85
+ OBSERVATION: [What you observe from all currently available information]
86
+ DIFFERENTIAL: [Ranked list — format each as "N. DiagnosisName (confidence: X.XX)"]
87
+ UNCERTAINTY: [Which two diagnoses are hardest to tell apart, and what specific information would resolve it]
88
+ ACTION: REQUEST [channel_name]
89
+
90
+ IMPORTANT: Replace [channel_name] with exactly one of the available channel names listed above. \
91
+ You MUST request a channel — do not skip or commit early.""",
92
+
93
+ "diagnosis_prompt": """You strategically gathered the most relevant clinical information. \
94
+ Now provide your final diagnosis. Focus on the evidence you acquired — it was selected \
95
+ specifically to resolve diagnostic uncertainty.
96
+
97
+ Information you gathered:
98
+ {acquired_info}
99
+
100
+ Candidate diagnoses to rank:
101
+ {candidates}
102
+
103
+ Respond in the structured format:
104
+ OBSERVATION: [Synthesis of the key findings from your acquired information]
105
+ DIFFERENTIAL: [Ranked candidates — format each as "N. DiagnosisName (confidence: X.XX)"]
106
+ REASONING: [Key evidence from your acquired data supporting your top diagnosis and ruling out alternatives]""",
107
+ }
108
+
109
+
110
+ # ============================================================
111
+ # Prompt Variant B — Information-Theoretic Framing
112
+ # ============================================================
113
+
114
+ VARIANT_B = {
115
+ "name": "information_theoretic",
116
+
117
+ "system_prompt": """You are an AI diagnostic system analyzing medical data under \
118
+ conditions of incomplete information. You process available evidence and estimate which \
119
+ additional data sources would most reduce your diagnostic uncertainty.
120
+
121
+ Your reasoning follows a structured protocol:
122
+ 1. EVIDENCE: Catalog the findings from all available inputs.
123
+ 2. HYPOTHESES: Rank candidate diagnoses by posterior probability (0-1, must sum to ≤1).
124
+ 3. INFORMATION GAP: Identify the highest-uncertainty region of your hypothesis space.
125
+ 4. ACQUISITION: Select the data source with highest expected information gain, or finalize.
126
+
127
+ Always respond in this exact structured format. Be precise with probabilities.""",
128
+
129
+ "acquisition_prompt": """Analyze your current diagnostic uncertainty and determine the \
130
+ optimal next data acquisition. You have {remaining_budget} acquisition(s) remaining.
131
+
132
+ Requestable data sources:
133
+ {available_channels}
134
+
135
+ Previously acquired data:
136
+ {acquired_info}
137
+
138
+ Respond in the structured format:
139
+ EVIDENCE: [Findings extracted from all currently available data]
140
+ HYPOTHESES: [Ranked list — format each as "N. DiagnosisName (probability: X.XX)"]
141
+ INFORMATION GAP: [Which distinction between top hypotheses cannot be resolved with current data, and why]
142
+ ACQUISITION: REQUEST [channel_name] — [expected information gain explanation]
143
+
144
+ If your top hypothesis probability exceeds 0.8 and is well-separated from alternatives:
145
+ ACQUISITION: FINALIZE""",
146
+
147
+ "diagnosis_prompt": """All data acquisition is complete. Produce your final ranked \
148
+ hypothesis set.
149
+
150
+ Accumulated data:
151
+ {acquired_info}
152
+
153
+ Candidate diagnoses to rank:
154
+ {candidates}
155
+
156
+ Respond in the structured format:
157
+ EVIDENCE: [Complete synthesis of all acquired data]
158
+ HYPOTHESES: [Final ranked candidates — format each as "N. DiagnosisName (probability: X.XX)"]
159
+ JUSTIFICATION: [Evidence chain supporting top hypothesis; contradicting evidence for alternatives]""",
160
+ }
161
+
162
+
163
+ # ============================================================
164
+ # Prompt Variant C — Neutral/Minimal Framing
165
+ # ============================================================
166
+
167
+ VARIANT_C = {
168
+ "name": "neutral",
169
+
170
+ "system_prompt": """You are assisting with medical image analysis. Given a medical image \
171
+ and possibly additional information, identify the most likely diagnosis from a set of candidates.
172
+
173
+ You may request additional information before making your final decision. Structure your \
174
+ response as follows:
175
+ 1. FINDINGS: What you observe.
176
+ 2. RANKING: Candidate diagnoses ranked with confidence scores (0-1).
177
+ 3. GAPS: What you don't know that would help.
178
+ 4. DECISION: Request more info or commit.""",
179
+
180
+ "acquisition_prompt": """You may request one more piece of information. \
181
+ {remaining_budget} request(s) left.
182
+
183
+ Options:
184
+ {available_channels}
185
+
186
+ Information so far:
187
+ {acquired_info}
188
+
189
+ Respond:
190
+ FINDINGS: [Current observations]
191
+ RANKING: [Format: "N. DiagnosisName (confidence: X.XX)"]
192
+ GAPS: [What's missing]
193
+ DECISION: REQUEST [channel_name] — [reason]
194
+
195
+ Or if ready:
196
+ DECISION: COMMIT""",
197
+
198
+ "diagnosis_prompt": """Provide your final diagnosis ranking.
199
+
200
+ All information:
201
+ {acquired_info}
202
+
203
+ Candidates:
204
+ {candidates}
205
+
206
+ Respond:
207
+ FINDINGS: [Summary]
208
+ RANKING: [Format: "N. DiagnosisName (confidence: X.XX)"]
209
+ REASONING: [Brief justification]""",
210
+ }
211
+
212
+
213
+ # ============================================================
214
+ # Variant Registry
215
+ # ============================================================
216
+
217
+ PROMPT_VARIANTS = {
218
+ "A": VARIANT_A,
219
+ "B": VARIANT_B,
220
+ "C": VARIANT_C,
221
+ }
222
+
223
+
224
+ def get_prompt_variant(variant_id: str) -> dict:
225
+ """Retrieve a prompt variant by ID."""
226
+ if variant_id not in PROMPT_VARIANTS:
227
+ raise ValueError(f"Unknown prompt variant: {variant_id}. Choose from {list(PROMPT_VARIANTS.keys())}")
228
+ return PROMPT_VARIANTS[variant_id]
reasoning_analysis.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reasoning Faithfulness & Acquisition Pattern Analysis.
3
+
4
+ Key analyses for ACL/EMNLP submission:
5
+
6
+ 1. Reasoning Faithfulness: Does the agent's stated reasoning match
7
+ actual information gain? When it says "I need X to distinguish
8
+ A from B", does X actually shift probability between A and B?
9
+
10
+ 2. Acquisition Order Patterns: What ordering strategies do different
11
+ models learn? Are they consistent? Do they match clinical guidelines?
12
+
13
+ 3. Error Analysis: When the agent commits early and is wrong, what
14
+ went wrong in the reasoning chain?
15
+
16
+ 4. Stopping Decision Quality: Are the agent's commit decisions well-timed?
17
+ """
18
+ import json
19
+ import logging
20
+ import re
21
+ from collections import Counter, defaultdict
22
+ from dataclasses import dataclass, field
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ from scipy.stats import spearmanr, kendalltau
27
+
28
+ from agent import AgentResult, AcquisitionStep
29
+ from datasets.base import MedicalCase
30
+ from information_gain import compute_entropy, compute_kl_divergence
31
+ from evaluation import evaluate_single_case, compute_reciprocal_rank
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # ================================================================
37
+ # 1. Reasoning Faithfulness
38
+ # ================================================================
39
+
40
+ @dataclass
41
+ class FaithfulnessMetrics:
42
+ """Per-step reasoning faithfulness measurement."""
43
+ case_id: str
44
+ step: int
45
+ channel_requested: str
46
+ # What the agent said
47
+ stated_reasoning: str
48
+ stated_if_positive: str
49
+ stated_if_negative: str
50
+ # What actually happened
51
+ target_diagnosis_before: float # Probability of stated target before
52
+ target_diagnosis_after: float # Probability of stated target after
53
+ actual_shift: float # Change in target probability
54
+ shift_direction_correct: bool # Did it shift the way the agent predicted?
55
+ # Information metrics
56
+ entropy_before: float
57
+ entropy_after: float
58
+ actual_ig: float
59
+ predicted_useful: bool # Agent thought this would help
60
+ actually_useful: bool # IG > 0.05 bits
61
+
62
+
63
+ def compute_reasoning_faithfulness(
64
+ results: list[AgentResult],
65
+ cases: list[MedicalCase],
66
+ ) -> dict:
67
+ """
68
+ Measure whether the agent's stated reasoning matches what actually
69
+ happens when information is acquired.
70
+
71
+ For each acquisition step where the agent states expected_impact:
72
+ - Extract the target diagnosis (if_positive/if_negative)
73
+ - Compare probability of that diagnosis before and after acquisition
74
+ - Check if the shift matches the agent's prediction
75
+
76
+ Returns aggregate faithfulness metrics.
77
+ """
78
+ per_step_metrics = []
79
+ direction_correct_count = 0
80
+ useful_when_predicted = 0
81
+ total_with_impact = 0
82
+
83
+ for result, case in zip(results, cases):
84
+ for i, step in enumerate(result.steps):
85
+ if step.committed or not step.expected_impact:
86
+ continue
87
+ if not step.differential:
88
+ continue
89
+
90
+ total_with_impact += 1
91
+
92
+ # Current distribution (before receiving the info)
93
+ current_dist = {
94
+ d.get("name", ""): d.get("confidence", 0)
95
+ for d in step.differential
96
+ }
97
+
98
+ # Find next step's distribution (after receiving the info)
99
+ next_dist = None
100
+ if i + 1 < len(result.steps) and result.steps[i + 1].differential:
101
+ next_dist = {
102
+ d.get("name", ""): d.get("confidence", 0)
103
+ for d in result.steps[i + 1].differential
104
+ }
105
+
106
+ if next_dist is None:
107
+ continue
108
+
109
+ # Get the target diagnosis from expected_impact
110
+ pos_target = step.expected_impact.get("if_positive", "")
111
+ neg_target = step.expected_impact.get("if_negative", "")
112
+
113
+ # Find probability of positive target before and after
114
+ pos_before = _fuzzy_lookup(current_dist, pos_target)
115
+ pos_after = _fuzzy_lookup(next_dist, pos_target)
116
+ neg_before = _fuzzy_lookup(current_dist, neg_target)
117
+ neg_after = _fuzzy_lookup(next_dist, neg_target)
118
+
119
+ # The agent predicted that this channel would help distinguish
120
+ # between pos_target and neg_target. Did the gap widen?
121
+ gap_before = abs(pos_before - neg_before)
122
+ gap_after = abs(pos_after - neg_after)
123
+ gap_widened = gap_after > gap_before
124
+
125
+ # Did probability shift in the stated direction?
126
+ actual_shift = pos_after - pos_before
127
+ shift_correct = gap_widened # More discriminating = correct prediction
128
+
129
+ if shift_correct:
130
+ direction_correct_count += 1
131
+
132
+ # Was the channel actually useful?
133
+ entropy_before = compute_entropy(current_dist)
134
+ entropy_after = compute_entropy(next_dist)
135
+ actual_ig = entropy_before - entropy_after
136
+ actually_useful = actual_ig > 0.05
137
+
138
+ if actually_useful:
139
+ useful_when_predicted += 1
140
+
141
+ metrics = FaithfulnessMetrics(
142
+ case_id=result.case_id,
143
+ step=step.step,
144
+ channel_requested=step.requested_channel or "",
145
+ stated_reasoning=step.reasoning[:200],
146
+ stated_if_positive=pos_target,
147
+ stated_if_negative=neg_target,
148
+ target_diagnosis_before=pos_before,
149
+ target_diagnosis_after=pos_after,
150
+ actual_shift=actual_shift,
151
+ shift_direction_correct=shift_correct,
152
+ entropy_before=entropy_before,
153
+ entropy_after=entropy_after,
154
+ actual_ig=actual_ig,
155
+ predicted_useful=True,
156
+ actually_useful=actually_useful,
157
+ )
158
+ per_step_metrics.append(metrics)
159
+
160
+ n = len(per_step_metrics)
161
+ return {
162
+ "n_steps_analyzed": n,
163
+ "n_with_expected_impact": total_with_impact,
164
+ "direction_accuracy": direction_correct_count / n if n > 0 else 0,
165
+ "utility_precision": useful_when_predicted / n if n > 0 else 0,
166
+ "mean_actual_ig": float(np.mean([m.actual_ig for m in per_step_metrics])) if per_step_metrics else 0,
167
+ "mean_absolute_shift": float(np.mean([abs(m.actual_shift) for m in per_step_metrics])) if per_step_metrics else 0,
168
+ "per_step_details": [
169
+ {
170
+ "case_id": m.case_id,
171
+ "step": m.step,
172
+ "channel": m.channel_requested,
173
+ "direction_correct": m.shift_direction_correct,
174
+ "actual_ig": round(m.actual_ig, 4),
175
+ "actually_useful": m.actually_useful,
176
+ "stated_reasoning": m.stated_reasoning[:100],
177
+ }
178
+ for m in per_step_metrics
179
+ ],
180
+ }
181
+
182
+
183
+ # ================================================================
184
+ # 2. Acquisition Order Patterns
185
+ # ================================================================
186
+
187
+ def analyze_acquisition_orders(
188
+ results: list[AgentResult],
189
+ cases: list[MedicalCase],
190
+ clinical_order: dict[str, list[str]] = None,
191
+ ) -> dict:
192
+ """
193
+ Analyze what acquisition ordering strategies the agent uses.
194
+
195
+ Returns:
196
+ - Most common first/second/third channel requests
197
+ - Order consistency across cases
198
+ - Correlation with clinical guideline order
199
+ - Correlation between acquisition order and case difficulty
200
+ """
201
+ from baselines import CLINICAL_GUIDELINE_ORDER
202
+ if clinical_order is None:
203
+ clinical_order = CLINICAL_GUIDELINE_ORDER
204
+
205
+ # Collect all acquisition sequences
206
+ sequences = []
207
+ first_requests = Counter()
208
+ second_requests = Counter()
209
+ full_sequences = Counter()
210
+
211
+ for result in results:
212
+ seq = tuple(result.acquired_channels)
213
+ sequences.append(seq)
214
+ full_sequences[seq] += 1
215
+
216
+ if len(seq) >= 1:
217
+ first_requests[seq[0]] += 1
218
+ if len(seq) >= 2:
219
+ second_requests[seq[1]] += 1
220
+
221
+ n = len(sequences)
222
+
223
+ # Consistency: what fraction of cases share the most common first request?
224
+ most_common_first = first_requests.most_common(1)
225
+ first_consistency = most_common_first[0][1] / n if most_common_first and n > 0 else 0
226
+
227
+ # Unique sequences
228
+ n_unique = len(full_sequences)
229
+
230
+ # Correlation with clinical guideline order
231
+ guideline_correlations = []
232
+ for result, case in zip(results, cases):
233
+ ds = case.dataset
234
+ if ds not in clinical_order:
235
+ continue
236
+
237
+ gl_order = clinical_order[ds]
238
+ agent_order = result.acquired_channels
239
+
240
+ if len(agent_order) < 2:
241
+ continue
242
+
243
+ # Compute rank correlation
244
+ # Map channels to their guideline rank
245
+ gl_ranks = {ch: i for i, ch in enumerate(gl_order)}
246
+ agent_ranks = {ch: i for i, ch in enumerate(agent_order)}
247
+
248
+ common = set(agent_order) & set(gl_order)
249
+ if len(common) < 2:
250
+ continue
251
+
252
+ gl_r = [gl_ranks.get(ch, len(gl_order)) for ch in agent_order if ch in common]
253
+ ag_r = list(range(len(gl_r)))
254
+
255
+ if len(gl_r) >= 2:
256
+ corr, pval = spearmanr(gl_r, ag_r)
257
+ if not np.isnan(corr):
258
+ guideline_correlations.append(corr)
259
+
260
+ # Cost efficiency: does the agent prefer cheaper channels first?
261
+ cost_order_correlations = []
262
+ for result, case in zip(results, cases):
263
+ if len(result.acquired_channels) < 2:
264
+ continue
265
+
266
+ costs = [case.get_channel_cost(ch) for ch in result.acquired_channels]
267
+ positions = list(range(len(costs)))
268
+
269
+ if len(set(costs)) > 1:
270
+ corr, _ = spearmanr(costs, positions)
271
+ if not np.isnan(corr):
272
+ cost_order_correlations.append(corr)
273
+
274
+ return {
275
+ "n_cases": n,
276
+ "n_unique_sequences": n_unique,
277
+ "sequence_entropy": _sequence_entropy(full_sequences, n),
278
+ "first_request_distribution": dict(first_requests.most_common()),
279
+ "first_request_consistency": first_consistency,
280
+ "second_request_distribution": dict(second_requests.most_common()),
281
+ "most_common_sequences": [
282
+ {"sequence": list(seq), "count": count}
283
+ for seq, count in full_sequences.most_common(5)
284
+ ],
285
+ "guideline_correlation": {
286
+ "mean": float(np.mean(guideline_correlations)) if guideline_correlations else None,
287
+ "std": float(np.std(guideline_correlations)) if guideline_correlations else None,
288
+ "n_comparable": len(guideline_correlations),
289
+ },
290
+ "cost_order_correlation": {
291
+ "mean": float(np.mean(cost_order_correlations)) if cost_order_correlations else None,
292
+ "std": float(np.std(cost_order_correlations)) if cost_order_correlations else None,
293
+ "interpretation": (
294
+ "positive = cheaper first, negative = expensive first"
295
+ ),
296
+ },
297
+ "mean_channels_acquired": float(np.mean([len(s) for s in sequences])),
298
+ }
299
+
300
+
301
+ # ================================================================
302
+ # 3. Error Analysis
303
+ # ================================================================
304
+
305
+ @dataclass
306
+ class ErrorCase:
307
+ """Detailed analysis of a single error case."""
308
+ case_id: str
309
+ ground_truth: str
310
+ agent_top1: str
311
+ agent_confidence: float
312
+ n_acquired: int
313
+ acquired_channels: list[str]
314
+ committed_early: bool
315
+ missed_channels: list[str]
316
+ error_type: str # "overconfident_early", "wrong_after_all", "insufficient_info"
317
+ reasoning_chain: list[str]
318
+ entropy_at_commit: float
319
+ final_ig_trend: str # "increasing", "decreasing", "plateau"
320
+
321
+
322
+ def analyze_errors(
323
+ results: list[AgentResult],
324
+ cases: list[MedicalCase],
325
+ ) -> dict:
326
+ """
327
+ Detailed error analysis: when and why the agent gets cases wrong.
328
+
329
+ Categorizes errors into:
330
+ 1. Overconfident early commit — committed before gathering enough info
331
+ 2. Wrong after all info — had all info but still wrong (reasoning failure)
332
+ 3. Insufficient info — didn't have the right channels (missing key evidence)
333
+ """
334
+ errors = []
335
+ correct_count = 0
336
+ total = len(results)
337
+
338
+ for result, case in zip(results, cases):
339
+ if not result.final_ranking:
340
+ continue
341
+
342
+ top = result.final_ranking[0]
343
+ top_name = top.get("name", "").strip().lower()
344
+ gt = case.ground_truth.strip().lower()
345
+ correct = top_name == gt or top_name in gt or gt in top_name
346
+
347
+ if correct:
348
+ correct_count += 1
349
+ continue
350
+
351
+ # Classify error type
352
+ all_requestable = set(case.requestable_channels.keys())
353
+ acquired = set(result.acquired_channels)
354
+ missed = list(all_requestable - acquired)
355
+
356
+ if result.committed_early and missed:
357
+ error_type = "overconfident_early"
358
+ elif not missed:
359
+ error_type = "wrong_after_all"
360
+ else:
361
+ error_type = "insufficient_info"
362
+
363
+ # Extract reasoning chain
364
+ reasoning_chain = []
365
+ for step in result.steps:
366
+ if step.reasoning:
367
+ reasoning_chain.append(
368
+ f"Step {step.step}: {step.reasoning[:150]}"
369
+ )
370
+
371
+ # Entropy trend
372
+ entropies = [s.entropy for s in result.steps if s.entropy > 0]
373
+ if len(entropies) >= 2:
374
+ diffs = [entropies[i+1] - entropies[i] for i in range(len(entropies)-1)]
375
+ if all(d <= 0 for d in diffs):
376
+ trend = "decreasing"
377
+ elif all(d >= 0 for d in diffs):
378
+ trend = "increasing"
379
+ else:
380
+ trend = "non_monotonic"
381
+ else:
382
+ trend = "insufficient_data"
383
+
384
+ entropy_at_commit = entropies[-1] if entropies else 0.0
385
+
386
+ error = ErrorCase(
387
+ case_id=result.case_id,
388
+ ground_truth=case.ground_truth,
389
+ agent_top1=top.get("name", ""),
390
+ agent_confidence=top.get("confidence", 0),
391
+ n_acquired=len(result.acquired_channels),
392
+ acquired_channels=result.acquired_channels,
393
+ committed_early=result.committed_early,
394
+ missed_channels=missed,
395
+ error_type=error_type,
396
+ reasoning_chain=reasoning_chain,
397
+ entropy_at_commit=entropy_at_commit,
398
+ final_ig_trend=trend,
399
+ )
400
+ errors.append(error)
401
+
402
+ # Aggregate by error type
403
+ type_counts = Counter(e.error_type for e in errors)
404
+ n_errors = len(errors)
405
+
406
+ # Confidence distribution for errors vs correct
407
+ error_confidences = [e.agent_confidence for e in errors]
408
+
409
+ return {
410
+ "n_total": total,
411
+ "n_correct": correct_count,
412
+ "n_errors": n_errors,
413
+ "accuracy": correct_count / total if total > 0 else 0,
414
+ "error_type_distribution": {
415
+ "overconfident_early": type_counts.get("overconfident_early", 0),
416
+ "wrong_after_all": type_counts.get("wrong_after_all", 0),
417
+ "insufficient_info": type_counts.get("insufficient_info", 0),
418
+ },
419
+ "error_type_rates": {
420
+ etype: count / n_errors if n_errors > 0 else 0
421
+ for etype, count in type_counts.items()
422
+ },
423
+ "mean_error_confidence": float(np.mean(error_confidences)) if error_confidences else 0,
424
+ "mean_error_channels_acquired": float(np.mean([e.n_acquired for e in errors])) if errors else 0,
425
+ "entropy_at_commit": {
426
+ "mean": float(np.mean([e.entropy_at_commit for e in errors])) if errors else 0,
427
+ "std": float(np.std([e.entropy_at_commit for e in errors])) if errors else 0,
428
+ },
429
+ "ig_trend_distribution": dict(Counter(e.final_ig_trend for e in errors)),
430
+ "per_case_errors": [
431
+ {
432
+ "case_id": e.case_id,
433
+ "ground_truth": e.ground_truth,
434
+ "predicted": e.agent_top1,
435
+ "confidence": e.agent_confidence,
436
+ "error_type": e.error_type,
437
+ "n_acquired": e.n_acquired,
438
+ "missed": e.missed_channels,
439
+ "committed_early": e.committed_early,
440
+ "entropy_at_commit": round(e.entropy_at_commit, 3),
441
+ }
442
+ for e in errors
443
+ ],
444
+ }
445
+
446
+
447
+ # ================================================================
448
+ # 4. Stopping Decision Quality
449
+ # ================================================================
450
+
451
+ def analyze_stopping_decisions(
452
+ results: list[AgentResult],
453
+ cases: list[MedicalCase],
454
+ ) -> dict:
455
+ """
456
+ Analyze whether the agent's commit decisions are well-timed.
457
+
458
+ Compares:
459
+ - Cases where agent committed early and was correct (good early stop)
460
+ - Cases where agent committed early and was wrong (premature stop)
461
+ - Cases that used all channels (necessary thoroughness vs wasted budget)
462
+ """
463
+ early_correct = []
464
+ early_wrong = []
465
+ full_correct = []
466
+ full_wrong = []
467
+
468
+ for result, case in zip(results, cases):
469
+ if not result.final_ranking:
470
+ continue
471
+
472
+ top = result.final_ranking[0]
473
+ top_name = top.get("name", "").strip().lower()
474
+ gt = case.ground_truth.strip().lower()
475
+ correct = top_name == gt or top_name in gt or gt in top_name
476
+ n_requestable = len(case.requestable_channels)
477
+ n_acquired = len(result.acquired_channels)
478
+
479
+ entry = {
480
+ "case_id": result.case_id,
481
+ "confidence": top.get("confidence", 0),
482
+ "n_acquired": n_acquired,
483
+ "n_available": n_requestable,
484
+ "fraction_used": n_acquired / n_requestable if n_requestable > 0 else 1,
485
+ "cost": result.acquisition_cost,
486
+ }
487
+
488
+ if result.committed_early:
489
+ if correct:
490
+ early_correct.append(entry)
491
+ else:
492
+ early_wrong.append(entry)
493
+ else:
494
+ if correct:
495
+ full_correct.append(entry)
496
+ else:
497
+ full_wrong.append(entry)
498
+
499
+ def _summarize(entries):
500
+ if not entries:
501
+ return {"count": 0}
502
+ return {
503
+ "count": len(entries),
504
+ "mean_confidence": float(np.mean([e["confidence"] for e in entries])),
505
+ "mean_channels": float(np.mean([e["n_acquired"] for e in entries])),
506
+ "mean_fraction_used": float(np.mean([e["fraction_used"] for e in entries])),
507
+ "mean_cost": float(np.mean([e["cost"] for e in entries])),
508
+ }
509
+
510
+ total = len(results)
511
+ early_rate = (len(early_correct) + len(early_wrong)) / total if total > 0 else 0
512
+ early_precision = (
513
+ len(early_correct) / (len(early_correct) + len(early_wrong))
514
+ if (len(early_correct) + len(early_wrong)) > 0 else 0
515
+ )
516
+
517
+ return {
518
+ "n_total": total,
519
+ "early_commit_rate": early_rate,
520
+ "early_commit_precision": early_precision,
521
+ "early_correct": _summarize(early_correct),
522
+ "early_wrong": _summarize(early_wrong),
523
+ "full_correct": _summarize(full_correct),
524
+ "full_wrong": _summarize(full_wrong),
525
+ "cost_savings_from_early_commit": {
526
+ "mean_cost_early": float(np.mean(
527
+ [e["cost"] for e in early_correct + early_wrong]
528
+ )) if (early_correct or early_wrong) else 0,
529
+ "mean_cost_full": float(np.mean(
530
+ [e["cost"] for e in full_correct + full_wrong]
531
+ )) if (full_correct or full_wrong) else 0,
532
+ },
533
+ }
534
+
535
+
536
+ # ================================================================
537
+ # Full Analysis Pipeline
538
+ # ================================================================
539
+
540
+ def run_reasoning_analysis(
541
+ results: list[AgentResult],
542
+ cases: list[MedicalCase],
543
+ save_dir: Path = None,
544
+ ) -> dict:
545
+ """Run all reasoning analyses and return combined results."""
546
+ logger.info("Running reasoning analysis...")
547
+
548
+ faithfulness = compute_reasoning_faithfulness(results, cases)
549
+ logger.info(
550
+ f" Faithfulness: direction_accuracy={faithfulness['direction_accuracy']:.3f}, "
551
+ f"utility_precision={faithfulness['utility_precision']:.3f}"
552
+ )
553
+
554
+ orders = analyze_acquisition_orders(results, cases)
555
+ logger.info(
556
+ f" Order patterns: {orders['n_unique_sequences']} unique sequences, "
557
+ f"first_consistency={orders['first_request_consistency']:.3f}"
558
+ )
559
+
560
+ errors = analyze_errors(results, cases)
561
+ logger.info(
562
+ f" Errors: {errors['n_errors']}/{errors['n_total']} "
563
+ f"({errors['error_type_distribution']})"
564
+ )
565
+
566
+ stopping = analyze_stopping_decisions(results, cases)
567
+ logger.info(
568
+ f" Stopping: early_rate={stopping['early_commit_rate']:.3f}, "
569
+ f"early_precision={stopping['early_commit_precision']:.3f}"
570
+ )
571
+
572
+ output = {
573
+ "reasoning_faithfulness": faithfulness,
574
+ "acquisition_orders": orders,
575
+ "error_analysis": errors,
576
+ "stopping_decisions": stopping,
577
+ }
578
+
579
+ if save_dir:
580
+ save_dir.mkdir(parents=True, exist_ok=True)
581
+ # Remove non-serializable details for compact save
582
+ compact = json.loads(json.dumps(output, default=str))
583
+ with open(save_dir / "reasoning_analysis.json", "w") as f:
584
+ json.dump(compact, f, indent=2)
585
+ logger.info(f" Saved to {save_dir / 'reasoning_analysis.json'}")
586
+
587
+ return output
588
+
589
+
590
+ # ================================================================
591
+ # Helpers
592
+ # ================================================================
593
+
594
+ def _fuzzy_lookup(dist: dict, target: str) -> float:
595
+ """Look up a diagnosis probability with fuzzy name matching."""
596
+ target_lower = target.lower().strip()
597
+ for name, prob in dist.items():
598
+ if target_lower in name.lower() or name.lower() in target_lower:
599
+ return prob
600
+ return 0.0
601
+
602
+
603
+ def _sequence_entropy(counter: Counter, total: int) -> float:
604
+ """Shannon entropy of sequence distribution (diversity measure)."""
605
+ if total == 0:
606
+ return 0.0
607
+ entropy = 0.0
608
+ for count in counter.values():
609
+ p = count / total
610
+ if p > 0:
611
+ entropy -= p * np.log2(p)
612
+ return float(entropy)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=6.0.0
2
+ numpy
3
+ Pillow
4
+ scipy
5
+ openai
6
+ anthropic
7
+ together
8
+ python-dotenv
tools.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool definitions for the ActiveMedAgent tool-use architecture.
3
+
4
+ Instead of parsing free-form text with regex, the agent makes structured
5
+ tool calls through the VLM's native function-calling interface. This:
6
+ 1. Eliminates brittle parsing heuristics
7
+ 2. Makes the agent a genuine tool-using system (not text completion + post-hoc extraction)
8
+ 3. Provides formally verifiable action traces
9
+ 4. Enables grounded information-theoretic analysis via structured probability reports
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from typing import Any
15
+
16
+
17
+ # ============================================================
18
+ # Tool Call Data Structures
19
+ # ============================================================
20
+
21
+ @dataclass
22
+ class ToolCall:
23
+ """A single tool call extracted from a VLM response."""
24
+ tool_name: str
25
+ arguments: dict[str, Any]
26
+ call_id: str = ""
27
+
28
+
29
+ @dataclass
30
+ class ToolResult:
31
+ """Result returned to the VLM after executing a tool."""
32
+ call_id: str
33
+ content: str
34
+ images: list[str] | None = None # base64-encoded images
35
+
36
+
37
+ # ============================================================
38
+ # Tool Definitions (canonical format — translated per backend)
39
+ # ============================================================
40
+
41
+ AGENT_TOOLS = [
42
+ {
43
+ "name": "request_information",
44
+ "description": (
45
+ "Request one additional information channel to reduce diagnostic uncertainty "
46
+ "while avoiding unnecessary resource use. Call this when you need more data "
47
+ "to distinguish between competing diagnoses and the expected benefit justifies "
48
+ "the channel's cost. "
49
+ "You must specify which channel to acquire and why it would resolve your "
50
+ "current uncertainty."
51
+ ),
52
+ "parameters": {
53
+ "type": "object",
54
+ "properties": {
55
+ "channel_name": {
56
+ "type": "string",
57
+ "description": "Exact name of the channel to request (from the available list)",
58
+ },
59
+ "reasoning": {
60
+ "type": "string",
61
+ "description": "Why this channel best resolves your current diagnostic uncertainty",
62
+ },
63
+ "current_differential": {
64
+ "type": "array",
65
+ "description": "Your current ranked differential diagnosis with calibrated probabilities (must sum to 1.0)",
66
+ "items": {
67
+ "type": "object",
68
+ "properties": {
69
+ "name": {"type": "string", "description": "Diagnosis name"},
70
+ "probability": {
71
+ "type": "number",
72
+ "description": "Posterior probability (0-1), all must sum to 1.0",
73
+ },
74
+ },
75
+ "required": ["name", "probability"],
76
+ },
77
+ },
78
+ "expected_impact": {
79
+ "type": "object",
80
+ "description": "What you expect this information to reveal",
81
+ "properties": {
82
+ "if_positive": {
83
+ "type": "string",
84
+ "description": "Which diagnosis becomes most likely if this channel shows positive/abnormal findings",
85
+ },
86
+ "if_negative": {
87
+ "type": "string",
88
+ "description": "Which diagnosis becomes most likely if this channel shows negative/normal findings",
89
+ },
90
+ },
91
+ "required": ["if_positive", "if_negative"],
92
+ },
93
+ },
94
+ "required": ["channel_name", "reasoning", "current_differential", "expected_impact"],
95
+ },
96
+ },
97
+ {
98
+ "name": "commit_diagnosis",
99
+ "description": (
100
+ "Commit to a final ranked diagnosis. Call this ONLY when you have exhausted "
101
+ "the clinically useful information OR when your top diagnosis has probability "
102
+ ">= 0.85 and is well-separated from alternatives. Prefer committing when "
103
+ "remaining channels are unlikely to change management enough to justify cost."
104
+ ),
105
+ "parameters": {
106
+ "type": "object",
107
+ "properties": {
108
+ "ranked_diagnoses": {
109
+ "type": "array",
110
+ "description": "Final ranked list of all candidate diagnoses with calibrated probabilities summing to 1.0",
111
+ "items": {
112
+ "type": "object",
113
+ "properties": {
114
+ "name": {"type": "string"},
115
+ "confidence": {
116
+ "type": "number",
117
+ "description": "Posterior probability (0-1)",
118
+ },
119
+ "key_evidence": {
120
+ "type": "string",
121
+ "description": "Most important evidence supporting or refuting this diagnosis",
122
+ },
123
+ },
124
+ "required": ["name", "confidence", "key_evidence"],
125
+ },
126
+ },
127
+ "reasoning": {
128
+ "type": "string",
129
+ "description": "Final diagnostic reasoning chain",
130
+ },
131
+ },
132
+ "required": ["ranked_diagnoses", "reasoning"],
133
+ },
134
+ },
135
+ ]
136
+
137
+
138
+ # ============================================================
139
+ # Schema Translation
140
+ # ============================================================
141
+
142
+ def to_openai_tools(tools: list[dict] = None) -> list[dict]:
143
+ """Convert canonical tool definitions to OpenAI function-calling format."""
144
+ if tools is None:
145
+ tools = AGENT_TOOLS
146
+ return [
147
+ {
148
+ "type": "function",
149
+ "function": {
150
+ "name": t["name"],
151
+ "description": t["description"],
152
+ "parameters": t["parameters"],
153
+ },
154
+ }
155
+ for t in tools
156
+ ]
157
+
158
+
159
+ def to_anthropic_tools(tools: list[dict] = None) -> list[dict]:
160
+ """Convert canonical tool definitions to Anthropic tool-use format."""
161
+ if tools is None:
162
+ tools = AGENT_TOOLS
163
+ return [
164
+ {
165
+ "name": t["name"],
166
+ "description": t["description"],
167
+ "input_schema": t["parameters"],
168
+ }
169
+ for t in tools
170
+ ]
171
+
172
+
173
+ def constrain_tools_for_step(budget_remaining: int, allow_commit: bool = True) -> list[dict]:
174
+ """
175
+ Return the appropriate tool subset for the current step.
176
+
177
+ - If budget > 0 and channels available: both request_information and commit_diagnosis
178
+ - If budget == 0 or forced final: only commit_diagnosis
179
+ """
180
+ if budget_remaining <= 0:
181
+ return [t for t in AGENT_TOOLS if t["name"] == "commit_diagnosis"]
182
+ tools = list(AGENT_TOOLS)
183
+ if not allow_commit:
184
+ tools = [t for t in tools if t["name"] != "commit_diagnosis"]
185
+ return tools
trajectory.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trajectory Collection for ActiveMedAgent.
3
+
4
+ Phase 1 of the training pipeline:
5
+ 1. Run zero-shot agent on all cases
6
+ 2. Record full (state, action, reward) trajectories
7
+ 3. Compute per-step rewards: did the acquisition improve the diagnosis?
8
+ 4. Save trajectory dataset for Phase 2 policy learning
9
+
10
+ Each trajectory step records:
11
+ - state: current uncertainty, differential, acquired channels so far
12
+ - action: which channel was requested
13
+ - reward: MRR improvement after receiving the requested info
14
+ - outcome: final diagnosis correctness
15
+ """
16
+ import json
17
+ import logging
18
+ import random
19
+ from dataclasses import dataclass, field, asdict
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ from tqdm import tqdm
24
+
25
+ import config
26
+ from api_client import BaseVLMClient, create_client
27
+ from agent import ActiveMedAgent, AgentResult
28
+ from datasets.base import MedicalCase
29
+ from evaluation import compute_reciprocal_rank
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class TrajectoryStep:
36
+ """A single step in an acquisition trajectory."""
37
+ step_idx: int
38
+ # State representation
39
+ acquired_so_far: list[str]
40
+ available_channels: list[str]
41
+ uncertainty_text: str
42
+ differential_before: list[dict] # Ranking before this acquisition
43
+ mrr_before: float
44
+
45
+ # Action
46
+ action: str # Channel name requested (or "COMMIT")
47
+
48
+ # Outcome (computed after the action)
49
+ differential_after: list[dict] # Ranking after receiving the info
50
+ mrr_after: float
51
+ reward: float # MRR improvement: mrr_after - mrr_before
52
+ acquisition_cost: float = 0.0
53
+ normalized_cost: float = 0.0
54
+ utility_reward: float = 0.0 # Cost-aware reward used for policy learning
55
+ diagnosis_changed: bool = False # Did top-1 change?
56
+ diagnosis_improved: bool = False # Did it change to the correct answer?
57
+
58
+
59
+ @dataclass
60
+ class Trajectory:
61
+ """Complete trajectory for one case."""
62
+ case_id: str
63
+ dataset: str
64
+ ground_truth: str
65
+ candidates: list[str]
66
+ steps: list[TrajectoryStep] = field(default_factory=list)
67
+ passive_mrr: float = 0.0
68
+ oracle_mrr: float = 0.0
69
+ final_mrr: float = 0.0
70
+ total_reward: float = 0.0
71
+ total_utility_reward: float = 0.0
72
+ success: bool = False # Did the agent get top-1 correct?
73
+
74
+
75
+ class TrajectoryCollector:
76
+ """
77
+ Collect acquisition trajectories with per-step rewards.
78
+
79
+ Unlike the basic agent.diagnose(), this method runs the agent
80
+ step-by-step, evaluating the diagnosis after EACH acquisition
81
+ to compute fine-grained reward signals.
82
+
83
+ Uses the tool-use agent architecture: runs the full agent for
84
+ acquisition decisions, then evaluates intermediate states via
85
+ the agent's get_diagnosis_at_state() helper.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ client: BaseVLMClient,
91
+ prompt_variant: str = "A",
92
+ budget: int = 3,
93
+ ):
94
+ self.client = client
95
+ self.prompt_variant = prompt_variant
96
+ self.budget = budget
97
+
98
+ def collect_trajectory(self, case: MedicalCase) -> Trajectory:
99
+ """
100
+ Collect a full trajectory with per-step rewards for one case.
101
+
102
+ Strategy:
103
+ 1. Get passive baseline (image-only diagnosis)
104
+ 2. Get oracle ceiling (all-info diagnosis)
105
+ 3. Run the active agent and record its decisions
106
+ 4. For each acquisition step, evaluate the intermediate
107
+ diagnosis to compute per-step MRR reward
108
+ """
109
+ traj = Trajectory(
110
+ case_id=case.case_id,
111
+ dataset=case.dataset,
112
+ ground_truth=case.ground_truth,
113
+ candidates=case.candidates,
114
+ )
115
+
116
+ # ---- Evaluation agent (budget=0, just for scoring) ----
117
+ eval_agent = ActiveMedAgent(
118
+ self.client, self.prompt_variant, budget=0
119
+ )
120
+
121
+ # ---- Get passive baseline (MRR with no acquisition) ----
122
+ passive_result = eval_agent.diagnose_passive(case)
123
+ passive_mrr = compute_reciprocal_rank(
124
+ passive_result.final_ranking, case.ground_truth, case.candidates
125
+ )
126
+ traj.passive_mrr = passive_mrr
127
+
128
+ # ---- Get oracle ceiling (MRR with all info) ----
129
+ oracle_result = eval_agent.diagnose_oracle(case)
130
+ oracle_mrr = compute_reciprocal_rank(
131
+ oracle_result.final_ranking, case.ground_truth, case.candidates
132
+ )
133
+ traj.oracle_mrr = oracle_mrr
134
+
135
+ # ---- Run the active agent to get its acquisition decisions ----
136
+ active_agent = ActiveMedAgent(
137
+ self.client, self.prompt_variant, budget=self.budget
138
+ )
139
+ active_result = active_agent.diagnose(case)
140
+
141
+ # ---- Evaluate each intermediate state ----
142
+ current_mrr = passive_mrr
143
+ current_ranking = passive_result.final_ranking
144
+ acquired_so_far = []
145
+
146
+ for step_idx, step in enumerate(active_result.steps):
147
+ if step.committed:
148
+ # Agent committed early — record and stop
149
+ traj_step = TrajectoryStep(
150
+ step_idx=step_idx,
151
+ acquired_so_far=list(acquired_so_far),
152
+ available_channels=[
153
+ n for n in case.requestable_names
154
+ if n not in acquired_so_far
155
+ ],
156
+ uncertainty_text=step.reasoning or "",
157
+ differential_before=current_ranking,
158
+ mrr_before=current_mrr,
159
+ action="COMMIT",
160
+ differential_after=current_ranking,
161
+ mrr_after=current_mrr,
162
+ reward=0.0,
163
+ acquisition_cost=0.0,
164
+ normalized_cost=0.0,
165
+ utility_reward=0.0,
166
+ diagnosis_changed=False,
167
+ diagnosis_improved=False,
168
+ )
169
+ traj.steps.append(traj_step)
170
+ break
171
+
172
+ channel = step.requested_channel
173
+ if not channel:
174
+ continue
175
+
176
+ available = [
177
+ n for n in case.requestable_names
178
+ if n not in acquired_so_far
179
+ ]
180
+
181
+ # Record the state BEFORE this acquisition
182
+ before_ranking = current_ranking
183
+ before_mrr = current_mrr
184
+
185
+ # Execute the acquisition
186
+ acquired_so_far.append(channel)
187
+
188
+ # Evaluate the diagnosis AFTER this acquisition
189
+ after_ranking, _ = eval_agent.get_diagnosis_at_state(
190
+ case, list(acquired_so_far)
191
+ )
192
+ after_mrr = compute_reciprocal_rank(
193
+ after_ranking, case.ground_truth, case.candidates
194
+ )
195
+
196
+ # Compute reward
197
+ reward = after_mrr - before_mrr
198
+ channel_cost = case.get_channel_cost(channel)
199
+ max_requestable_cost = max(case.get_max_requestable_cost(), 1.0)
200
+ normalized_cost = channel_cost / max_requestable_cost
201
+ utility_reward = reward - (
202
+ config.COST_PENALTY_LAMBDA * normalized_cost
203
+ )
204
+
205
+ # Did diagnosis change?
206
+ top1_before = (
207
+ before_ranking[0]["name"] if before_ranking else ""
208
+ )
209
+ top1_after = (
210
+ after_ranking[0]["name"] if after_ranking else ""
211
+ )
212
+ diagnosis_changed = (
213
+ top1_before.lower() != top1_after.lower()
214
+ )
215
+
216
+ gt_lower = case.ground_truth.lower()
217
+ diagnosis_improved = (
218
+ diagnosis_changed
219
+ and (
220
+ gt_lower in top1_after.lower()
221
+ or top1_after.lower() in gt_lower
222
+ )
223
+ )
224
+
225
+ traj_step = TrajectoryStep(
226
+ step_idx=step_idx,
227
+ acquired_so_far=list(acquired_so_far[:-1]),
228
+ available_channels=available,
229
+ uncertainty_text=step.reasoning or "",
230
+ differential_before=before_ranking,
231
+ mrr_before=before_mrr,
232
+ action=channel,
233
+ differential_after=after_ranking,
234
+ mrr_after=after_mrr,
235
+ reward=reward,
236
+ acquisition_cost=channel_cost,
237
+ normalized_cost=normalized_cost,
238
+ utility_reward=utility_reward,
239
+ diagnosis_changed=diagnosis_changed,
240
+ diagnosis_improved=diagnosis_improved,
241
+ )
242
+ traj.steps.append(traj_step)
243
+
244
+ # Update state for next step
245
+ current_mrr = after_mrr
246
+ current_ranking = after_ranking
247
+
248
+ # ---- Finalize trajectory ----
249
+ traj.final_mrr = current_mrr
250
+ traj.total_reward = sum(s.reward for s in traj.steps)
251
+ traj.total_utility_reward = sum(s.utility_reward for s in traj.steps)
252
+ traj.success = (current_mrr == 1.0)
253
+
254
+ return traj
255
+
256
+ def collect_dataset(
257
+ self,
258
+ cases: list[MedicalCase],
259
+ max_cases: int = None,
260
+ save_path: Path = None,
261
+ ) -> list[Trajectory]:
262
+ """Collect trajectories for all cases."""
263
+ if max_cases:
264
+ cases = cases[:max_cases]
265
+
266
+ trajectories = []
267
+ for case in tqdm(cases, desc="Collecting trajectories", ncols=80):
268
+ try:
269
+ traj = self.collect_trajectory(case)
270
+ trajectories.append(traj)
271
+ except Exception as e:
272
+ logger.error(f"Failed on {case.case_id}: {e}")
273
+ continue
274
+
275
+ # Save
276
+ if save_path:
277
+ save_path = Path(save_path)
278
+ save_path.parent.mkdir(parents=True, exist_ok=True)
279
+ with open(save_path, "w") as f:
280
+ json.dump(
281
+ [asdict(t) for t in trajectories],
282
+ f, indent=2, default=str,
283
+ )
284
+ logger.info(f"Saved {len(trajectories)} trajectories to {save_path}")
285
+
286
+ # Report statistics
287
+ self._report_stats(trajectories)
288
+
289
+ return trajectories
290
+
291
+ def _report_stats(self, trajectories: list[Trajectory]):
292
+ """Log summary statistics of collected trajectories."""
293
+ n = len(trajectories)
294
+ if n == 0:
295
+ return
296
+
297
+ logger.info(f"\n{'='*50}")
298
+ logger.info(f"Trajectory Collection Summary (n={n})")
299
+ logger.info(f"{'='*50}")
300
+
301
+ success_rate = np.mean([t.success for t in trajectories])
302
+ avg_steps = np.mean([len(t.steps) for t in trajectories])
303
+ avg_reward = np.mean([t.total_reward for t in trajectories])
304
+ avg_utility = np.mean([t.total_utility_reward for t in trajectories])
305
+ avg_passive_mrr = np.mean([t.passive_mrr for t in trajectories])
306
+ avg_final_mrr = np.mean([t.final_mrr for t in trajectories])
307
+ avg_oracle_mrr = np.mean([t.oracle_mrr for t in trajectories])
308
+
309
+ logger.info(f" Success rate: {success_rate:.3f}")
310
+ logger.info(f" Avg steps taken: {avg_steps:.1f}")
311
+ logger.info(f" Avg total reward: {avg_reward:.3f}")
312
+ logger.info(f" Avg utility reward: {avg_utility:.3f}")
313
+ logger.info(
314
+ f" MRR: passive={avg_passive_mrr:.3f} -> "
315
+ f"active={avg_final_mrr:.3f} -> oracle={avg_oracle_mrr:.3f}"
316
+ )
317
+
318
+ # Per-action reward statistics
319
+ all_steps = [
320
+ s for t in trajectories for s in t.steps
321
+ if s.action != "COMMIT"
322
+ ]
323
+ if all_steps:
324
+ action_rewards = {}
325
+ for s in all_steps:
326
+ if s.action not in action_rewards:
327
+ action_rewards[s.action] = []
328
+ action_rewards[s.action].append(s.utility_reward)
329
+
330
+ logger.info(f"\n Per-channel utility statistics:")
331
+ for action, rewards in sorted(
332
+ action_rewards.items(), key=lambda x: -np.mean(x[1])
333
+ ):
334
+ logger.info(
335
+ f" {action:<25} mean_utility={np.mean(rewards):+.3f} "
336
+ f"n={len(rewards)} "
337
+ f"positive_rate={np.mean([r > 0 for r in rewards]):.2f}"
338
+ )