Spaces:
Sleeping
Sleeping
| """ | |
| Interactive Demo for ActiveMedAgent. | |
| A Gradio-based UI that lets users: | |
| - Select from pre-built demo cases OR enter a custom clinical scenario | |
| - Upload medical images (optional) | |
| - Watch the agent's step-by-step reasoning, information acquisition, and | |
| entropy reduction in real time | |
| - No budget constraint β the agent acquires as many channels as it needs | |
| Usage: | |
| python app.py | |
| python app.py --backend openai | |
| python app.py --backend anthropic --port 7861 | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| import time | |
| import math | |
| from pathlib import Path | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| import config | |
| from api_client import create_client, encode_image_to_base64, encode_pil_image_to_base64 | |
| from agent import ActiveMedAgent, AgentResult, AcquisitionStep, SYSTEM_PROMPT_FULL, SYSTEM_PROMPT_CONDENSED, SYSTEM_PROMPT_FINAL | |
| from datasets.base import MedicalCase, ChannelData | |
| from tools import AGENT_TOOLS, constrain_tools_for_step, ToolCall | |
| from information_gain import ( | |
| BeliefState, BeliefTrajectory, | |
| compute_entropy, compute_kl_divergence, | |
| estimate_expected_information_gain, | |
| should_commit, compute_value_of_information, | |
| ) | |
| from prompts import format_available_channels, format_acquired_info | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # Backend Availability Detection | |
| # ============================================================ | |
| def _detect_available_backends() -> list[str]: | |
| """Detect which backends have API keys configured.""" | |
| available = [] | |
| if config.OPENAI_API_KEY and config.OPENAI_API_KEY != "sk-...": | |
| available.append("openai") | |
| if config.ANTHROPIC_API_KEY and config.ANTHROPIC_API_KEY != "sk-ant-...": | |
| available.append("anthropic") | |
| if config.TOGETHER_API_KEY: | |
| available.append("together") | |
| return available | |
| AVAILABLE_BACKENDS = _detect_available_backends() | |
| # ============================================================ | |
| # Simulation Mode β works without API keys | |
| # ============================================================ | |
| def _simulate_agent_on_case(case: MedicalCase) -> AgentResult: | |
| """ | |
| Run a simulated agent that demonstrates the full pipeline | |
| with realistic-looking reasoning traces. No API keys needed. | |
| """ | |
| import random | |
| random.seed(42) | |
| result = AgentResult( | |
| case_id=case.case_id, | |
| dataset=case.dataset, | |
| prompt_variant="A", | |
| backend="simulated (no API key)", | |
| budget=len(case.requestable_channels), | |
| ) | |
| trajectory = BeliefTrajectory(case_id=case.case_id) | |
| acquired = [] | |
| n_candidates = len(case.candidates) | |
| # Generate initial uniform-ish distribution | |
| probs = np.random.dirichlet(np.ones(n_candidates) * 2.0).tolist() | |
| probs.sort(reverse=True) | |
| # Make ground truth likely to end up on top by the end | |
| gt_idx = case.ground_truth_rank | |
| requestable_names = list(case.requestable_channels.keys()) | |
| cumulative_cost = case.get_initial_cost() | |
| for step_idx, ch_name in enumerate(requestable_names): | |
| ch = case.requestable_channels[ch_name] | |
| # Evolve the distribution β gradually concentrate on correct answer | |
| progress = (step_idx + 1) / len(requestable_names) | |
| new_probs = [] | |
| for i in range(n_candidates): | |
| if i == gt_idx: | |
| new_probs.append(probs[i] + 0.15 * progress + random.uniform(0, 0.05)) | |
| else: | |
| new_probs.append(max(0.01, probs[i] - 0.04 * progress + random.uniform(-0.02, 0.02))) | |
| total = sum(new_probs) | |
| probs = [p / total for p in new_probs] | |
| distribution = {case.candidates[i]: probs[i] for i in range(n_candidates)} | |
| sorted_dist = sorted(distribution.items(), key=lambda x: -x[1]) | |
| prev_entropy = trajectory.states[-1].entropy if trajectory.states else compute_entropy(distribution) + 0.3 | |
| belief = BeliefState( | |
| step=step_idx, | |
| distribution=distribution, | |
| channel_acquired=ch_name, | |
| ) | |
| trajectory.states.append(belief) | |
| ig = prev_entropy - belief.entropy | |
| kl = abs(ig) * 1.2 + random.uniform(0, 0.1) | |
| top_two = sorted_dist[:2] | |
| reasoning_templates = [ | |
| f"Need to distinguish between {top_two[0][0]} ({top_two[0][1]:.0%}) and {top_two[1][0]} ({top_two[1][1]:.0%}). " | |
| f"Requesting {ch_name} to resolve this uncertainty.", | |
| f"Current top diagnosis is {top_two[0][0]} at {top_two[0][1]:.0%} but {top_two[1][0]} cannot be ruled out. " | |
| f"The {ch_name} channel should provide discriminating evidence.", | |
| f"Diagnostic uncertainty remains high (H={belief.entropy:.2f} bits). " | |
| f"The {ch_name} data is expected to significantly narrow the differential.", | |
| ] | |
| step = AcquisitionStep( | |
| step=step_idx, | |
| tool_call=ToolCall(tool_name="request_information", arguments={ | |
| "channel_name": ch_name, | |
| "reasoning": reasoning_templates[step_idx % len(reasoning_templates)], | |
| }), | |
| requested_channel=ch_name, | |
| reasoning=reasoning_templates[step_idx % len(reasoning_templates)], | |
| differential=[ | |
| {"name": name, "confidence": prob, "rank": i + 1} | |
| for i, (name, prob) in enumerate(sorted_dist) | |
| ], | |
| committed=False, | |
| raw_response="(simulated)", | |
| latency_ms=random.uniform(800, 3000), | |
| entropy=belief.entropy, | |
| information_gain=ig, | |
| kl_divergence=kl, | |
| expected_impact={ | |
| "if_positive": sorted_dist[0][0], | |
| "if_negative": sorted_dist[1][0], | |
| }, | |
| ) | |
| result.steps.append(step) | |
| acquired.append(ch_name) | |
| # Final commit step | |
| final_probs = [] | |
| for i in range(n_candidates): | |
| if i == gt_idx: | |
| final_probs.append(0.65 + random.uniform(0, 0.15)) | |
| else: | |
| final_probs.append(random.uniform(0.02, 0.12)) | |
| total = sum(final_probs) | |
| final_probs = [p / total for p in final_probs] | |
| final_dist = {case.candidates[i]: final_probs[i] for i in range(n_candidates)} | |
| sorted_final = sorted(final_dist.items(), key=lambda x: -x[1]) | |
| final_belief = BeliefState( | |
| step=len(requestable_names), | |
| distribution=final_dist, | |
| channel_acquired=None, | |
| ) | |
| trajectory.states.append(final_belief) | |
| final_ranking = [ | |
| { | |
| "name": name, | |
| "confidence": prob, | |
| "rank": i + 1, | |
| "key_evidence": f"Supported by evidence from acquired channels" if i == 0 else "Less consistent with findings", | |
| } | |
| for i, (name, prob) in enumerate(sorted_final) | |
| ] | |
| commit_step = AcquisitionStep( | |
| step=len(requestable_names), | |
| tool_call=ToolCall(tool_name="commit_diagnosis", arguments={}), | |
| requested_channel=None, | |
| reasoning=f"After acquiring all available channels, the evidence strongly supports {sorted_final[0][0]}. " | |
| f"Entropy reduced to {final_belief.entropy:.2f} bits. Committing diagnosis.", | |
| differential=final_ranking, | |
| committed=True, | |
| raw_response="(simulated)", | |
| latency_ms=random.uniform(500, 2000), | |
| entropy=final_belief.entropy, | |
| information_gain=trajectory.states[-2].entropy - final_belief.entropy if len(trajectory.states) >= 2 else 0, | |
| kl_divergence=0.0, | |
| ) | |
| result.steps.append(commit_step) | |
| result.committed_early = False | |
| result.final_ranking = final_ranking | |
| result.acquired_channels = acquired | |
| result.belief_trajectory = trajectory | |
| result.acquisition_cost = case.get_acquisition_cost(acquired) | |
| result.total_case_cost = case.get_total_cost(acquired) | |
| result.total_latency_ms = sum(s.latency_ms for s in result.steps) | |
| result.total_input_tokens = 0 | |
| result.total_output_tokens = 0 | |
| return result | |
| # ============================================================ | |
| # Synthetic Demo Cases | |
| # ============================================================ | |
| def _make_dummy_image(width=224, height=224, color=(180, 60, 60)) -> str: | |
| img = Image.new("RGB", (width, height), color) | |
| arr = np.array(img) | |
| noise = np.random.randint(-20, 20, arr.shape, dtype=np.int16) | |
| arr = np.clip(arr.astype(np.int16) + noise, 0, 255).astype(np.uint8) | |
| img = Image.fromarray(arr) | |
| return encode_pil_image_to_base64(img) | |
| DEMO_CASES = { | |
| "NEJM: Pulmonary Fibrosis": { | |
| "description": ( | |
| "A 58-year-old man with progressive dyspnea and dry cough over 3 months. " | |
| "30-pack-year smoking history, takes lisinopril for hypertension." | |
| ), | |
| "case": lambda: MedicalCase( | |
| case_id="demo_nejm_ipf", | |
| dataset="nejm", | |
| initial_channels={ | |
| "demographics": ChannelData( | |
| name="demographics", channel_type="text", | |
| description="Patient age, sex, and ethnicity", | |
| value="A 58-year-old man", always_given=True, cost=0.0, tier="free", | |
| ), | |
| "chief_complaint": ChannelData( | |
| name="chief_complaint", channel_type="text", | |
| description="Presenting symptoms and duration", | |
| value="Progressive dyspnea and dry cough over the past 3 months.", | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| "medical_history": ChannelData( | |
| name="medical_history", channel_type="text", | |
| description="Past medical conditions, medications, family and social history", | |
| value="30-pack-year smoking history. No prior lung disease. Takes lisinopril for hypertension.", | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| }, | |
| requestable_channels={ | |
| "exam_findings": ChannelData( | |
| name="exam_findings", channel_type="text", | |
| description="Physical examination results and observations", | |
| value="Bibasilar crackles on auscultation. No clubbing. Oxygen saturation 92% on room air.", | |
| cost=75.0, tier="cheap", | |
| ), | |
| "investigations": ChannelData( | |
| name="investigations", channel_type="text", | |
| description="Laboratory values, prior imaging results, and test outcomes", | |
| value="PFTs show restrictive pattern with reduced DLCO. CT chest shows bilateral ground-glass opacities with honeycombing in the lower lobes.", | |
| cost=250.0, tier="moderate", | |
| ), | |
| "image": ChannelData( | |
| name="image", channel_type="image", | |
| description="The primary diagnostic image (chest CT)", | |
| value=_make_dummy_image(300, 300, (200, 200, 210)), | |
| cost=800.0, tier="expensive", | |
| ), | |
| }, | |
| candidates=[ | |
| "A. Idiopathic pulmonary fibrosis", | |
| "B. Hypersensitivity pneumonitis", | |
| "C. Sarcoidosis", | |
| "D. Lung adenocarcinoma", | |
| "E. ACE-inhibitor induced cough with incidental CT findings", | |
| ], | |
| ground_truth="A. Idiopathic pulmonary fibrosis", | |
| ground_truth_rank=0, | |
| ), | |
| }, | |
| "Dermatology: Pigmented Lesion": { | |
| "description": ( | |
| "A 62-year-old woman presents with a pigmented lesion on her left forearm. " | |
| "The lesion is 8mm x 6mm. Clinical photograph provided." | |
| ), | |
| "case": lambda: MedicalCase( | |
| case_id="demo_midas_001", | |
| dataset="midas", | |
| initial_channels={ | |
| "clinical_30cm": ChannelData( | |
| name="clinical_30cm", channel_type="image", | |
| description="Clinical photograph at 30cm distance", | |
| value=_make_dummy_image(224, 224, (180, 120, 100)), | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| }, | |
| requestable_channels={ | |
| "patient_demographics": ChannelData( | |
| name="patient_demographics", channel_type="text", | |
| description="Patient age, sex, and Fitzpatrick skin type", | |
| value="Age: 62; Sex: Female; Fitzpatrick skin type: III", | |
| cost=0.0, tier="free", | |
| ), | |
| "lesion_metadata": ChannelData( | |
| name="lesion_metadata", channel_type="text", | |
| description="Anatomic location, lesion length and width", | |
| value="Anatomic location: Left forearm; Lesion length: 8mm; Lesion width: 6mm", | |
| cost=25.0, tier="cheap", | |
| ), | |
| "clinical_15cm": ChannelData( | |
| name="clinical_15cm", channel_type="image", | |
| description="Clinical photograph at 15cm distance (closer view)", | |
| value=_make_dummy_image(224, 224, (170, 110, 90)), | |
| cost=50.0, tier="moderate", | |
| ), | |
| "dermoscopy": ChannelData( | |
| name="dermoscopy", channel_type="image", | |
| description="Dermoscopic image showing subsurface skin structures", | |
| value=_make_dummy_image(224, 224, (100, 80, 60)), | |
| cost=250.0, tier="expensive", | |
| ), | |
| }, | |
| candidates=[ | |
| "Melanoma in situ", | |
| "Dysplastic nevus", | |
| "Basal cell carcinoma", | |
| "Seborrheic keratosis", | |
| "Solar lentigo", | |
| ], | |
| ground_truth="Dysplastic nevus", | |
| ground_truth_rank=1, | |
| ), | |
| }, | |
| "Ophthalmology: Retinal Biomarkers (OLIVES)": { | |
| "description": ( | |
| "A patient with diabetic macular edema (DME), 4 prior anti-VEGF injections, " | |
| "32 weeks in treatment. Fundus photograph provided." | |
| ), | |
| "case": lambda: MedicalCase( | |
| case_id="demo_olives_P01", | |
| dataset="olives", | |
| initial_channels={ | |
| "disease_context": ChannelData( | |
| name="disease_context", channel_type="text", | |
| description="Disease type and treatment context", | |
| value="Disease: Diabetic Macular Edema (DME). Prior anti-VEGF injections: 4. Weeks in treatment: 32.", | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| }, | |
| requestable_channels={ | |
| "clinical_measurements": ChannelData( | |
| name="clinical_measurements", channel_type="text", | |
| description="Best Corrected Visual Acuity (BCVA) and Central Subfield Thickness (CST)", | |
| value="BCVA: 20/60 (logMAR 0.48); CST: 385 um", | |
| cost=20.0, tier="cheap", | |
| ), | |
| "biomarker_hints": ChannelData( | |
| name="biomarker_hints", channel_type="text", | |
| description="Expert-graded presence of fundus-visible retinal biomarkers", | |
| value="Hard Exudates: Present; Hemorrhage: Present; Microaneurysms: Present; Cotton Wool Spots: Not detected", | |
| cost=100.0, tier="moderate", | |
| ), | |
| "oct_scan": ChannelData( | |
| name="oct_scan", channel_type="image", | |
| description="OCT B-scan showing retinal cross-section", | |
| value=_make_dummy_image(512, 128, (60, 60, 60)), | |
| cost=300.0, tier="expensive", | |
| ), | |
| "additional_oct": ChannelData( | |
| name="additional_oct", channel_type="image", | |
| description="Additional OCT B-scans from different retinal locations", | |
| value=_make_dummy_image(512, 128, (50, 50, 55)), | |
| cost=150.0, tier="very_expensive", | |
| ), | |
| }, | |
| candidates=[ | |
| "Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms", | |
| "Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Fluid Srf, Hard Exudates, Hemorrhage, Microaneurysms", | |
| "Present biomarkers: Hard Exudates, Hemorrhage, Microaneurysms", | |
| "Present biomarkers: Dril, Ez Disruption, Fluid Irf, Shrm", | |
| "No biomarkers detected", | |
| ], | |
| ground_truth="Present biomarkers: Dril, Drt Me, Ez Disruption, Fluid Irf, Hard Exudates, Hemorrhage, Microaneurysms", | |
| ground_truth_rank=0, | |
| ), | |
| }, | |
| "NEJM: Cardiac Case": { | |
| "description": ( | |
| "A 45-year-old woman presents with sudden onset chest pain and shortness " | |
| "of breath. She recently completed a long international flight." | |
| ), | |
| "case": lambda: MedicalCase( | |
| case_id="demo_nejm_pe", | |
| dataset="nejm", | |
| initial_channels={ | |
| "demographics": ChannelData( | |
| name="demographics", channel_type="text", | |
| description="Patient age, sex, and ethnicity", | |
| value="A 45-year-old woman", always_given=True, cost=0.0, tier="free", | |
| ), | |
| "chief_complaint": ChannelData( | |
| name="chief_complaint", channel_type="text", | |
| description="Presenting symptoms and duration", | |
| value="Sudden onset chest pain and shortness of breath, started 2 hours ago after returning from a 14-hour international flight.", | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| "medical_history": ChannelData( | |
| name="medical_history", channel_type="text", | |
| description="Past medical conditions, medications, family and social history", | |
| value="On oral contraceptives for 5 years. BMI 32. No prior VTE. Mother had DVT at age 50.", | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| }, | |
| requestable_channels={ | |
| "exam_findings": ChannelData( | |
| name="exam_findings", channel_type="text", | |
| description="Physical examination results and observations", | |
| value="Tachycardic (HR 110), tachypneic (RR 24), SpO2 89% on room air. Right calf swollen and tender. JVP elevated. Loud P2 on cardiac auscultation.", | |
| cost=75.0, tier="cheap", | |
| ), | |
| "investigations": ChannelData( | |
| name="investigations", channel_type="text", | |
| description="Laboratory values, imaging results, and test outcomes", | |
| value="D-dimer: 4200 ng/mL (markedly elevated). Troponin I: 0.15 ng/mL (mildly elevated). ABG: pH 7.48, PaO2 62 mmHg, PaCO2 28 mmHg. ECG: S1Q3T3 pattern, right axis deviation. CT pulmonary angiography: bilateral pulmonary emboli with right heart strain.", | |
| cost=250.0, tier="moderate", | |
| ), | |
| "image": ChannelData( | |
| name="image", channel_type="image", | |
| description="CT Pulmonary Angiography image", | |
| value=_make_dummy_image(300, 300, (100, 100, 120)), | |
| cost=800.0, tier="expensive", | |
| ), | |
| }, | |
| candidates=[ | |
| "A. Pulmonary embolism", | |
| "B. Acute myocardial infarction", | |
| "C. Tension pneumothorax", | |
| "D. Aortic dissection", | |
| "E. Acute pericarditis", | |
| ], | |
| ground_truth="A. Pulmonary embolism", | |
| ground_truth_rank=0, | |
| ), | |
| }, | |
| } | |
| # ============================================================ | |
| # Custom Case Builder | |
| # ============================================================ | |
| def build_custom_case( | |
| scenario_text: str, | |
| candidates_text: str, | |
| channel_1_name: str, channel_1_type: str, channel_1_value: str, | |
| channel_2_name: str, channel_2_type: str, channel_2_value: str, | |
| channel_3_name: str, channel_3_type: str, channel_3_value: str, | |
| uploaded_image=None, | |
| ) -> MedicalCase: | |
| """Build a MedicalCase from user-provided custom inputs.""" | |
| candidates = [c.strip() for c in candidates_text.strip().split("\n") if c.strip()] | |
| if not candidates: | |
| candidates = ["Diagnosis A", "Diagnosis B", "Diagnosis C"] | |
| initial_channels = { | |
| "clinical_scenario": ChannelData( | |
| name="clinical_scenario", channel_type="text", | |
| description="The presenting clinical scenario", | |
| value=scenario_text, | |
| always_given=True, cost=0.0, tier="free", | |
| ), | |
| } | |
| if uploaded_image is not None: | |
| img_b64 = encode_pil_image_to_base64(Image.fromarray(uploaded_image)) | |
| initial_channels["uploaded_image"] = ChannelData( | |
| name="uploaded_image", channel_type="image", | |
| description="Uploaded medical image", | |
| value=img_b64, always_given=True, cost=0.0, tier="free", | |
| ) | |
| requestable = {} | |
| for name, ctype, value in [ | |
| (channel_1_name, channel_1_type, channel_1_value), | |
| (channel_2_name, channel_2_type, channel_2_value), | |
| (channel_3_name, channel_3_type, channel_3_value), | |
| ]: | |
| name = name.strip() | |
| value = value.strip() | |
| if name and value: | |
| key = name.lower().replace(" ", "_") | |
| requestable[key] = ChannelData( | |
| name=key, channel_type=ctype.lower(), | |
| description=name, | |
| value=value, | |
| cost=100.0, tier="moderate", | |
| ) | |
| # Register channel config so the agent can look it up | |
| custom_config = {} | |
| for name, ch in initial_channels.items(): | |
| custom_config[name] = { | |
| "description": ch.description, | |
| "type": ch.channel_type, | |
| "always_given": True, | |
| "tier": ch.tier, | |
| "cost": ch.cost, | |
| "order": 0, | |
| } | |
| for i, (name, ch) in enumerate(requestable.items()): | |
| custom_config[name] = { | |
| "description": ch.description, | |
| "type": ch.channel_type, | |
| "always_given": False, | |
| "tier": ch.tier, | |
| "cost": ch.cost, | |
| "order": i + 1, | |
| } | |
| config.CHANNEL_CONFIGS["custom"] = custom_config | |
| return MedicalCase( | |
| case_id="custom_case", | |
| dataset="custom", | |
| initial_channels=initial_channels, | |
| requestable_channels=requestable, | |
| candidates=candidates, | |
| ground_truth=candidates[0] if candidates else "", | |
| ground_truth_rank=0, | |
| ) | |
| # ============================================================ | |
| # Formatting Helpers | |
| # ============================================================ | |
| def format_step_markdown(step_idx: int, step: AcquisitionStep, cumulative_cost: float) -> str: | |
| """Format a single acquisition step as rich markdown.""" | |
| lines = [] | |
| if step.committed: | |
| lines.append(f"### Step {step_idx + 1}: COMMITTED TO DIAGNOSIS") | |
| lines.append("") | |
| lines.append(f"**Reasoning:** {step.reasoning}") | |
| lines.append("") | |
| if step.differential: | |
| lines.append("**Final Ranking:**") | |
| for d in step.differential: | |
| conf = d.get("confidence", 0) | |
| bar = render_bar(conf) | |
| evidence = d.get("key_evidence", "") | |
| lines.append(f"- **{d['name']}** β {conf:.1%} {bar}") | |
| if evidence: | |
| lines.append(f" - *Evidence:* {evidence}") | |
| else: | |
| lines.append(f"### Step {step_idx + 1}: Requested `{step.requested_channel}`") | |
| lines.append("") | |
| lines.append(f"**Reasoning:** {step.reasoning}") | |
| lines.append("") | |
| if step.differential: | |
| lines.append("**Current Differential:**") | |
| for d in step.differential: | |
| conf = d.get("confidence", 0) | |
| bar = render_bar(conf) | |
| lines.append(f"- {d['name']} β {conf:.1%} {bar}") | |
| if step.expected_impact: | |
| lines.append("") | |
| lines.append("**Expected Impact:**") | |
| pos = step.expected_impact.get("if_positive", "N/A") | |
| neg = step.expected_impact.get("if_negative", "N/A") | |
| lines.append(f"- If positive/abnormal: *{pos}*") | |
| lines.append(f"- If negative/normal: *{neg}*") | |
| lines.append("") | |
| lines.append("**Information Metrics:**") | |
| lines.append(f"- Entropy: **{step.entropy:.3f}** bits") | |
| if step.information_gain: | |
| lines.append(f"- Information Gain: **{step.information_gain:.3f}** bits") | |
| if step.kl_divergence: | |
| lines.append(f"- KL Divergence: **{step.kl_divergence:.3f}** bits") | |
| lines.append(f"- Latency: {step.latency_ms:.0f}ms") | |
| lines.append(f"- Cumulative Cost: ${cumulative_cost:,.0f}") | |
| lines.append("") | |
| lines.append("---") | |
| return "\n".join(lines) | |
| def render_bar(value: float, width: int = 20) -> str: | |
| """Render a text-based progress bar.""" | |
| filled = int(value * width) | |
| return "`" + "\u2588" * filled + "\u2591" * (width - filled) + "`" | |
| def format_entropy_table(trajectory: BeliefTrajectory) -> str: | |
| """Format entropy trajectory as a markdown table.""" | |
| if not trajectory or not trajectory.states: | |
| return "*No belief trajectory recorded.*" | |
| lines = ["| Step | Channel | Entropy (bits) | Info Gain | Cumulative IG |"] | |
| lines.append("|------|---------|---------------|-----------|---------------|") | |
| cumulative_ig = 0.0 | |
| for i, state in enumerate(trajectory.states): | |
| ch = state.channel_acquired or "initial/commit" | |
| ig = 0.0 | |
| if i > 0: | |
| ig = trajectory.states[i - 1].entropy - state.entropy | |
| cumulative_ig += ig | |
| lines.append( | |
| f"| {i} | {ch} | {state.entropy:.3f} | " | |
| f"{ig:+.3f} | {cumulative_ig:.3f} |" | |
| ) | |
| lines.append("") | |
| lines.append(f"**Information Efficiency:** {trajectory.information_efficiency:.1%}") | |
| lines.append(f"**Total Information Gain:** {trajectory.total_information_gain:.3f} bits") | |
| return "\n".join(lines) | |
| def format_summary(result: AgentResult, case: MedicalCase) -> str: | |
| """Format the overall result summary.""" | |
| lines = [] | |
| lines.append("## Summary") | |
| lines.append("") | |
| if result.final_ranking: | |
| top = result.final_ranking[0] | |
| top_name = top["name"].strip().lower() | |
| gt_name = case.ground_truth.strip().lower() | |
| # Fuzzy match: handle "Pulmonary embolism" vs "A. Pulmonary embolism" | |
| correct = top_name == gt_name or top_name in gt_name or gt_name in top_name | |
| icon = "correct" if correct else "incorrect" | |
| lines.append(f"**Top Diagnosis:** {top['name']} ({top['confidence']:.1%})") | |
| lines.append(f"**Ground Truth:** {case.ground_truth}") | |
| lines.append(f"**Result:** {icon}") | |
| else: | |
| lines.append("*No diagnosis produced.*") | |
| lines.append("") | |
| lines.append(f"**Channels Acquired:** {len(result.acquired_channels)} / {len(case.requestable_channels)}") | |
| if result.acquired_channels: | |
| lines.append(f"**Acquisition Order:** {' -> '.join(result.acquired_channels)}") | |
| lines.append(f"**Committed Early:** {'Yes' if result.committed_early else 'No'}") | |
| lines.append(f"**Total Acquisition Cost:** ${result.acquisition_cost:,.0f}") | |
| lines.append(f"**Total Case Cost:** ${result.total_case_cost:,.0f}") | |
| lines.append(f"**Total Latency:** {result.total_latency_ms:,.0f}ms") | |
| lines.append(f"**Tokens Used:** {result.total_input_tokens:,} in / {result.total_output_tokens:,} out") | |
| return "\n".join(lines) | |
| # ============================================================ | |
| # Main Agent Runner (for Gradio) | |
| # ============================================================ | |
| def run_agent_on_case( | |
| case: MedicalCase, | |
| backend: str, | |
| context_mode: str, | |
| user_api_key: str = "", | |
| ) -> tuple[str, str, str]: | |
| """ | |
| Run the agent on a case and return formatted markdown outputs. | |
| Returns: (steps_markdown, entropy_table, summary_markdown) | |
| """ | |
| if backend == "simulated (no API key)": | |
| result = _simulate_agent_on_case(case) | |
| model_name = "simulated" | |
| else: | |
| try: | |
| kwargs = {} | |
| if user_api_key and user_api_key.strip(): | |
| kwargs["api_key"] = user_api_key.strip() | |
| client = create_client(backend, **kwargs) | |
| except Exception as e: | |
| return ( | |
| f"**Error creating {backend} client:** {e}\n\n" | |
| "Make sure you enter your API key above, or set it in environment variables. " | |
| "Or select **simulated (no API key)** to see a demo trace.", | |
| "", "", | |
| ) | |
| agent = ActiveMedAgent( | |
| client, | |
| prompt_variant="A", | |
| budget=None, # NO BUDGET CONSTRAINT | |
| context_mode=context_mode if context_mode != "adaptive" else None, | |
| ) | |
| try: | |
| result = agent.diagnose(case) | |
| except Exception as e: | |
| return f"**Error running agent:** {e}", "", "" | |
| model_name = client.model | |
| # Format step-by-step reasoning | |
| steps_parts = [] | |
| steps_parts.append("# Agent Reasoning Trace\n") | |
| steps_parts.append(f"**Case:** {case.case_id} | **Dataset:** {case.dataset} | **Backend:** {model_name}\n") | |
| steps_parts.append(f"**Candidates:** {', '.join(case.candidates)}\n") | |
| initial_info = format_acquired_info(case.get_text_context([])) | |
| steps_parts.append(f"**Initial Information:**\n{initial_info}\n") | |
| steps_parts.append("---\n") | |
| cumulative_cost = case.get_initial_cost() | |
| for i, step in enumerate(result.steps): | |
| if step.requested_channel: | |
| cumulative_cost += case.get_channel_cost(step.requested_channel) | |
| steps_parts.append(format_step_markdown(i, step, cumulative_cost)) | |
| steps_md = "\n".join(steps_parts) | |
| # Format entropy trajectory | |
| entropy_md = "" | |
| if result.belief_trajectory: | |
| entropy_md = format_entropy_table(result.belief_trajectory) | |
| # Format summary | |
| summary_md = format_summary(result, case) | |
| return steps_md, entropy_md, summary_md | |
| # ============================================================ | |
| # Gradio Event Handlers | |
| # ============================================================ | |
| def on_demo_case_selected(case_name: str) -> tuple[str, str]: | |
| """When a demo case is selected, show its description and candidates.""" | |
| if case_name in DEMO_CASES: | |
| info = DEMO_CASES[case_name] | |
| case = info["case"]() | |
| desc = info["description"] | |
| cands = "\n".join(case.candidates) | |
| channels = [] | |
| for name, ch in case.requestable_channels.items(): | |
| channels.append(f"- **{name}** ({ch.tier}, ${ch.cost:,.0f}): {ch.description}") | |
| ch_str = "\n".join(channels) | |
| return ( | |
| f"{desc}\n\n**Available channels to acquire:**\n{ch_str}", | |
| cands, | |
| ) | |
| return "", "" | |
| def run_demo_case(case_name: str, backend: str, context_mode: str, user_api_key: str = ""): | |
| """Run agent on a selected demo case.""" | |
| if case_name not in DEMO_CASES: | |
| return "Please select a demo case.", "", "" | |
| case = DEMO_CASES[case_name]["case"]() | |
| return run_agent_on_case(case, backend, context_mode, user_api_key) | |
| def run_custom_case( | |
| scenario: str, candidates: str, | |
| ch1_name: str, ch1_type: str, ch1_value: str, | |
| ch2_name: str, ch2_type: str, ch2_value: str, | |
| ch3_name: str, ch3_type: str, ch3_value: str, | |
| uploaded_image, | |
| backend: str, context_mode: str, | |
| user_api_key: str = "", | |
| ): | |
| """Run agent on a custom user-defined case.""" | |
| if not scenario.strip(): | |
| return "Please enter a clinical scenario.", "", "" | |
| case = build_custom_case( | |
| scenario, candidates, | |
| ch1_name, ch1_type, ch1_value, | |
| ch2_name, ch2_type, ch2_value, | |
| ch3_name, ch3_type, ch3_value, | |
| uploaded_image, | |
| ) | |
| return run_agent_on_case(case, backend, context_mode, user_api_key) | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| def create_app(): | |
| with gr.Blocks( | |
| title="ActiveMedAgent Interactive Demo", | |
| ) as app: | |
| gr.Markdown( | |
| """ | |
| # ActiveMedAgent: Learned Information Acquisition for Medical Diagnosis | |
| **Interactive Demo** β Watch the agent reason step-by-step, acquire information channels, | |
| and track entropy reduction. **No budget constraint** β the agent decides when to stop. | |
| """, | |
| elem_classes="header-text", | |
| ) | |
| # Build backend choices: always show real backends (user can paste their own key) | |
| all_backends = ["openai", "anthropic", "together"] | |
| backend_choices = ["simulated (no API key)"] + all_backends | |
| default_backend = "openai" | |
| with gr.Row(): | |
| backend = gr.Dropdown( | |
| choices=backend_choices, | |
| value=default_backend, | |
| label="VLM Backend", | |
| info="Select 'simulated' to see the demo without API keys", | |
| scale=1, | |
| ) | |
| context_mode = gr.Dropdown( | |
| choices=["adaptive", "full", "condensed"], | |
| value="adaptive", | |
| label="Context Mode", | |
| info="How the agent manages conversation history", | |
| scale=1, | |
| ) | |
| user_api_key = gr.Textbox( | |
| label="API Key (optional)", | |
| placeholder="sk-... (paste your OpenAI/Anthropic key)", | |
| type="password", | |
| info="Enter your own API key to use a real VLM backend", | |
| scale=1, | |
| ) | |
| with gr.Tabs(): | |
| # ---- Tab 1: Demo Cases ---- | |
| with gr.TabItem("Demo Cases"): | |
| gr.Markdown("Select a pre-built clinical scenario and run the agent.") | |
| with gr.Row(): | |
| case_selector = gr.Dropdown( | |
| choices=list(DEMO_CASES.keys()), | |
| label="Select Case", | |
| scale=2, | |
| ) | |
| run_demo_btn = gr.Button("Run Agent", variant="primary", scale=1) | |
| case_description = gr.Markdown(label="Case Description") | |
| case_candidates = gr.Textbox(label="Candidate Diagnoses", lines=3, interactive=False) | |
| case_selector.change( | |
| fn=on_demo_case_selected, | |
| inputs=[case_selector], | |
| outputs=[case_description, case_candidates], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| demo_steps = gr.Markdown( | |
| label="Reasoning Steps", | |
| elem_classes="reasoning-box", | |
| ) | |
| with gr.Column(scale=1): | |
| demo_summary = gr.Markdown(label="Summary") | |
| demo_entropy = gr.Markdown(label="Entropy Trajectory") | |
| run_demo_btn.click( | |
| fn=run_demo_case, | |
| inputs=[case_selector, backend, context_mode, user_api_key], | |
| outputs=[demo_steps, demo_entropy, demo_summary], | |
| ) | |
| # ---- Tab 2: Custom Case ---- | |
| with gr.TabItem("Custom Case"): | |
| gr.Markdown( | |
| "Define your own clinical scenario, candidate diagnoses, " | |
| "and information channels the agent can request." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| custom_scenario = gr.Textbox( | |
| label="Clinical Scenario", | |
| placeholder="A 35-year-old woman presents with...", | |
| lines=4, | |
| ) | |
| custom_candidates = gr.Textbox( | |
| label="Candidate Diagnoses (one per line)", | |
| placeholder="A. Diagnosis one\nB. Diagnosis two\nC. Diagnosis three", | |
| lines=5, | |
| ) | |
| custom_image = gr.Image( | |
| label="Upload Medical Image (optional)", | |
| type="numpy", | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Requestable Information Channels") | |
| gr.Markdown("Define up to 3 channels the agent can request.") | |
| with gr.Group(): | |
| gr.Markdown("**Channel 1:**") | |
| ch1_name = gr.Textbox(label="Name", value="Exam Findings", scale=1) | |
| ch1_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type") | |
| ch1_value = gr.Textbox(label="Content (what the agent receives)", lines=2, | |
| placeholder="Physical exam: temperature 38.5C, ...") | |
| with gr.Group(): | |
| gr.Markdown("**Channel 2:**") | |
| ch2_name = gr.Textbox(label="Name", value="Lab Results", scale=1) | |
| ch2_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type") | |
| ch2_value = gr.Textbox(label="Content", lines=2, | |
| placeholder="WBC 12,000, CRP elevated, ...") | |
| with gr.Group(): | |
| gr.Markdown("**Channel 3:**") | |
| ch3_name = gr.Textbox(label="Name", value="Imaging", scale=1) | |
| ch3_type = gr.Dropdown(choices=["text", "image"], value="text", label="Type") | |
| ch3_value = gr.Textbox(label="Content", lines=2, | |
| placeholder="CT scan shows...") | |
| run_custom_btn = gr.Button("Run Agent on Custom Case", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| custom_steps = gr.Markdown( | |
| label="Reasoning Steps", | |
| elem_classes="reasoning-box", | |
| ) | |
| with gr.Column(scale=1): | |
| custom_summary = gr.Markdown(label="Summary") | |
| custom_entropy = gr.Markdown(label="Entropy Trajectory") | |
| run_custom_btn.click( | |
| fn=run_custom_case, | |
| inputs=[ | |
| custom_scenario, custom_candidates, | |
| ch1_name, ch1_type, ch1_value, | |
| ch2_name, ch2_type, ch2_value, | |
| ch3_name, ch3_type, ch3_value, | |
| custom_image, | |
| backend, context_mode, | |
| user_api_key, | |
| ], | |
| outputs=[custom_steps, custom_entropy, custom_summary], | |
| ) | |
| # ---- Tab 3: How It Works ---- | |
| with gr.TabItem("How It Works"): | |
| gr.Markdown(""" | |
| ## ActiveMedAgent Architecture | |
| ### Tool-Use Acquisition Loop | |
| The agent uses native VLM function calling (not regex parsing) with two tools: | |
| 1. **`request_information`** β Request one data channel, providing reasoning, current differential with calibrated probabilities, and expected impact | |
| 2. **`commit_diagnosis`** β Submit final ranked diagnosis when confident | |
| ### No Budget Constraint | |
| The agent acquires as many channels as it needs (0 to all). It stops when: | |
| - It calls `commit_diagnosis` (self-determined confidence) | |
| - Information-theoretic stopping criteria trigger (convergence, confirmed dominance, or diminishing returns) | |
| - All channels are exhausted | |
| ### Information-Theoretic Metrics | |
| At each step, the system tracks: | |
| - **Shannon Entropy** H(p) β diagnostic uncertainty in bits | |
| - **Information Gain** β entropy reduction from each acquisition | |
| - **KL Divergence** β how much the belief distribution shifted | |
| - **Expected Information Gain (EIG)** β predicted value of the next channel | |
| - **Value of Information (VoI)** β whether continuing to acquire is worthwhile | |
| ### Context Management | |
| - **Full Mode**: Multi-turn conversation with complete history (for capable models) | |
| - **Condensed Mode**: Fresh single-turn call each step with compressed state log (for weaker models) | |
| - **Adaptive**: Auto-selects based on model capability | |
| ### Stopping Criteria | |
| 1. **Convergence**: Last acquisition < 0.05 bits of IG | |
| 2. **Confirmed Dominance**: Top diagnosis > 90% probability with > 40% gap (after 2+ acquisitions) | |
| 3. **Diminishing Returns**: Last 2 acquisitions both < 0.1 bits IG | |
| """) | |
| return app | |
| # ============================================================ | |
| # Entry Point | |
| # ============================================================ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="ActiveMedAgent Interactive Demo") | |
| parser.add_argument("--port", type=int, default=7860, help="Port to serve on") | |
| parser.add_argument("--backend", default="openai", choices=["openai", "anthropic", "together"]) | |
| parser.add_argument("--share", action="store_true", help="Create a public Gradio link") | |
| args = parser.parse_args() | |
| app = create_app() | |
| app.launch( | |
| server_port=args.port, | |
| share=args.share, | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .reasoning-box { font-size: 14px; } | |
| .header-text { text-align: center; margin-bottom: 10px; } | |
| """, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |