yuxbox's picture
Upload folder using huggingface_hub
e794581 verified
"""
Interactive Demo for ActiveMedAgent.
A Gradio-based UI that lets users:
- Select from pre-built demo cases OR enter a custom clinical scenario
- Upload medical images (optional)
- Watch the agent's step-by-step reasoning, information acquisition, and
entropy reduction in real time
- No budget constraint β€” the agent acquires as many channels as it needs
Usage:
python app.py
python app.py --backend openai
python app.py --backend anthropic --port 7861
"""
import argparse
import json
import logging
import sys
import time
import math
from pathlib import Path
from dataclasses import dataclass, field
import numpy as np
import gradio as gr
from PIL import Image
sys.path.insert(0, str(Path(__file__).resolve().parent))
import config
from api_client import create_client, encode_image_to_base64, encode_pil_image_to_base64
from agent import ActiveMedAgent, AgentResult, AcquisitionStep, SYSTEM_PROMPT_FULL, SYSTEM_PROMPT_CONDENSED, SYSTEM_PROMPT_FINAL
from datasets.base import MedicalCase, ChannelData
from tools import AGENT_TOOLS, constrain_tools_for_step, ToolCall
from information_gain import (
BeliefState, BeliefTrajectory,
compute_entropy, compute_kl_divergence,
estimate_expected_information_gain,
should_commit, compute_value_of_information,
)
from prompts import format_available_channels, format_acquired_info
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
# ============================================================
# Backend Availability Detection
# ============================================================
def _detect_available_backends() -> list[str]:
"""Detect which backends have API keys configured."""
available = []
if config.OPENAI_API_KEY and config.OPENAI_API_KEY != "sk-...":
available.append("openai")
if config.ANTHROPIC_API_KEY and config.ANTHROPIC_API_KEY != "sk-ant-...":
available.append("anthropic")
if config.TOGETHER_API_KEY:
available.append("together")
return available
AVAILABLE_BACKENDS = _detect_available_backends()
# ============================================================
# Simulation Mode β€” works without API keys
# ============================================================
def _simulate_agent_on_case(case: MedicalCase) -> AgentResult:
"""
Run a simulated agent that demonstrates the full pipeline
with realistic-looking reasoning traces. No API keys needed.
"""
import random
random.seed(42)
result = AgentResult(
case_id=case.case_id,
dataset=case.dataset,
prompt_variant="A",
backend="simulated (no API key)",
budget=len(case.requestable_channels),
)
trajectory = BeliefTrajectory(case_id=case.case_id)
acquired = []
n_candidates = len(case.candidates)
# Generate initial uniform-ish distribution
probs = np.random.dirichlet(np.ones(n_candidates) * 2.0).tolist()
probs.sort(reverse=True)
# Make ground truth likely to end up on top by the end
gt_idx = case.ground_truth_rank
requestable_names = list(case.requestable_channels.keys())
cumulative_cost = case.get_initial_cost()
for step_idx, ch_name in enumerate(requestable_names):
ch = case.requestable_channels[ch_name]
# Evolve the distribution β€” gradually concentrate on correct answer
progress = (step_idx + 1) / len(requestable_names)
new_probs = []
for i in range(n_candidates):
if i == gt_idx:
new_probs.append(probs[i] + 0.15 * progress + random.uniform(0, 0.05))
else:
new_probs.append(max(0.01, probs[i] - 0.04 * progress + random.uniform(-0.02, 0.02)))
total = sum(new_probs)
probs = [p / total for p in new_probs]
distribution = {case.candidates[i]: probs[i] for i in range(n_candidates)}
sorted_dist = sorted(distribution.items(), key=lambda x: -x[1])
prev_entropy = trajectory.states[-1].entropy if trajectory.states else compute_entropy(distribution) + 0.3
belief = BeliefState(
step=step_idx,
distribution=distribution,
channel_acquired=ch_name,
)
trajectory.states.append(belief)
ig = prev_entropy - belief.entropy
kl = abs(ig) * 1.2 + random.uniform(0, 0.1)
top_two = sorted_dist[:2]
reasoning_templates = [
f"Need to distinguish between {top_two[0][0]} ({top_two[0][1]:.0%}) and {top_two[1][0]} ({top_two[1][1]:.0%}). "
f"Requesting {ch_name} to resolve this uncertainty.",
f"Current top diagnosis is {top_two[0][0]} at {top_two[0][1]:.0%} but {top_two[1][0]} cannot be ruled out. "
f"The {ch_name} channel should provide discriminating evidence.",
f"Diagnostic uncertainty remains high (H={belief.entropy:.2f} bits). "
f"The {ch_name} data is expected to significantly narrow the differential.",
]
step = AcquisitionStep(
step=step_idx,
tool_call=ToolCall(tool_name="request_information", arguments={
"channel_name": ch_name,
"reasoning": reasoning_templates[step_idx % len(reasoning_templates)],
}),
requested_channel=ch_name,
reasoning=reasoning_templates[step_idx % len(reasoning_templates)],
differential=[
{"name": name, "confidence": prob, "rank": i + 1}
for i, (name, prob) in enumerate(sorted_dist)
],
committed=False,
raw_response="(simulated)",
latency_ms=random.uniform(800, 3000),
entropy=belief.entropy,
information_gain=ig,
kl_divergence=kl,
expected_impact={
"if_positive": sorted_dist[0][0],
"if_negative": sorted_dist[1][0],
},
)
result.steps.append(step)
acquired.append(ch_name)
# Final commit step
final_probs = []
for i in range(n_candidates):
if i == gt_idx:
final_probs.append(0.65 + random.uniform(0, 0.15))
else:
final_probs.append(random.uniform(0.02, 0.12))
total = sum(final_probs)
final_probs = [p / total for p in final_probs]
final_dist = {case.candidates[i]: final_probs[i] for i in range(n_candidates)}
sorted_final = sorted(final_dist.items(), key=lambda x: -x[1])
final_belief = BeliefState(
step=len(requestable_names),
distribution=final_dist,
channel_acquired=None,
)
trajectory.states.append(final_belief)
final_ranking = [
{
"name": name,
"confidence": prob,
"rank": i + 1,
"key_evidence": f"Supported by evidence from acquired channels" if i == 0 else "Less consistent with findings",
}
for i, (name, prob) in enumerate(sorted_final)
]
commit_step = AcquisitionStep(
step=len(requestable_names),
tool_call=ToolCall(tool_name="commit_diagnosis", arguments={}),
requested_channel=None,
reasoning=f"After acquiring all available channels, the evidence strongly supports {sorted_final[0][0]}. "
f"Entropy reduced to {final_belief.entropy:.2f} bits. Committing diagnosis.",
differential=final_ranking,
committed=True,
raw_response="(simulated)",
latency_ms=random.uniform(500, 2000),
entropy=final_belief.entropy,
information_gain=trajectory.states[-2].entropy - final_belief.entropy if len(trajectory.states) >= 2 else 0,
kl_divergence=0.0,
)
result.steps.append(commit_step)
result.committed_early = False
result.final_ranking = final_ranking
result.acquired_channels = acquired
result.belief_trajectory = trajectory
result.acquisition_cost = case.get_acquisition_cost(acquired)
result.total_case_cost = case.get_total_cost(acquired)
result.total_latency_ms = sum(s.latency_ms for s in result.steps)
result.total_input_tokens = 0
result.total_output_tokens = 0
return result
# ============================================================
# Synthetic Demo Cases
# ============================================================
def _make_dummy_image(width=224, height=224, color=(180, 60, 60)) -> str:
img = Image.new("RGB", (width, height), color)
arr = np.array(img)
noise = np.random.randint(-20, 20, arr.shape, dtype=np.int16)
arr = np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8)
img = Image.fromarray(arr)
return encode_pil_image_to_base64(img)
DEMO_CASES = {
"NEJM: Pulmonary Fibrosis": {
"description": (
"A 58-year-old man with progressive dyspnea and dry cough over 3 months. "
"30-pack-year smoking history, takes lisinopril for hypertension."
),
"case": lambda: MedicalCase(
case_id="demo_nejm_ipf",
dataset="nejm",
initial_channels={
"demographics": ChannelData(
name="demographics", channel_type="text",
description="Patient age, sex, and ethnicity",
value="A 58-year-old man", always_given=True, cost=0.0, tier="free",
),
"chief_complaint": ChannelData(
name="chief_complaint", channel_type="text",
description="Presenting symptoms and duration",
value="Progressive dyspnea and dry cough over the past 3 months.",
always_given=True, cost=0.0, tier="free",
),
"medical_history": ChannelData(
name="medical_history", channel_type="text",
description="Past medical conditions, medications, family and social history",
value="30-pack-year smoking history. No prior lung disease. Takes lisinopril for hypertension.",
always_given=True, cost=0.0, tier="free",
),
},
requestable_channels={
"exam_findings": ChannelData(
name="exam_findings", channel_type="text",
description="Physical examination results and observations",
value="Bibasilar crackles on auscultation. No clubbing. Oxygen saturation 92% on room air.",
cost=75.0, tier="cheap",
),
"investigations": ChannelData(
name="investigations", channel_type="text",
description="Laboratory values, prior imaging results, and test outcomes",
value="PFTs show restrictive pattern with reduced DLCO. CT chest shows bilateral ground-glass opacities with honeycombing in the lower lobes.",
cost=250.0, tier="moderate",
),
"image": ChannelData(
name="image", channel_type="image",
description="The primary diagnostic image (chest CT)",
value=_make_dummy_image(300, 300, (200, 200, 210)),
cost=800.0, tier="expensive",
),
},
candidates=[
"A. Idiopathic pulmonary fibrosis",
"B. Hypersensitivity pneumonitis",
"C. Sarcoidosis",
"D. Lung adenocarcinoma",
"E. ACE-inhibitor induced cough with incidental CT findings",
],
ground_truth="A. Idiopathic pulmonary fibrosis",
ground_truth_rank=0,
),
},
"Dermatology: Pigmented Lesion": {
"description": (
"A 62-year-old woman presents with a pigmented lesion on her left forearm. "
"The lesion is 8mm x 6mm. Clinical photograph provided."
),
"case": lambda: MedicalCase(
case_id="demo_midas_001",
dataset="midas",
initial_channels={
"clinical_30cm": ChannelData(
name="clinical_30cm", channel_type="image",
description="Clinical photograph at 30cm distance",
value=_make_dummy_image(224, 224, (180, 120, 100)),
always_given=True, cost=0.0, tier="free",
),
},
requestable_channels={
"patient_demographics": ChannelData(
name="patient_demographics", channel_type="text",
description="Patient age, sex, and Fitzpatrick skin type",
value="Age: 62; Sex: Female; Fitzpatrick skin type: III",
cost=0.0, tier="free",
),
"lesion_metadata": ChannelData(
name="lesion_metadata", channel_type="text",
description="Anatomic location, lesion length and width",
value="Anatomic location: Left forearm; Lesion length: 8mm; Lesion width: 6mm",
cost=25.0, tier="cheap",
),
"clinical_15cm": ChannelData(
name="clinical_15cm", channel_type="image",
description="Clinical photograph at 15cm distance (closer view)",
value=_make_dummy_image(224, 224, (170, 110, 90)),
cost=50.0, tier="moderate",
),
"dermoscopy": ChannelData(
name="dermoscopy", channel_type="image",
description="Dermoscopic image showing subsurface skin structures",
value=_make_dummy_image(224, 224, (100, 80, 60)),
cost=250.0, tier="expensive",
),
},
candidates=[
"Melanoma in situ",
"Dysplastic nevus",
"Basal cell carcinoma",
"Seborrheic keratosis",
"Solar lentigo",
],
ground_truth="Dysplastic nevus",
ground_truth_rank=1,
),
},
"Ophthalmology: Retinal Biomarkers (OLIVES)": {
"description": (
"A patient with diabetic macular edema (DME), 4 prior anti-VEGF injections, "
"32 weeks in treatment. Fundus photograph provided."
),
"case": lambda: MedicalCase(
case_id="demo_olives_P01",
dataset="olives",
initial_channels={
"disease_context": ChannelData(
name="disease_context", channel_type="text",
description="Disease type and treatment context",
value="Disease: Diabetic Macular Edema (DME). Prior anti-VEGF injections: 4. Weeks in treatment: 32.",
always_given=True, cost=0.0, tier="free",
),
},
requestable_channels={
"clinical_measurements": ChannelData(
name="clinical_measurements", channel_type="text",
description="Best Corrected Visual Acuity (BCVA) and Central Subfield Thickness (CST)",
value="BCVA: 20/60 (logMAR 0.48); CST: 385 um",
cost=20.0, tier="cheap",
),
"biomarker_hints": ChannelData(
name="biomarker_hints", channel_type="text",
description="Expert-graded presence of fundus-visible retinal biomarkers",
value="Hard Exudates: Present; Hemorrhage: Present; Microaneurysms: Present; Cotton Wool Spots: Not detected",
cost=100.0, tier="moderate",
),
"oct_scan": ChannelData(
name="oct_scan", channel_type="image",
description="OCT B-scan showing retinal cross-section",
value=_make_dummy_image(512, 128, (60, 60, 60)),
cost=300.0, tier="expensive",
),
"additional_oct": ChannelData(
name="additional_oct", channel_type="image",
description="Additional OCT B-scans from different retinal locations",
value=_make_dummy_image(512, 128, (50, 50, 55)),
cost=150.0, tier="very_expensive",
),
},
candidates=[
"Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms",
"Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Fluid Srf, Hard Exudates, Hemorrhage, Microaneurysms",
"Present biomarkers: Hard Exudates, Hemorrhage, Microaneurysms",
"Present biomarkers: Dril, Ez Disruption, Fluid Irf, Shrm",
"No biomarkers detected",
],
ground_truth="Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms",
ground_truth_rank=0,
),
},
"NEJM: Cardiac Case": {
"description": (
"A 45-year-old woman presents with sudden onset chest pain and shortness "
"of breath. She recently completed a long international flight."
),
"case": lambda: MedicalCase(
case_id="demo_nejm_pe",
dataset="nejm",
initial_channels={
"demographics": ChannelData(
name="demographics", channel_type="text",
description="Patient age, sex, and ethnicity",
value="A 45-year-old woman", always_given=True, cost=0.0, tier="free",
),
"chief_complaint": ChannelData(
name="chief_complaint", channel_type="text",
description="Presenting symptoms and duration",
value="Sudden onset chest pain and shortness of breath, started 2 hours ago after returning from a 14-hour international flight.",
always_given=True, cost=0.0, tier="free",
),
"medical_history": ChannelData(
name="medical_history", channel_type="text",
description="Past medical conditions, medications, family and social history",
value="On oral contraceptives for 5 years. BMI 32. No prior VTE. Mother had DVT at age 50.",
always_given=True, cost=0.0, tier="free",
),
},
requestable_channels={
"exam_findings": ChannelData(
name="exam_findings", channel_type="text",
description="Physical examination results and observations",
value="Tachycardic (HR 110), tachypneic (RR 24), SpO2 89% on room air. Right calf swollen and tender. JVP elevated. Loud P2 on cardiac auscultation.",
cost=75.0, tier="cheap",
),
"investigations": ChannelData(
name="investigations", channel_type="text",
description="Laboratory values, imaging results, and test outcomes",
value="D-dimer: 4200 ng/mL (markedly elevated). Troponin I: 0.15 ng/mL (mildly elevated). ABG: pH 7.48, PaO2 62 mmHg, PaCO2 28 mmHg. ECG: S1Q3T3 pattern, right axis deviation. CT pulmonary angiography: bilateral pulmonary emboli with right heart strain.",
cost=250.0, tier="moderate",
),
"image": ChannelData(
name="image", channel_type="image",
description="CT Pulmonary Angiography image",
value=_make_dummy_image(300, 300, (100, 100, 120)),
cost=800.0, tier="expensive",
),
},
candidates=[
"A. Pulmonary embolism",
"B. Acute myocardial infarction",
"C. Tension pneumothorax",
"D. Aortic dissection",
"E. Acute pericarditis",
],
ground_truth="A. Pulmonary embolism",
ground_truth_rank=0,
),
},
}
# ============================================================
# Custom Case Builder
# ============================================================
def build_custom_case(
scenario_text: str,
candidates_text: str,
channel_1_name: str, channel_1_type: str, channel_1_value: str,
channel_2_name: str, channel_2_type: str, channel_2_value: str,
channel_3_name: str, channel_3_type: str, channel_3_value: str,
uploaded_image=None,
) -> MedicalCase:
"""Build a MedicalCase from user-provided custom inputs."""
candidates = [c.strip() for c in candidates_text.strip().split("\n") if c.strip()]
if not candidates:
candidates = ["Diagnosis A", "Diagnosis B", "Diagnosis C"]
initial_channels = {
"clinical_scenario": ChannelData(
name="clinical_scenario", channel_type="text",
description="The presenting clinical scenario",
value=scenario_text,
always_given=True, cost=0.0, tier="free",
),
}
if uploaded_image is not None:
img_b64 = encode_pil_image_to_base64(Image.fromarray(uploaded_image))
initial_channels["uploaded_image"] = ChannelData(
name="uploaded_image", channel_type="image",
description="Uploaded medical image",
value=img_b64, always_given=True, cost=0.0, tier="free",
)
requestable = {}
for name, ctype, value in [
(channel_1_name, channel_1_type, channel_1_value),
(channel_2_name, channel_2_type, channel_2_value),
(channel_3_name, channel_3_type, channel_3_value),
]:
name = name.strip()
value = value.strip()
if name and value:
key = name.lower().replace(" ", "_")
requestable[key] = ChannelData(
name=key, channel_type=ctype.lower(),
description=name,
value=value,
cost=100.0, tier="moderate",
)
# Register channel config so the agent can look it up
custom_config = {}
for name, ch in initial_channels.items():
custom_config[name] = {
"description": ch.description,
"type": ch.channel_type,
"always_given": True,
"tier": ch.tier,
"cost": ch.cost,
"order": 0,
}
for i, (name, ch) in enumerate(requestable.items()):
custom_config[name] = {
"description": ch.description,
"type": ch.channel_type,
"always_given": False,
"tier": ch.tier,
"cost": ch.cost,
"order": i + 1,
}
config.CHANNEL_CONFIGS["custom"] = custom_config
return MedicalCase(
case_id="custom_case",
dataset="custom",
initial_channels=initial_channels,
requestable_channels=requestable,
candidates=candidates,
ground_truth=candidates[0] if candidates else "",
ground_truth_rank=0,
)
# ============================================================
# Formatting Helpers
# ============================================================
def format_step_markdown(step_idx: int, step: AcquisitionStep, cumulative_cost: float) -> str:
"""Format a single acquisition step as rich markdown."""
lines = []
if step.committed:
lines.append(f"### Step {step_idx + 1}: COMMITTED TO DIAGNOSIS")
lines.append("")
lines.append(f"**Reasoning:** {step.reasoning}")
lines.append("")
if step.differential:
lines.append("**Final Ranking:**")
for d in step.differential:
conf = d.get("confidence", 0)
bar = render_bar(conf)
evidence = d.get("key_evidence", "")
lines.append(f"- **{d['name']}** β€” {conf:.1%} {bar}")
if evidence:
lines.append(f" - *Evidence:* {evidence}")
else:
lines.append(f"### Step {step_idx + 1}: Requested `{step.requested_channel}`")
lines.append("")
lines.append(f"**Reasoning:** {step.reasoning}")
lines.append("")
if step.differential:
lines.append("**Current Differential:**")
for d in step.differential:
conf = d.get("confidence", 0)
bar = render_bar(conf)
lines.append(f"- {d['name']} β€” {conf:.1%} {bar}")
if step.expected_impact:
lines.append("")
lines.append("**Expected Impact:**")
pos = step.expected_impact.get("if_positive", "N/A")
neg = step.expected_impact.get("if_negative", "N/A")
lines.append(f"- If positive/abnormal: *{pos}*")
lines.append(f"- If negative/normal: *{neg}*")
lines.append("")
lines.append("**Information Metrics:**")
lines.append(f"- Entropy: **{step.entropy:.3f}** bits")
if step.information_gain:
lines.append(f"- Information Gain: **{step.information_gain:.3f}** bits")
if step.kl_divergence:
lines.append(f"- KL Divergence: **{step.kl_divergence:.3f}** bits")
lines.append(f"- Latency: {step.latency_ms:.0f}ms")
lines.append(f"- Cumulative Cost: ${cumulative_cost:,.0f}")
lines.append("")
lines.append("---")
return "\n".join(lines)
def render_bar(value: float, width: int = 20) -> str:
"""Render a text-based progress bar."""
filled = int(value * width)
return "`" + "\u2588" * filled + "\u2591" * (width - filled) + "`"
def format_entropy_table(trajectory: BeliefTrajectory) -> str:
"""Format entropy trajectory as a markdown table."""
if not trajectory or not trajectory.states:
return "*No belief trajectory recorded.*"
lines = ["| Step | Channel | Entropy (bits) | Info Gain | Cumulative IG |"]
lines.append("|------|---------|---------------|-----------|---------------|")
cumulative_ig = 0.0
for i, state in enumerate(trajectory.states):
ch = state.channel_acquired or "initial/commit"
ig = 0.0
if i > 0:
ig = trajectory.states[i - 1].entropy - state.entropy
cumulative_ig += ig
lines.append(
f"| {i} | {ch} | {state.entropy:.3f} | "
f"{ig:+.3f} | {cumulative_ig:.3f} |"
)
lines.append("")
lines.append(f"**Information Efficiency:** {trajectory.information_efficiency:.1%}")
lines.append(f"**Total Information Gain:** {trajectory.total_information_gain:.3f} bits")
return "\n".join(lines)
def format_summary(result: AgentResult, case: MedicalCase) -> str:
"""Format the overall result summary."""
lines = []
lines.append("## Summary")
lines.append("")
if result.final_ranking:
top = result.final_ranking[0]
top_name = top["name"].strip().lower()
gt_name = case.ground_truth.strip().lower()
# Fuzzy match: handle "Pulmonary embolism" vs "A. Pulmonary embolism"
correct = top_name == gt_name or top_name in gt_name or gt_name in top_name
icon = "correct" if correct else "incorrect"
lines.append(f"**Top Diagnosis:** {top['name']} ({top['confidence']:.1%})")
lines.append(f"**Ground Truth:** {case.ground_truth}")
lines.append(f"**Result:** {icon}")
else:
lines.append("*No diagnosis produced.*")
lines.append("")
lines.append(f"**Channels Acquired:** {len(result.acquired_channels)} / {len(case.requestable_channels)}")
if result.acquired_channels:
lines.append(f"**Acquisition Order:** {' -> '.join(result.acquired_channels)}")
lines.append(f"**Committed Early:** {'Yes' if result.committed_early else 'No'}")
lines.append(f"**Total Acquisition Cost:** ${result.acquisition_cost:,.0f}")
lines.append(f"**Total Case Cost:** ${result.total_case_cost:,.0f}")
lines.append(f"**Total Latency:** {result.total_latency_ms:,.0f}ms")
lines.append(f"**Tokens Used:** {result.total_input_tokens:,} in / {result.total_output_tokens:,} out")
return "\n".join(lines)
# ============================================================
# Main Agent Runner (for Gradio)
# ============================================================
def run_agent_on_case(
case: MedicalCase,
backend: str,
context_mode: str,
user_api_key: str = "",
) -> tuple[str, str, str]:
"""
Run the agent on a case and return formatted markdown outputs.
Returns: (steps_markdown, entropy_table, summary_markdown)
"""
if backend == "simulated (no API key)":
result = _simulate_agent_on_case(case)
model_name = "simulated"
else:
try:
kwargs = {}
if user_api_key and user_api_key.strip():
kwargs["api_key"] = user_api_key.strip()
client = create_client(backend, **kwargs)
except Exception as e:
return (
f"**Error creating {backend} client:** {e}\n\n"
"Make sure you enter your API key above, or set it in environment variables. "
"Or select **simulated (no API key)** to see a demo trace.",
"", "",
)
agent = ActiveMedAgent(
client,
prompt_variant="A",
budget=None, # NO BUDGET CONSTRAINT
context_mode=context_mode if context_mode != "adaptive" else None,
)
try:
result = agent.diagnose(case)
except Exception as e:
return f"**Error running agent:** {e}", "", ""
model_name = client.model
# Format step-by-step reasoning
steps_parts = []
steps_parts.append("# Agent Reasoning Trace\n")
steps_parts.append(f"**Case:** {case.case_id} | **Dataset:** {case.dataset} | **Backend:** {model_name}\n")
steps_parts.append(f"**Candidates:** {', '.join(case.candidates)}\n")
initial_info = format_acquired_info(case.get_text_context([]))
steps_parts.append(f"**Initial Information:**\n{initial_info}\n")
steps_parts.append("---\n")
cumulative_cost = case.get_initial_cost()
for i, step in enumerate(result.steps):
if step.requested_channel:
cumulative_cost += case.get_channel_cost(step.requested_channel)
steps_parts.append(format_step_markdown(i, step, cumulative_cost))
steps_md = "\n".join(steps_parts)
# Format entropy trajectory
entropy_md = ""
if result.belief_trajectory:
entropy_md = format_entropy_table(result.belief_trajectory)
# Format summary
summary_md = format_summary(result, case)
return steps_md, entropy_md, summary_md
# ============================================================
# Gradio Event Handlers
# ============================================================
def on_demo_case_selected(case_name: str) -> tuple[str, str]:
"""When a demo case is selected, show its description and candidates."""
if case_name in DEMO_CASES:
info = DEMO_CASES[case_name]
case = info["case"]()
desc = info["description"]
cands = "\n".join(case.candidates)
channels = []
for name, ch in case.requestable_channels.items():
channels.append(f"- **{name}** ({ch.tier}, ${ch.cost:,.0f}): {ch.description}")
ch_str = "\n".join(channels)
return (
f"{desc}\n\n**Available channels to acquire:**\n{ch_str}",
cands,
)
return "", ""
def run_demo_case(case_name: str, backend: str, context_mode: str, user_api_key: str = ""):
"""Run agent on a selected demo case."""
if case_name not in DEMO_CASES:
return "Please select a demo case.", "", ""
case = DEMO_CASES[case_name]["case"]()
return run_agent_on_case(case, backend, context_mode, user_api_key)
def run_custom_case(
scenario: str, candidates: str,
ch1_name: str, ch1_type: str, ch1_value: str,
ch2_name: str, ch2_type: str, ch2_value: str,
ch3_name: str, ch3_type: str, ch3_value: str,
uploaded_image,
backend: str, context_mode: str,
user_api_key: str = "",
):
"""Run agent on a custom user-defined case."""
if not scenario.strip():
return "Please enter a clinical scenario.", "", ""
case = build_custom_case(
scenario, candidates,
ch1_name, ch1_type, ch1_value,
ch2_name, ch2_type, ch2_value,
ch3_name, ch3_type, ch3_value,
uploaded_image,
)
return run_agent_on_case(case, backend, context_mode, user_api_key)
# ============================================================
# Gradio UI
# ============================================================
def create_app():
with gr.Blocks(
title="ActiveMedAgent Interactive Demo",
) as app:
gr.Markdown(
"""
# ActiveMedAgent: Learned Information Acquisition for Medical Diagnosis
**Interactive Demo** β€” Watch the agent reason step-by-step, acquire information channels,
and track entropy reduction. **No budget constraint** β€” the agent decides when to stop.
""",
elem_classes="header-text",
)
# Build backend choices: always show real backends (user can paste their own key)
all_backends = ["openai", "anthropic", "together"]
backend_choices = ["simulated (no API key)"] + all_backends
default_backend = "openai"
with gr.Row():
backend = gr.Dropdown(
choices=backend_choices,
value=default_backend,
label="VLM Backend",
info="Select 'simulated' to see the demo without API keys",
scale=1,
)
context_mode = gr.Dropdown(
choices=["adaptive", "full", "condensed"],
value="adaptive",
label="Context Mode",
info="How the agent manages conversation history",
scale=1,
)
user_api_key = gr.Textbox(
label="API Key (optional)",
placeholder="sk-... (paste your OpenAI/Anthropic key)",
type="password",
info="Enter your own API key to use a real VLM backend",
scale=1,
)
with gr.Tabs():
# ---- Tab 1: Demo Cases ----
with gr.TabItem("Demo Cases"):
gr.Markdown("Select a pre-built clinical scenario and run the agent.")
with gr.Row():
case_selector = gr.Dropdown(
choices=list(DEMO_CASES.keys()),
label="Select Case",
scale=2,
)
run_demo_btn = gr.Button("Run Agent", variant="primary", scale=1)
case_description = gr.Markdown(label="Case Description")
case_candidates = gr.Textbox(label="Candidate Diagnoses", lines=3, interactive=False)
case_selector.change(
fn=on_demo_case_selected,
inputs=[case_selector],
outputs=[case_description, case_candidates],
)
with gr.Row():
with gr.Column(scale=2):
demo_steps = gr.Markdown(
label="Reasoning Steps",
elem_classes="reasoning-box",
)
with gr.Column(scale=1):
demo_summary = gr.Markdown(label="Summary")
demo_entropy = gr.Markdown(label="Entropy Trajectory")
run_demo_btn.click(
fn=run_demo_case,
inputs=[case_selector, backend, context_mode, user_api_key],
outputs=[demo_steps, demo_entropy, demo_summary],
)
# ---- Tab 2: Custom Case ----
with gr.TabItem("Custom Case"):
gr.Markdown(
"Define your own clinical scenario, candidate diagnoses, "
"and information channels the agent can request."
)
with gr.Row():
with gr.Column():
custom_scenario = gr.Textbox(
label="Clinical Scenario",
placeholder="A 35-year-old woman presents with...",
lines=4,
)
custom_candidates = gr.Textbox(
label="Candidate Diagnoses (one per line)",
placeholder="A. Diagnosis one\nB. Diagnosis two\nC. Diagnosis three",
lines=5,
)
custom_image = gr.Image(
label="Upload Medical Image (optional)",
type="numpy",
)
with gr.Column():
gr.Markdown("### Requestable Information Channels")
gr.Markdown("Define up to 3 channels the agent can request.")
with gr.Group():
gr.Markdown("**Channel 1:**")
ch1_name = gr.Textbox(label="Name", value="Exam Findings", scale=1)
ch1_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
ch1_value = gr.Textbox(label="Content (what the agent receives)", lines=2,
placeholder="Physical exam: temperature 38.5C, ...")
with gr.Group():
gr.Markdown("**Channel 2:**")
ch2_name = gr.Textbox(label="Name", value="Lab Results", scale=1)
ch2_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
ch2_value = gr.Textbox(label="Content", lines=2,
placeholder="WBC 12,000, CRP elevated, ...")
with gr.Group():
gr.Markdown("**Channel 3:**")
ch3_name = gr.Textbox(label="Name", value="Imaging", scale=1)
ch3_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type")
ch3_value = gr.Textbox(label="Content", lines=2,
placeholder="CT scan shows...")
run_custom_btn = gr.Button("Run Agent on Custom Case", variant="primary")
with gr.Row():
with gr.Column(scale=2):
custom_steps = gr.Markdown(
label="Reasoning Steps",
elem_classes="reasoning-box",
)
with gr.Column(scale=1):
custom_summary = gr.Markdown(label="Summary")
custom_entropy = gr.Markdown(label="Entropy Trajectory")
run_custom_btn.click(
fn=run_custom_case,
inputs=[
custom_scenario, custom_candidates,
ch1_name, ch1_type, ch1_value,
ch2_name, ch2_type, ch2_value,
ch3_name, ch3_type, ch3_value,
custom_image,
backend, context_mode,
user_api_key,
],
outputs=[custom_steps, custom_entropy, custom_summary],
)
# ---- Tab 3: How It Works ----
with gr.TabItem("How It Works"):
gr.Markdown("""
## ActiveMedAgent Architecture
### Tool-Use Acquisition Loop
The agent uses native VLM function calling (not regex parsing) with two tools:
1. **`request_information`** β€” Request one data channel, providing reasoning, current differential with calibrated probabilities, and expected impact
2. **`commit_diagnosis`** β€” Submit final ranked diagnosis when confident
### No Budget Constraint
The agent acquires as many channels as it needs (0 to all). It stops when:
- It calls `commit_diagnosis` (self-determined confidence)
- Information-theoretic stopping criteria trigger (convergence, confirmed dominance, or diminishing returns)
- All channels are exhausted
### Information-Theoretic Metrics
At each step, the system tracks:
- **Shannon Entropy** H(p) β€” diagnostic uncertainty in bits
- **Information Gain** β€” entropy reduction from each acquisition
- **KL Divergence** β€” how much the belief distribution shifted
- **Expected Information Gain (EIG)** β€” predicted value of the next channel
- **Value of Information (VoI)** β€” whether continuing to acquire is worthwhile
### Context Management
- **Full Mode**: Multi-turn conversation with complete history (for capable models)
- **Condensed Mode**: Fresh single-turn call each step with compressed state log (for weaker models)
- **Adaptive**: Auto-selects based on model capability
### Stopping Criteria
1. **Convergence**: Last acquisition < 0.05 bits of IG
2. **Confirmed Dominance**: Top diagnosis > 90% probability with > 40% gap (after 2+ acquisitions)
3. **Diminishing Returns**: Last 2 acquisitions both < 0.1 bits IG
""")
return app
# ============================================================
# Entry Point
# ============================================================
def main():
parser = argparse.ArgumentParser(description="ActiveMedAgent Interactive Demo")
parser.add_argument("--port", type=int, default=7860, help="Port to serve on")
parser.add_argument("--backend", default="openai", choices=["openai", "anthropic", "together"])
parser.add_argument("--share", action="store_true", help="Create a public Gradio link")
args = parser.parse_args()
app = create_app()
app.launch(
server_port=args.port,
share=args.share,
theme=gr.themes.Soft(),
css="""
.reasoning-box { font-size: 14px; }
.header-text { text-align: center; margin-bottom: 10px; }
""",
)
if __name__ == "__main__":
main()