yipengsun Claude Opus 4.5 commited on
Commit
c0fff99
·
0 Parent(s):

Initial commit: Diagnostic Devil's Advocate project

Browse files

Multi-agent medical diagnosis system using MedGemma, MedSigLIP, and MedASR
with LangGraph orchestration and Gradio UI.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ *.egg
10
+
11
+ # Virtual environments
12
+ .venv/
13
+ venv/
14
+ env/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+
22
+ # Jupyter
23
+ .ipynb_checkpoints/
24
+
25
+ # Environment variables
26
+ .env
27
+ .env.*
28
+
29
+ # OS
30
+ .DS_Store
31
+ Thumbs.db
32
+
33
+ # Model weights / large files
34
+ *.bin
35
+ *.pt
36
+ *.pth
37
+ *.onnx
38
+ *.safetensors
39
+
40
+ # Logs
41
+ *.log
42
+ logs/
README.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Diagnostic Devil's Advocate
3
+ emoji: "\U0001FA7A"
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: "5.12.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ tags:
12
+ - medgemma
13
+ - medical-imaging
14
+ - multi-agent
15
+ - cognitive-bias
16
+ - radiology
17
+ ---
18
+
19
+ <div align="center">
20
+
21
+ # 🩺 Diagnostic Devil's Advocate
22
+
23
+ ### AI-Powered Cognitive Debiasing for Clinical Diagnosis
24
+
25
+ **A multi-agent system that challenges medical diagnoses to catch what doctors might miss.**
26
+
27
+ [![MedGemma](https://img.shields.io/badge/MedGemma-4B%20%7C%2027B-4285F4?style=for-the-badge&logo=google&logoColor=white)](https://huggingface.co/google/medgemma-1.5-4b-it)
28
+ [![MedSigLIP](https://img.shields.io/badge/MedSigLIP-448-34A853?style=for-the-badge&logo=google&logoColor=white)](https://huggingface.co/google/medsiglip-448)
29
+ [![LangGraph](https://img.shields.io/badge/LangGraph-Agent%20Pipeline-1C3C3C?style=for-the-badge&logo=langchain&logoColor=white)](https://langchain-ai.github.io/langgraph/)
30
+ [![Gradio](https://img.shields.io/badge/Gradio-UI-F97316?style=for-the-badge&logo=gradio&logoColor=white)](https://gradio.app)
31
+
32
+ [Live Demo](#getting-started) &bull; [Architecture](#architecture) &bull; [Demo Cases](#demo-cases) &bull; [Technical Details](#technical-details)
33
+
34
+ ---
35
+
36
+ </div>
37
+
38
+ ## The Problem
39
+
40
+ > *Diagnostic errors affect an estimated **12 million** adults annually in the U.S. alone, with cognitive biases — [anchoring](https://en.wikipedia.org/wiki/Anchoring_(cognitive_bias)), [premature closure](https://en.wikipedia.org/wiki/Premature_closure), [confirmation bias](https://en.wikipedia.org/wiki/Confirmation_bias) — implicated in up to **74%** of cases.* ([Singh et al., BMJ Quality & Safety, 2014](https://qualitysafety.bmj.com/content/23/9/727))
41
+
42
+ Doctors are not wrong because they lack knowledge. They are wrong because the human brain takes shortcuts — and in medicine, shortcuts kill. A physician who sees "young patient + chest pain after trauma" anchors on **rib contusion** and stops looking. The pneumothorax on the X-ray goes unseen. The patient deteriorates.
43
+
44
+ **Diagnostic Devil's Advocate** is a system that acts as an adversarial second opinion. It does not replace the physician — it challenges them. It asks: *"Have you considered what happens if you're wrong?"*
45
+
46
+ ## How It Works
47
+
48
+ The system runs a **4-agent pipeline** orchestrated by [LangGraph](https://langchain-ai.github.io/langgraph/) where each agent has a distinct adversarial role. Every agent analyzes **both the medical image and the full clinical context** (history, vitals, labs, exam findings) — because some dangerous conditions (aortic dissection, pulmonary embolism) may show subtle or no imaging signs but have obvious clinical red flags. Critically, the first agent does this **without seeing the doctor's diagnosis**, preventing the AI itself from being [anchored](https://en.wikipedia.org/wiki/Anchoring_(cognitive_bias)).
49
+
50
+ ### The Four Agents
51
+
52
+ | Agent | Role | Model | Key Design Choice |
53
+ |:------|:-----|:------|:------------------|
54
+ | **Diagnostician** | Independent image + clinical analysis | [MedGemma 4B-IT](https://huggingface.co/google/medgemma-1.5-4b-it) (multimodal) | **Blinded** — never sees the doctor's diagnosis. Tags each finding as `imaging`, `clinical`, or `both` to distinguish evidence sources. |
55
+ | **Bias Detector** | Compare doctor vs. AI findings | [MedGemma](https://huggingface.co/google/medgemma-1.5-4b-it) 4B/27B + [MedSigLIP](https://huggingface.co/google/medsiglip-448) | Uses **zero-shot image classification** to verify radiological signs. Flags clinical red flags ignored by either assessment. |
56
+ | **Devil's Advocate** | Adversarial challenge | [MedGemma](https://huggingface.co/google/medgemma-27b-text-it) 4B/27B | Deliberately contrarian — uses both imaging and clinical evidence to argue for **[must-not-miss diagnoses](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6775443/)** |
57
+ | **Consultant** | Synthesize final report | [MedGemma](https://huggingface.co/google/medgemma-27b-text-it) 4B/27B | Writes as a **collegial consultant**: *"Have you considered..."* not *"You are wrong."* |
58
+
59
+ ## Architecture
60
+
61
+ The pipeline is orchestrated by [LangGraph](https://langchain-ai.github.io/langgraph/) as a linear `StateGraph`:
62
+
63
+ **Gradio UI** (image upload, diagnosis input, clinical context, [MedASR](https://huggingface.co/google/medasr) voice input)
64
+ → **Diagnostician** — receives image + clinical context but **NOT** the doctor's diagnosis; tags findings by source (`imaging` / `clinical` / `both`)
65
+ → **Bias Detector** — now receives the doctor's diagnosis, compares it against independent findings using image, clinical data, and [MedSigLIP](https://huggingface.co/google/medsiglip-448) sign verification
66
+ → **Devil's Advocate** — challenges the working diagnosis using both imaging and clinical evidence for must-not-miss alternatives
67
+ → **Consultant** — synthesizes a collegial consultation note
68
+ → **Output** (consultation report, alternative diagnoses, recommended workup)
69
+
70
+ ### MedSigLIP Sign Verification
71
+
72
+ The Bias Detector doesn't just rely on text reasoning — it uses [**MedSigLIP-448**](https://huggingface.co/google/medsiglip-448) for objective visual verification. For each radiological sign mentioned by the Diagnostician (e.g., "pleural effusion", "cardiomegaly", "pneumothorax"), MedSigLIP performs [zero-shot binary classification](https://huggingface.co/tasks/zero-shot-image-classification): it compares the logits of `"chest radiograph showing [sign]"` vs `"normal chest radiograph with no [sign]"`. A logit difference > 2 is classified as "likely present", grounding the bias analysis in **visual evidence** rather than pure language reasoning.
73
+
74
+ ## Demo Cases
75
+
76
+ Three composite clinical scenarios covering the most dangerous diagnostic error patterns:
77
+
78
+ <table>
79
+ <tr>
80
+ <td width="33%" valign="top">
81
+
82
+ ### Case 1: Missed Pneumothorax
83
+ **🏷️ TRAUMA**
84
+
85
+ 32M, motorcycle collision. Doctor diagnoses **rib contusion**, discharges patient. Supine CXR actually shows a **left pneumothorax** with rib fractures.
86
+
87
+ **Bias**: [Satisfaction of search](https://radiopaedia.org/articles/satisfaction-of-search) — found the rib fractures, stopped looking.
88
+
89
+ </td>
90
+ <td width="33%" valign="top">
91
+
92
+ ### Case 2: Aortic Dissection → "GERD"
93
+ **🏷️ VASCULAR**
94
+
95
+ 58M, hypertensive, tearing chest pain. Doctor diagnoses **acid reflux**, prescribes antacids. Blood pressure asymmetry (178/102 R vs 146/88 L) and D-dimer 4,850 suggest **Stanford type B dissection**.
96
+
97
+ **Bias**: [Anchoring](https://en.wikipedia.org/wiki/Anchoring_(cognitive_bias)) + [availability heuristic](https://en.wikipedia.org/wiki/Availability_heuristic) — common diagnosis assumed first.
98
+
99
+ </td>
100
+ <td width="33%" valign="top">
101
+
102
+ ### Case 3: Postpartum PE → "Anxiety"
103
+ **🏷️ POSTPARTUM**
104
+
105
+ 29F, day 5 post C-section, dyspnea and tachycardia. Doctor orders **psychiatric consult**. SpO2 91%, ABG shows respiratory alkalosis — classic **pulmonary embolism**.
106
+
107
+ **Bias**: [Premature closure](https://en.wikipedia.org/wiki/Premature_closure) + [framing effect](https://en.wikipedia.org/wiki/Framing_effect_(psychology)) — young woman = anxiety.
108
+
109
+ </td>
110
+ </tr>
111
+ </table>
112
+
113
+ > All cases are educational composites synthesized from published literature. See [`data/demo_cases/SOURCES.md`](data/demo_cases/SOURCES.md) for full citations.
114
+
115
+ ## Technical Details
116
+
117
+ ### Model Stack
118
+
119
+ | Model | Parameters | Role | Loading |
120
+ |:------|:----------|:-----|:--------|
121
+ | [MedGemma 1.5 4B-IT](https://huggingface.co/google/medgemma-1.5-4b-it) | 4B | Multimodal image+text analysis | 4-bit quantized (~4GB VRAM) or BF16 (~8GB) |
122
+ | [MedGemma 27B Text-IT](https://huggingface.co/google/medgemma-27b-text-it) | 27B | Advanced clinical reasoning | BF16 (~54GB VRAM), A100 only |
123
+ | [MedSigLIP-448](https://huggingface.co/google/medsiglip-448) | 0.9B | Zero-shot sign verification | FP32 (~3GB VRAM) |
124
+ | [MedASR](https://huggingface.co/google/medasr) | 105M | Medical speech-to-text | FP32 (~0.5GB VRAM) |
125
+
126
+ ### Hardware Profiles
127
+
128
+ | Environment | GPU | Configuration | VRAM Usage |
129
+ |:------------|:----|:-------------|:-----------|
130
+ | **Local dev** | RTX 4070 12GB | 4B 4-bit + MedSigLIP + MedASR | ~7.5 GB |
131
+ | **School HPC** | A100 80GB | 4B BF16 + **27B BF16** + MedSigLIP + MedASR | ~66 GB |
132
+ | **HF Space** | T4 16GB | 4B 4-bit + MedSigLIP + MedASR | ~7.5 GB |
133
+ | **Kaggle** | T4 16GB | 4B 4-bit + MedSigLIP | ~7 GB |
134
+
135
+ All models load locally via [Transformers](https://huggingface.co/docs/transformers) with optional [4-bit quantization](https://huggingface.co/docs/bitsandbytes) — **zero API costs, fully offline-capable**.
136
+
137
+ ### Key Technical Decisions
138
+
139
+ - **Blinded Diagnostician**: The first agent never sees the doctor's diagnosis. This prevents the AI from anchoring on the same conclusion, enabling genuine independent analysis.
140
+
141
+ - **Dual-source analysis (imaging + clinical)**: All agents analyze both the medical image and the full clinical context (vitals, labs, risk factors). Each Diagnostician finding is tagged with its source (`imaging`, `clinical`, or `both`). This is critical because many must-not-miss diagnoses — aortic dissection (BP asymmetry), pulmonary embolism (low SpO2, elevated D-dimer) — may have subtle or absent imaging signs but glaring clinical red flags.
142
+
143
+ - **Structured JSON output**: All agents output structured JSON parsed by [`json_repair`](https://github.com/mangiucugna/json_repair), which handles LLM output quirks (missing commas, truncation, markdown wrapping).
144
+
145
+ - **Thinking token stripping**: MedGemma wraps internal reasoning in `<unused94>...<unused95>` tags ([model card](https://huggingface.co/google/medgemma-27b-text-it#thinking-mode)). These are stripped via regex before display.
146
+
147
+ - **Adaptive model routing**: `generate_text()` automatically routes to 27B when `USE_27B=true`, else falls back to 4B. `generate_with_image()` always uses 4B (only model with vision).
148
+
149
+ - **Collegial tone**: The Consultant is prompted to write as a consulting colleague, not a critic. Research shows physicians respond better to [collaborative challenge than confrontation](https://pubmed.ncbi.nlm.nih.gov/28493811/).
150
+
151
+ - **Prompt Repetition**: All agents use the prompt repetition technique from [*"Prompt Repetition Improves Non-Reasoning LLMs"*](https://arxiv.org/abs/2512.14982) (Google Research, 2025). The user prompt is repeated with a transition phrase (`<query> Let me repeat the request: <query>`), which won **47 out of 70** benchmark-model combinations with **zero losses** — at nearly zero cost (only increases prefill tokens, no extra generation). Controllable via `ENABLE_PROMPT_REPETITION` env var.
152
+
153
+ ## Getting Started
154
+
155
+ ### Prerequisites
156
+
157
+ - Python 3.11+
158
+ - CUDA-capable GPU (12GB+ VRAM)
159
+ - [Hugging Face account](https://huggingface.co) with access to gated models (MedGemma, MedSigLIP, MedASR)
160
+
161
+ ### Installation
162
+
163
+ ```bash
164
+ # Clone the repository
165
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/diagnostic-devils-advocate
166
+ cd diagnostic-devils-advocate
167
+
168
+ # Install dependencies
169
+ pip install -r requirements.txt
170
+
171
+ # Login to Hugging Face (required for gated models)
172
+ huggingface-cli login
173
+ ```
174
+
175
+ ### Running
176
+
177
+ ```bash
178
+ # Standard launch (4B quantized, 12GB GPU)
179
+ python app.py
180
+
181
+ # With 27B reasoning model (A100 80GB required)
182
+ USE_27B=true QUANTIZE_4B=false python app.py
183
+
184
+ # Disable voice input
185
+ ENABLE_MEDASR=false python app.py
186
+ ```
187
+
188
+ The app launches at `http://localhost:7860`.
189
+
190
+ ### Environment Variables
191
+
192
+ | Variable | Default | Description |
193
+ |:---------|:--------|:------------|
194
+ | `USE_27B` | `false` | Enable 27B model for text-only agents |
195
+ | `QUANTIZE_4B` | `true` | 4-bit quantize the 4B model |
196
+ | `ENABLE_MEDASR` | `true` | Enable voice input via MedASR |
197
+ | `HF_TOKEN` | — | Hugging Face token (or use `huggingface-cli login`) |
198
+ | `ENABLE_PROMPT_REPETITION` | `true` | [Prompt repetition](https://arxiv.org/abs/2512.14982) for improved output quality |
199
+ | `MODEL_LOCAL_DIR` | — | Local directory for pre-downloaded models |
200
+ | `DEVICE` | `cuda` | Compute device |
201
+
202
+ ## Project Structure
203
+
204
+ ```
205
+ diagnostic-devils-advocate/
206
+ ├── app.py # Gradio entry point
207
+ ├── config.py # Model selection & environment config
208
+ ├── requirements.txt
209
+
210
+ ├── agents/
211
+ │ ├── state.py # LangGraph TypedDict state definitions
212
+ │ ├── prompts.py # All agent prompt templates
213
+ │ ├── graph.py # LangGraph StateGraph pipeline
214
+ │ ├── output_parser.py # JSON parsing with json_repair
215
+ │ ├── diagnostician.py # Agent 1: Blinded image + clinical analysis
216
+ │ ├── bias_detector.py # Agent 2: Bias detection + MedSigLIP
217
+ │ ├── devil_advocate.py # Agent 3: Adversarial challenge
218
+ │ └── consultant.py # Agent 4: Consultation note synthesis
219
+
220
+ ├── models/
221
+ │ ├── medgemma_client.py # MedGemma 4B/27B inference client
222
+ │ ├── medsiglip_client.py # MedSigLIP zero-shot classification
223
+ │ ├── medasr_client.py # MedASR speech-to-text
224
+ │ └── utils.py # Image preprocessing, token stripping
225
+
226
+ ├── ui/
227
+ │ ├── components.py # Gradio layout & progress visualization
228
+ │ ├── callbacks.py # UI event handlers & pipeline integration
229
+ │ └── css.py # Custom styling (responsive design)
230
+
231
+ ├── data/
232
+ │ └── demo_cases/ # 3 composite clinical scenarios
233
+ │ └── SOURCES.md # Full literature citations
234
+
235
+ └── tests/
236
+ ├── test_smoke.py # Import & build verification
237
+ ├── test_output_parser.py # JSON repair tests
238
+ └── test_pipeline_mock.py # Integration tests with mocked models
239
+ ```
240
+
241
+ ## Testing
242
+
243
+ ```bash
244
+ python -m pytest tests/ -v
245
+ ```
246
+
247
+ ## Disclaimer
248
+
249
+ > **This is a research prototype built for the MedGemma Impact Challenge. It is NOT intended for clinical decision-making.** All demo cases are educational composites. Medical images are sourced from the University of Saskatchewan Teaching Collection (CC-BY-NC-SA 4.0).
250
+
251
+ ## References
252
+
253
+ - Singh H, et al. "The frequency of diagnostic errors in outpatient care." [*BMJ Quality & Safety*, 2014](https://qualitysafety.bmj.com/content/23/9/727)
254
+ - Graber ML, et al. "Cognitive interventions to reduce diagnostic error." [*BMJ Quality & Safety*, 2012](https://qualitysafety.bmj.com/content/21/7/535)
255
+ - Croskerry P. "The importance of cognitive errors in diagnosis." [*Academic Medicine*, 2003](https://pubmed.ncbi.nlm.nih.gov/12915371/)
256
+ - Ball CG, et al. "Incidence, risk factors, and outcomes for occult pneumothoraces." [*J Trauma*, 2005](https://pubmed.ncbi.nlm.nih.gov/16374282/)
257
+ - Hansen MS, et al. "Frequency of misdiagnosis of acute aortic dissection." [*Am J Cardiol*, 2007](https://pubmed.ncbi.nlm.nih.gov/17350380/)
258
+ - Ivgi M, et al. "Prompt Repetition Improves Non-Reasoning LLMs." [*arXiv:2512.14982*](https://arxiv.org/abs/2512.14982), Google Research, 2025
259
+ - Google Health AI. [Health AI Developer Foundations (HAI-DEF)](https://developers.google.com/health-ai)
260
+ - Yang J, et al. [MedGemma: Medical AI model](https://huggingface.co/collections/google/health-ai-developer-foundations-68544906f8a0a10f7d30ade8) — Hugging Face Collection
261
+
262
+ ---
263
+
264
+ <div align="center">
265
+
266
+ Built with [Google Health AI Developer Foundations](https://developers.google.com/health-ai) for the [MedGemma Impact Challenge](https://www.kaggle.com/competitions/medgemma-impact-challenge)
267
+
268
+ </div>
agents/__init__.py ADDED
File without changes
agents/bias_detector.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bias Detector agent: compares doctor's diagnosis with independent analysis to identify cognitive biases.
3
+ Runs MedSigLIP sign verification on imaging findings mentioned by the Diagnostician.
4
+ Outputs structured JSON.
5
+ """
6
+
7
+ import re
8
+ import logging
9
+ from agents.state import PipelineState
10
+ from agents.prompts import BIAS_DETECTOR_SYSTEM, BIAS_DETECTOR_USER
11
+ from agents.output_parser import parse_json_response
12
+ from models import medgemma_client, medsiglip_client
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Common imaging signs that SigLIP can meaningfully evaluate on chest X-ray.
17
+ # These are visual patterns, not abstract diagnoses.
18
+ _KNOWN_SIGNS = [
19
+ "pleural effusion", "consolidation", "infiltrates", "pneumothorax",
20
+ "widened mediastinum", "cardiomegaly", "pulmonary edema", "atelectasis",
21
+ "rib fracture", "subcutaneous emphysema", "hilar enlargement",
22
+ "hyperinflation", "pleural thickening", "lung opacity", "air bronchogram",
23
+ "mediastinal shift", "tracheal deviation", "cephalization",
24
+ ]
25
+
26
+
27
+ def _extract_signs(findings: object) -> list[str]:
28
+ """Extract imaging signs mentioned in the Diagnostician's findings.
29
+
30
+ Matches against known radiological signs rather than parsing diagnoses.
31
+ """
32
+ if isinstance(findings, list):
33
+ chunks: list[str] = []
34
+ for item in findings:
35
+ if isinstance(item, dict):
36
+ chunks.append(str(item.get("finding", "")))
37
+ chunks.append(str(item.get("description", "")))
38
+ else:
39
+ chunks.append(str(item))
40
+ findings_text = "\n".join(chunks)
41
+ else:
42
+ findings_text = str(findings)
43
+
44
+ findings_lower = findings_text.lower()
45
+ found = []
46
+ for sign in _KNOWN_SIGNS:
47
+ if sign in findings_lower:
48
+ found.append(sign)
49
+
50
+ # Also extract any explicit "abnormal" findings with simple patterns
51
+ # e.g., "visible pleural line", "blunted costophrenic angle"
52
+ extra_patterns = [
53
+ r'(?:visible|subtle|small|large|bilateral|unilateral|left|right)\s+([\w\s]{5,30}?)(?:\.|,|;|\n)',
54
+ ]
55
+ for pat in extra_patterns:
56
+ for m in re.findall(pat, findings_lower):
57
+ cleaned = m.strip()
58
+ if cleaned not in found and len(cleaned) > 5:
59
+ found.append(cleaned)
60
+
61
+ # Deduplicate, limit to 8
62
+ seen = set()
63
+ unique = []
64
+ for s in found:
65
+ if s not in seen:
66
+ seen.add(s)
67
+ unique.append(s)
68
+ return unique[:8]
69
+
70
+
71
+ def run(state: PipelineState) -> PipelineState:
72
+ """Run the Bias Detector agent."""
73
+ state["current_step"] = "bias_detector"
74
+ clinical = state["clinical_input"]
75
+ diag_out = state.get("diagnostician_output")
76
+
77
+ if diag_out is None:
78
+ state["error"] = "Diagnostician output missing."
79
+ return state
80
+
81
+ try:
82
+ # 1. MedSigLIP: verify imaging signs mentioned in findings
83
+ sign_verification = []
84
+ image = clinical.get("image")
85
+ if image is not None:
86
+ signs = _extract_signs(diag_out.get("findings_list") or diag_out.get("findings", ""))
87
+ logger.info("Extracted signs for SigLIP verification: %s", signs)
88
+ if signs:
89
+ sign_verification = medsiglip_client.verify_findings(
90
+ image,
91
+ signs,
92
+ modality=clinical.get("modality"),
93
+ )
94
+
95
+ # 2. MedGemma: cognitive bias analysis (with image if available)
96
+ diagnostician_analysis = diag_out.get("analysis") or diag_out.get("findings", "")
97
+ prompt = BIAS_DETECTOR_USER.format(
98
+ doctor_diagnosis=clinical["doctor_diagnosis"],
99
+ clinical_context=clinical["clinical_context"],
100
+ diagnostician_findings=diagnostician_analysis,
101
+ consistency_check=_format_sign_verification(sign_verification),
102
+ )
103
+ if image is not None:
104
+ raw = medgemma_client.generate_with_image(prompt, image, system_prompt=BIAS_DETECTOR_SYSTEM)
105
+ else:
106
+ raw = medgemma_client.generate_text(prompt, system_prompt=BIAS_DETECTOR_SYSTEM)
107
+ parsed = parse_json_response(raw)
108
+ state["bias_detector_output"] = {
109
+ "identified_biases": parsed.get("identified_biases", []),
110
+ "discrepancy_summary": parsed.get("discrepancy_summary", ""),
111
+ "missed_findings": parsed.get("missed_findings", []),
112
+ "consistency_check": sign_verification,
113
+ }
114
+
115
+ except Exception as e:
116
+ logger.exception("Bias Detector agent failed")
117
+ state["error"] = f"Bias Detector error: {e}"
118
+
119
+ return state
120
+
121
+
122
+ def _format_sign_verification(results: list[dict]) -> str:
123
+ """Format sign verification results as text for the MedGemma prompt."""
124
+ if not results:
125
+ return "No image verification available."
126
+
127
+ # Only include non-inconclusive results
128
+ meaningful = [r for r in results if r.get("confidence") != "inconclusive"]
129
+ if not meaningful:
130
+ return "Image verification inconclusive for all findings."
131
+
132
+ lines = ["Image sign verification (MedSigLIP):"]
133
+ for r in meaningful:
134
+ lines.append(f"- {r['sign']}: {r['confidence']}")
135
+ return "\n".join(lines)
agents/consultant.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Consultant agent: synthesizes all upstream outputs into a collegial debiasing report.
3
+ Outputs structured JSON.
4
+ """
5
+
6
+ import json
7
+ import logging
8
+
9
+ from agents.state import PipelineState
10
+ from agents.prompts import CONSULTANT_SYSTEM, CONSULTANT_USER
11
+ from agents.output_parser import parse_json_response
12
+ from models import medgemma_client
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _format_bias_report(bias_out: dict) -> str:
18
+ """Format bias detector output as text for the Consultant prompt."""
19
+ parts = []
20
+ if bias_out.get("discrepancy_summary"):
21
+ parts.append(f"Discrepancy: {bias_out['discrepancy_summary']}")
22
+ for b in bias_out.get("identified_biases", []):
23
+ parts.append(f"- [{b.get('severity','?').upper()}] {b.get('type','')}: {b.get('evidence','')}")
24
+ if bias_out.get("missed_findings"):
25
+ parts.append(f"Missed: {', '.join(bias_out['missed_findings'])}")
26
+ return "\n".join(parts) if parts else "No bias data."
27
+
28
+
29
+ def _format_da_report(da_out: dict) -> str:
30
+ """Format devil's advocate output as text for the Consultant prompt."""
31
+ parts = []
32
+ for c in da_out.get("challenges", []):
33
+ parts.append(f"Challenge: {c.get('claim','')} → {c.get('counter_evidence','')}")
34
+ for m in da_out.get("must_not_miss", []):
35
+ parts.append(f"MUST-NOT-MISS: {m.get('diagnosis','')} — {m.get('why_dangerous','')}")
36
+ if da_out.get("recommended_workup"):
37
+ items = [str(w) if not isinstance(w, dict) else w.get("test", str(w)) for w in da_out["recommended_workup"]]
38
+ parts.append("Workup: " + ", ".join(items))
39
+ return "\n".join(parts) if parts else "No challenges raised."
40
+
41
+
42
+ def run(state: PipelineState) -> PipelineState:
43
+ """Run the Consultant agent."""
44
+ state["current_step"] = "consultant"
45
+ clinical = state["clinical_input"]
46
+ diag_out = state.get("diagnostician_output")
47
+ bias_out = state.get("bias_detector_output")
48
+ da_out = state.get("devils_advocate_output")
49
+
50
+ if diag_out is None or bias_out is None or da_out is None:
51
+ state["error"] = "Missing upstream agent outputs."
52
+ return state
53
+
54
+ try:
55
+ diagnostician_analysis = diag_out.get("analysis") or diag_out.get("findings", "")
56
+ prompt = CONSULTANT_USER.format(
57
+ doctor_diagnosis=clinical["doctor_diagnosis"],
58
+ clinical_context=clinical["clinical_context"],
59
+ diagnostician_findings=diagnostician_analysis,
60
+ bias_report=_format_bias_report(bias_out),
61
+ devil_advocate_report=_format_da_report(da_out),
62
+ similar_cases="Not available.",
63
+ )
64
+ raw = medgemma_client.generate_text(prompt, system_prompt=CONSULTANT_SYSTEM)
65
+ parsed = parse_json_response(raw)
66
+
67
+ alternative_diagnoses = parsed.get("alternative_diagnoses", [])
68
+ if isinstance(alternative_diagnoses, str):
69
+ try:
70
+ alternative_diagnoses = json.loads(alternative_diagnoses)
71
+ except json.JSONDecodeError:
72
+ alternative_diagnoses = []
73
+ if not isinstance(alternative_diagnoses, list):
74
+ alternative_diagnoses = []
75
+
76
+ immediate_actions = parsed.get("immediate_actions", [])
77
+ if isinstance(immediate_actions, str):
78
+ immediate_actions = [immediate_actions]
79
+ if not isinstance(immediate_actions, list):
80
+ immediate_actions = []
81
+ immediate_actions = [str(x).strip() for x in immediate_actions if str(x).strip()]
82
+
83
+ state["consultant_output"] = {
84
+ "consultation_note": parsed.get("consultation_note", ""),
85
+ "alternative_diagnoses": alternative_diagnoses,
86
+ "immediate_actions": immediate_actions,
87
+ "confidence_note": parsed.get("confidence_note", ""),
88
+ }
89
+
90
+ except Exception as e:
91
+ logger.exception("Consultant agent failed")
92
+ state["error"] = f"Consultant error: {e}"
93
+
94
+ return state
agents/devil_advocate.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Devil's Advocate agent: adversarial challenge to the working diagnosis.
3
+ Deliberately contrarian — focuses on must-not-miss diagnoses.
4
+ Uses MedGemma 4B (multimodal) to independently examine the image.
5
+ Outputs structured JSON.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from collections.abc import Mapping
11
+ from agents.state import PipelineState
12
+ from agents.prompts import DEVIL_ADVOCATE_SYSTEM, DEVIL_ADVOCATE_USER
13
+ from agents.output_parser import parse_json_response
14
+ from models import medgemma_client
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ _DA_SCHEMA_KEYS = ("challenges", "must_not_miss", "recommended_workup")
20
+ _DA_WRAPPER_KEYS = (
21
+ "devils_advocate_output",
22
+ "devil_advocate_output",
23
+ "devil_advocate",
24
+ "output",
25
+ "response",
26
+ "result",
27
+ "data",
28
+ )
29
+ _DA_SYNONYMS: dict[str, str] = {
30
+ # must-not-miss
31
+ "must_not_miss_diagnoses": "must_not_miss",
32
+ "must_not_miss_differentials": "must_not_miss",
33
+ "dangerous_alternatives": "must_not_miss",
34
+ "critical_differentials": "must_not_miss",
35
+ # workup
36
+ "workup": "recommended_workup",
37
+ "recommended_tests": "recommended_workup",
38
+ "recommended_actions": "recommended_workup",
39
+ "next_steps": "recommended_workup",
40
+ # challenges
41
+ "challenge": "challenges",
42
+ "concerns": "challenges",
43
+ "counterarguments": "challenges",
44
+ }
45
+
46
+
47
+ def _format_bias_summary(bias_out: dict) -> str:
48
+ """Format bias detector output for the Devil's Advocate prompt."""
49
+ parts = []
50
+ if bias_out.get("discrepancy_summary"):
51
+ parts.append(bias_out["discrepancy_summary"])
52
+ for b in bias_out.get("identified_biases", []):
53
+ parts.append(f"- {b.get('type', 'unknown')}: {b.get('evidence', '')} (severity: {b.get('severity', '?')})")
54
+ if bias_out.get("missed_findings"):
55
+ parts.append("Missed findings: " + ", ".join(bias_out["missed_findings"]))
56
+ return "\n".join(parts) if parts else "No bias analysis available."
57
+
58
+
59
+ def _unwrap_da_payload(parsed: dict) -> dict:
60
+ """Unwrap common container shapes: {"output": {...}}, {"result": {...}}, etc."""
61
+ if any(k in parsed for k in _DA_SCHEMA_KEYS):
62
+ return parsed
63
+
64
+ for key in _DA_WRAPPER_KEYS:
65
+ inner = parsed.get(key)
66
+ if isinstance(inner, Mapping) and any(k in inner for k in _DA_SCHEMA_KEYS):
67
+ return dict(inner)
68
+
69
+ # If there's a single nested object, unwrap it if it contains DA keys.
70
+ if len(parsed) == 1:
71
+ only_value = next(iter(parsed.values()))
72
+ if isinstance(only_value, Mapping) and any(k in only_value for k in _DA_SCHEMA_KEYS):
73
+ return dict(only_value)
74
+
75
+ # One-level scan for any nested object that contains DA keys.
76
+ for value in parsed.values():
77
+ if isinstance(value, Mapping) and any(k in value for k in _DA_SCHEMA_KEYS):
78
+ return dict(value)
79
+
80
+ return parsed
81
+
82
+
83
+ def _coerce_da_schema(parsed: dict) -> dict:
84
+ """Best-effort normalization when the model returns an unexpected top-level JSON shape."""
85
+ if not isinstance(parsed, dict):
86
+ return {}
87
+
88
+ parsed = _unwrap_da_payload(parsed)
89
+ if not isinstance(parsed, dict):
90
+ return {}
91
+
92
+ # Map common synonym keys onto the expected schema.
93
+ coerced = dict(parsed)
94
+ for src, dst in _DA_SYNONYMS.items():
95
+ if src in coerced and dst not in coerced:
96
+ coerced[dst] = coerced[src]
97
+
98
+ if any(k in coerced for k in _DA_SCHEMA_KEYS):
99
+ return coerced
100
+
101
+ items = coerced.get("items")
102
+ if not isinstance(items, list) or not items:
103
+ return coerced
104
+
105
+ # If the model returned just a list of strings, treat it as a workup list.
106
+ if all(isinstance(x, str) for x in items):
107
+ return {"recommended_workup": items}
108
+
109
+ dict_items = [x for x in items if isinstance(x, dict)]
110
+ if len(dict_items) != len(items):
111
+ return parsed
112
+
113
+ keys: set[str] = set()
114
+ for d in dict_items[:5]:
115
+ keys.update(d.keys())
116
+
117
+ if "claim" in keys or "counter_evidence" in keys:
118
+ return {"challenges": dict_items}
119
+ if {"why_dangerous", "supporting_signs", "rule_out_test"} & keys or "diagnosis" in keys:
120
+ return {"must_not_miss": dict_items}
121
+
122
+ return coerced
123
+
124
+
125
+ def _normalize_challenges(value: object) -> list[dict[str, str]]:
126
+ if value is None:
127
+ return []
128
+
129
+ items = [value] if isinstance(value, Mapping) else value
130
+ if isinstance(items, str):
131
+ s = items.strip()
132
+ return [{"claim": s, "counter_evidence": ""}] if s else []
133
+ if not isinstance(items, list):
134
+ return []
135
+
136
+ out: list[dict[str, str]] = []
137
+ for item in items:
138
+ if item is None:
139
+ continue
140
+ if isinstance(item, Mapping):
141
+ d = dict(item)
142
+ claim = str(d.get("claim") or d.get("challenge") or d.get("concern") or "").strip()
143
+ counter = str(
144
+ d.get("counter_evidence")
145
+ or d.get("counterevidence")
146
+ or d.get("counter_argument")
147
+ or d.get("counterargument")
148
+ or d.get("counter")
149
+ or d.get("evidence_against")
150
+ or ""
151
+ ).strip()
152
+ if claim or counter:
153
+ out.append({"claim": claim, "counter_evidence": counter})
154
+ continue
155
+
156
+ s = str(item).strip()
157
+ if s:
158
+ out.append({"claim": s, "counter_evidence": ""})
159
+
160
+ return out
161
+
162
+
163
+ def _normalize_must_not_miss(value: object) -> list[dict[str, str]]:
164
+ if value is None:
165
+ return []
166
+
167
+ items = [value] if isinstance(value, Mapping) else value
168
+ if isinstance(items, str):
169
+ s = items.strip()
170
+ return [{"diagnosis": s}] if s else []
171
+ if not isinstance(items, list):
172
+ return []
173
+
174
+ out: list[dict[str, str]] = []
175
+ for item in items:
176
+ if item is None:
177
+ continue
178
+ if isinstance(item, Mapping):
179
+ d = dict(item)
180
+ diagnosis = str(d.get("diagnosis") or d.get("dx") or d.get("differential") or "").strip()
181
+ why = str(d.get("why_dangerous") or d.get("why") or d.get("danger") or "").strip()
182
+ signs = str(d.get("supporting_signs") or d.get("evidence") or d.get("support") or "").strip()
183
+ test = str(d.get("rule_out_test") or d.get("test") or d.get("rule_out") or "").strip()
184
+ if diagnosis or why or signs or test:
185
+ out.append(
186
+ {
187
+ "diagnosis": diagnosis,
188
+ "why_dangerous": why,
189
+ "supporting_signs": signs,
190
+ "rule_out_test": test,
191
+ }
192
+ )
193
+ continue
194
+
195
+ s = str(item).strip()
196
+ if s:
197
+ out.append({"diagnosis": s})
198
+
199
+ return out
200
+
201
+
202
+ def run(state: PipelineState) -> PipelineState:
203
+ """Run the Devil's Advocate agent."""
204
+ state["current_step"] = "devil_advocate"
205
+ clinical = state["clinical_input"]
206
+ diag_out = state.get("diagnostician_output")
207
+ bias_out = state.get("bias_detector_output")
208
+
209
+ image = clinical.get("image")
210
+
211
+ if diag_out is None or bias_out is None:
212
+ state["error"] = "Missing upstream agent outputs."
213
+ return state
214
+
215
+ if image is None:
216
+ state["error"] = "No image provided for Devil's Advocate."
217
+ return state
218
+
219
+ try:
220
+ diagnostician_analysis = diag_out.get("analysis") or diag_out.get("findings", "")
221
+ prompt = DEVIL_ADVOCATE_USER.format(
222
+ doctor_diagnosis=clinical["doctor_diagnosis"],
223
+ clinical_context=clinical["clinical_context"],
224
+ diagnostician_findings=diagnostician_analysis,
225
+ bias_summary=_format_bias_summary(bias_out),
226
+ )
227
+ system_prompt = DEVIL_ADVOCATE_SYSTEM
228
+ raw = medgemma_client.generate_with_image(prompt, image, system_prompt=system_prompt)
229
+ parsed = _coerce_da_schema(parse_json_response(raw))
230
+
231
+ challenges = _normalize_challenges(parsed.get("challenges"))
232
+ must_not_miss = _normalize_must_not_miss(parsed.get("must_not_miss"))
233
+ workup_raw = parsed.get("recommended_workup", [])
234
+ normalized_workup: list[str] = []
235
+ if isinstance(workup_raw, str):
236
+ # Split a single workup string into bullet-like entries.
237
+ workup_raw = [x.strip(" -\t") for x in workup_raw.replace(";", "\n").splitlines()]
238
+ if isinstance(workup_raw, Mapping):
239
+ workup_raw = [dict(workup_raw)]
240
+ if isinstance(workup_raw, list):
241
+ for item in workup_raw:
242
+ if item is None:
243
+ continue
244
+ if isinstance(item, str):
245
+ s = item.strip()
246
+ elif isinstance(item, dict):
247
+ s = str(
248
+ item.get("test")
249
+ or item.get("name")
250
+ or item.get("action")
251
+ or item.get("workup")
252
+ or ""
253
+ ).strip()
254
+ if not s:
255
+ s = json.dumps(item, ensure_ascii=False)
256
+ else:
257
+ s = str(item).strip()
258
+ if s:
259
+ normalized_workup.append(s)
260
+ # Deduplicate while preserving order.
261
+ normalized_workup = list(dict.fromkeys(normalized_workup))
262
+
263
+ # If the model returned an empty schema, retry once with a stricter instruction.
264
+ if not (challenges or must_not_miss or normalized_workup):
265
+ logger.warning("Devil's Advocate produced empty structured output; retrying once.")
266
+ strict_system = (
267
+ DEVIL_ADVOCATE_SYSTEM
268
+ + "\n\nIMPORTANT: Do not return empty arrays. Provide at least 1 item in each list, "
269
+ + "even if you must express uncertainty and suggest rule-out testing."
270
+ )
271
+ raw_retry = medgemma_client.generate_with_image(prompt, image, system_prompt=strict_system)
272
+ parsed_retry = _coerce_da_schema(parse_json_response(raw_retry))
273
+ challenges = _normalize_challenges(parsed_retry.get("challenges"))
274
+ must_not_miss = _normalize_must_not_miss(parsed_retry.get("must_not_miss"))
275
+ workup_retry = parsed_retry.get("recommended_workup", [])
276
+ normalized_workup = []
277
+ if isinstance(workup_retry, str):
278
+ workup_retry = [x.strip(" -\t") for x in workup_retry.replace(";", "\n").splitlines()]
279
+ if isinstance(workup_retry, Mapping):
280
+ workup_retry = [dict(workup_retry)]
281
+ if isinstance(workup_retry, list):
282
+ for item in workup_retry:
283
+ if item is None:
284
+ continue
285
+ if isinstance(item, str):
286
+ s = item.strip()
287
+ elif isinstance(item, dict):
288
+ s = str(
289
+ item.get("test")
290
+ or item.get("name")
291
+ or item.get("action")
292
+ or item.get("workup")
293
+ or ""
294
+ ).strip()
295
+ if not s:
296
+ s = json.dumps(item, ensure_ascii=False)
297
+ else:
298
+ s = str(item).strip()
299
+ if s:
300
+ normalized_workup.append(s)
301
+ normalized_workup = list(dict.fromkeys(normalized_workup))
302
+
303
+ state["devils_advocate_output"] = {
304
+ "challenges": challenges,
305
+ "must_not_miss": must_not_miss,
306
+ "recommended_workup": normalized_workup,
307
+ }
308
+
309
+ except Exception as e:
310
+ logger.exception("Devil's Advocate agent failed")
311
+ state["error"] = f"Devil's Advocate error: {e}"
312
+
313
+ return state
agents/diagnostician.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnostician agent: independent image analysis WITHOUT seeing the doctor's diagnosis.
3
+ Uses MedGemma 4B (multimodal) for detailed radiological analysis.
4
+ Outputs structured JSON.
5
+ """
6
+
7
+ import logging
8
+ from agents.state import PipelineState
9
+ from agents.prompts import DIAGNOSTICIAN_SYSTEM, DIAGNOSTICIAN_USER
10
+ from agents.output_parser import parse_json_response
11
+ from models import medgemma_client
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def run(state: PipelineState) -> PipelineState:
17
+ """Run the Diagnostician agent."""
18
+ state["current_step"] = "diagnostician"
19
+ clinical = state["clinical_input"]
20
+ image = clinical.get("image")
21
+
22
+ if image is None:
23
+ state["error"] = "No image provided."
24
+ return state
25
+
26
+ try:
27
+ prompt = DIAGNOSTICIAN_USER.format(clinical_context=clinical["clinical_context"])
28
+ raw = medgemma_client.generate_with_image(prompt, image, system_prompt=DIAGNOSTICIAN_SYSTEM)
29
+ parsed = parse_json_response(raw)
30
+ findings = parsed.get("findings", [])
31
+ differentials = parsed.get("differential_diagnoses", [])
32
+ if not isinstance(findings, list):
33
+ findings = [findings] if findings else []
34
+ if not isinstance(differentials, list):
35
+ differentials = [differentials] if differentials else []
36
+
37
+ findings_lines: list[str] = []
38
+ for f in findings:
39
+ if isinstance(f, dict):
40
+ name = str(f.get("finding", "")).strip()
41
+ desc = str(f.get("description", "")).strip()
42
+ source = str(f.get("source", "")).strip()
43
+ source_tag = f" [{source}]" if source else ""
44
+ if name and desc:
45
+ findings_lines.append(f"- {name}{source_tag}: {desc}")
46
+ elif name:
47
+ findings_lines.append(f"- {name}{source_tag}")
48
+ elif desc:
49
+ findings_lines.append(f"- {desc}")
50
+ else:
51
+ s = str(f).strip()
52
+ if s:
53
+ findings_lines.append(f"- {s}")
54
+
55
+ differential_lines: list[str] = []
56
+ for d in differentials:
57
+ if isinstance(d, dict):
58
+ name = str(d.get("diagnosis", "")).strip()
59
+ reasoning = str(d.get("reasoning", "")).strip()
60
+ if name and reasoning:
61
+ differential_lines.append(f"- {name}: {reasoning}")
62
+ elif name:
63
+ differential_lines.append(f"- {name}")
64
+ elif reasoning:
65
+ differential_lines.append(f"- {reasoning}")
66
+ else:
67
+ s = str(d).strip()
68
+ if s:
69
+ differential_lines.append(f"- {s}")
70
+
71
+ findings_text = "\n".join(findings_lines)
72
+ differentials_text = "\n".join(differential_lines)
73
+ analysis_parts: list[str] = []
74
+ if findings_text:
75
+ analysis_parts.append("Findings:\n" + findings_text)
76
+ if differentials_text:
77
+ analysis_parts.append("Differential diagnoses:\n" + differentials_text)
78
+ analysis_text = "\n\n".join(analysis_parts).strip()
79
+ state["diagnostician_output"] = {
80
+ "analysis": analysis_text,
81
+ "findings": findings_text,
82
+ "findings_list": findings,
83
+ "differential_diagnoses": differentials,
84
+ "differentials_text": differentials_text,
85
+ }
86
+
87
+ except Exception as e:
88
+ logger.exception("Diagnostician agent failed")
89
+ state["error"] = f"Diagnostician error: {e}"
90
+
91
+ return state
agents/graph.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph pipeline: linear flow through 4 diagnostic agents.
3
+
4
+ START → diagnostician → bias_detector → devil_advocate → consultant → END
5
+ """
6
+
7
+ import logging
8
+ import threading
9
+
10
+ from agents.state import PipelineState
11
+ from agents import diagnostician, bias_detector, devil_advocate, consultant
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ try:
16
+ from langgraph.graph import StateGraph, START, END
17
+
18
+ _LANGGRAPH_AVAILABLE = True
19
+ except ModuleNotFoundError:
20
+ StateGraph = None # type: ignore[assignment]
21
+ START = END = None
22
+ _LANGGRAPH_AVAILABLE = False
23
+
24
+
25
+ def _check_error(state: PipelineState) -> str:
26
+ """Route to END if an error occurred, otherwise continue."""
27
+ if state.get("error"):
28
+ return "end"
29
+ return "continue"
30
+
31
+
32
+ class _FallbackGraph:
33
+ def invoke(self, initial_state: PipelineState) -> PipelineState:
34
+ state = initial_state
35
+ for fn in (diagnostician.run, bias_detector.run, devil_advocate.run, consultant.run):
36
+ state = fn(state)
37
+ if state.get("error"):
38
+ break
39
+ return state
40
+
41
+ def stream(self, initial_state: PipelineState, stream_mode: str = "updates"):
42
+ state = initial_state
43
+ for name, fn in (
44
+ ("diagnostician", diagnostician.run),
45
+ ("bias_detector", bias_detector.run),
46
+ ("devil_advocate", devil_advocate.run),
47
+ ("consultant", consultant.run),
48
+ ):
49
+ state = fn(state)
50
+ yield {name: dict(state)}
51
+ if state.get("error"):
52
+ break
53
+
54
+
55
+ def build_graph():
56
+ """Build and compile the diagnostic debiasing pipeline."""
57
+ if not _LANGGRAPH_AVAILABLE:
58
+ logger.warning("langgraph is not installed; falling back to a simple sequential pipeline.")
59
+ return _FallbackGraph()
60
+
61
+ graph = StateGraph(PipelineState)
62
+
63
+ # Add nodes
64
+ graph.add_node("diagnostician", diagnostician.run)
65
+ graph.add_node("bias_detector", bias_detector.run)
66
+ graph.add_node("devil_advocate", devil_advocate.run)
67
+ graph.add_node("consultant", consultant.run)
68
+
69
+ # Linear flow with error checking
70
+ graph.add_edge(START, "diagnostician")
71
+ graph.add_conditional_edges("diagnostician", _check_error, {"continue": "bias_detector", "end": END})
72
+ graph.add_conditional_edges("bias_detector", _check_error, {"continue": "devil_advocate", "end": END})
73
+ graph.add_conditional_edges("devil_advocate", _check_error, {"continue": "consultant", "end": END})
74
+ graph.add_edge("consultant", END)
75
+
76
+ return graph.compile()
77
+
78
+
79
+ # Singleton compiled graph
80
+ _compiled_graph = None
81
+ _compiled_graph_lock = threading.Lock()
82
+
83
+
84
+ def get_graph():
85
+ """Get or create the compiled pipeline graph."""
86
+ global _compiled_graph
87
+ if _compiled_graph is not None:
88
+ return _compiled_graph
89
+ with _compiled_graph_lock:
90
+ if _compiled_graph is None:
91
+ _compiled_graph = build_graph()
92
+ return _compiled_graph
93
+
94
+
95
+ def _make_initial_state(
96
+ image,
97
+ doctor_diagnosis: str,
98
+ clinical_context: str,
99
+ modality: str | None = None,
100
+ ) -> PipelineState:
101
+ return {
102
+ "clinical_input": {
103
+ "image": image,
104
+ "doctor_diagnosis": doctor_diagnosis,
105
+ "clinical_context": clinical_context,
106
+ "modality": modality or "CXR",
107
+ },
108
+ "diagnostician_output": None,
109
+ "bias_detector_output": None,
110
+ "devils_advocate_output": None,
111
+ "consultant_output": None,
112
+ "current_step": "start",
113
+ "error": None,
114
+ }
115
+
116
+
117
+ def run_pipeline(
118
+ image,
119
+ doctor_diagnosis: str,
120
+ clinical_context: str,
121
+ modality: str | None = None,
122
+ ) -> PipelineState:
123
+ """Run the full debiasing pipeline (blocking)."""
124
+ graph = get_graph()
125
+ initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
126
+ return graph.invoke(initial_state)
127
+
128
+
129
+ def stream_pipeline(
130
+ image,
131
+ doctor_diagnosis: str,
132
+ clinical_context: str,
133
+ modality: str | None = None,
134
+ ):
135
+ """
136
+ Stream the pipeline, yielding (node_name, state) after each agent completes.
137
+ Use this for progressive UI updates.
138
+ """
139
+ graph = get_graph()
140
+ initial_state = _make_initial_state(image, doctor_diagnosis, clinical_context, modality=modality)
141
+
142
+ for event in graph.stream(initial_state, stream_mode="updates"):
143
+ # event is {node_name: state_update}
144
+ for node_name, state_update in event.items():
145
+ yield node_name, state_update
agents/output_parser.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ JSON output parser for LLM responses.
3
+ Uses json_repair to handle malformed JSON (missing commas, truncation, extra text, etc.).
4
+ """
5
+
6
+ import logging
7
+ from collections.abc import Mapping
8
+ from json_repair import repair_json
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _TOP_LEVEL_KEYS = {
13
+ # Diagnostician
14
+ "findings",
15
+ "differential_diagnoses",
16
+ # Bias detector
17
+ "discrepancy_summary",
18
+ "identified_biases",
19
+ "missed_findings",
20
+ "agreement_points",
21
+ # Devil's advocate
22
+ "challenges",
23
+ "must_not_miss",
24
+ "recommended_workup",
25
+ # Consultant
26
+ "consultation_note",
27
+ "alternative_diagnoses",
28
+ "immediate_actions",
29
+ "confidence_note",
30
+ }
31
+
32
+
33
+ def parse_json_response(text: str) -> dict:
34
+ """
35
+ Extract and repair JSON from an LLM response.
36
+ Handles: raw JSON, ```json blocks, missing commas, truncated output, etc.
37
+ Returns parsed dict. Raises ValueError if repair fails completely.
38
+ """
39
+ result = repair_json(text, return_objects=True)
40
+
41
+ # Typical (desired) case: top-level object.
42
+ if isinstance(result, Mapping):
43
+ return dict(result)
44
+
45
+ # Some model outputs come back as a top-level array. Coerce to a dict so
46
+ # downstream code can continue, while preserving the payload for callers to
47
+ # interpret (via 'items') when schema keys are missing.
48
+ if isinstance(result, list):
49
+ return _coerce_list_root(result)
50
+
51
+ raise ValueError(
52
+ f"Could not parse JSON from LLM output (got {type(result).__name__}, length={len(text)})"
53
+ )
54
+
55
+
56
+ def _coerce_list_root(items: list) -> dict:
57
+ if not items:
58
+ return {"items": []}
59
+
60
+ mapping_items = [x for x in items if isinstance(x, Mapping)]
61
+ if not mapping_items:
62
+ return {"items": items}
63
+
64
+ merged: dict = {}
65
+ contains_top_level_key = False
66
+ for m in mapping_items:
67
+ d = dict(m)
68
+ contains_top_level_key = contains_top_level_key or bool(_TOP_LEVEL_KEYS.intersection(d.keys()))
69
+ merged.update(d)
70
+
71
+ # If the extracted objects already contain known top-level schema keys, it's
72
+ # likely a wrapped/duplicated object (or multiple partial objects). Merge.
73
+ if contains_top_level_key:
74
+ return merged
75
+
76
+ all_mappings = len(mapping_items) == len(items)
77
+ if all_mappings:
78
+ # Distinguish between (a) a true list of repeated schema items, vs (b)
79
+ # multiple standalone JSON objects extracted from a noisy response.
80
+ key_sets = [set(dict(m).keys()) for m in mapping_items[:10]]
81
+ union = set().union(*key_sets)
82
+ intersection = set(key_sets[0]).intersection(*key_sets[1:]) if len(key_sets) > 1 else set(key_sets[0])
83
+ overlap_ratio = (len(intersection) / len(union)) if union else 0.0
84
+
85
+ if len(items) == 1 or overlap_ratio >= 0.35:
86
+ inferred_key = _infer_list_container_key(mapping_items)
87
+ if inferred_key:
88
+ return {inferred_key: [dict(m) for m in mapping_items]}
89
+ return {"items": [dict(m) for m in mapping_items]}
90
+
91
+ # Low overlap between objects: treat as multiple extracted JSON objects.
92
+ return merged
93
+
94
+ # Mixed list: preserve non-mapping items, but coerce mappings to dict.
95
+ coerced = [dict(x) if isinstance(x, Mapping) else x for x in items]
96
+ return {"items": coerced}
97
+
98
+
99
+ def _infer_list_container_key(items: list[Mapping]) -> str | None:
100
+ keys: set[str] = set()
101
+ for item in items[:5]:
102
+ keys.update(str(k) for k in item.keys())
103
+
104
+ # Diagnostician
105
+ if {"finding", "description"} & keys:
106
+ return "findings"
107
+ if "reasoning" in keys:
108
+ return "differential_diagnoses"
109
+
110
+ # Bias detector
111
+ if {"type", "severity"} <= keys or ("type" in keys and "severity" in keys):
112
+ return "identified_biases"
113
+
114
+ # Devil's advocate
115
+ if "claim" in keys or "counter_evidence" in keys:
116
+ return "challenges"
117
+ if {"why_dangerous", "rule_out_test", "supporting_signs"} & keys:
118
+ return "must_not_miss"
119
+
120
+ # Consultant
121
+ if {"urgency", "next_step"} & keys:
122
+ return "alternative_diagnoses"
123
+
124
+ return None
agents/prompts.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for each agent in the debiasing pipeline.
3
+ All downstream agents (Bias Detector, Devil's Advocate, Consultant) use JSON output format.
4
+ """
5
+
6
+ # ---------------------------------------------------------------------------
7
+ # Diagnostician: independent image analysis (MUST NOT see doctor's diagnosis)
8
+ # ---------------------------------------------------------------------------
9
+ DIAGNOSTICIAN_SYSTEM = """\
10
+ You are a radiologist performing an independent case review. Analyze BOTH the medical image AND the clinical context (history, vitals, labs, exam findings). Do not assume any prior diagnosis.
11
+ Some dangerous conditions may show subtle or no imaging signs but have obvious clinical red flags — you must catch these.
12
+ Respond with valid JSON only — no markdown, no text outside the JSON.
13
+ Top-level JSON must be a single object (not an array)."""
14
+
15
+ DIAGNOSTICIAN_USER = """\
16
+ Patient clinical context: {clinical_context}
17
+
18
+ Analyze this medical image together with the clinical context above. Report ALL findings — both imaging findings and clinical red flags from the context (abnormal vitals, labs, risk factors). Respond with JSON:
19
+
20
+ {{
21
+ "findings": [
22
+ {{
23
+ "finding": "name of finding",
24
+ "source": "imaging | clinical | both",
25
+ "description": "location/appearance for imaging findings, or value/significance for clinical findings"
26
+ }}
27
+ ],
28
+ "differential_diagnoses": [
29
+ {{
30
+ "diagnosis": "diagnosis name",
31
+ "reasoning": "combined evidence from imaging AND clinical context"
32
+ }}
33
+ ]
34
+ }}"""
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Bias Detector: compare doctor's diagnosis with independent analysis
38
+ # Output: structured JSON
39
+ # ---------------------------------------------------------------------------
40
+ BIAS_DETECTOR_SYSTEM = """\
41
+ You are a clinical reasoning expert specializing in cognitive bias detection. You have direct access to the medical image AND the full clinical context (history, vitals, labs, exam findings).
42
+ You are given two independent assessments of the same case: the treating physician's diagnosis and an AI-generated analysis. Neither is assumed to be correct — both may contain errors or omissions.
43
+ Examine the image yourself AND carefully review the clinical context. Compare both assessments against what you see in the image AND what the clinical data shows. Some dangerous conditions have subtle imaging but obvious clinical red flags — flag these if either assessment ignored them.
44
+ Respond with valid JSON only — no markdown, no text outside the JSON.
45
+ Top-level JSON must be a single object (not an array)."""
46
+
47
+ BIAS_DETECTOR_USER = """\
48
+ Doctor's diagnosis: "{doctor_diagnosis}"
49
+ Clinical context: {clinical_context}
50
+ AI independent analysis (blinded, may also contain errors): {diagnostician_findings}
51
+ Image–diagnosis consistency (MedSigLIP verification): {consistency_check}
52
+
53
+ Compare both assessments objectively. Neither is assumed correct. Respond with JSON:
54
+
55
+ {{
56
+ "discrepancy_summary": "how the two assessments differ — note which points are uncertain",
57
+ "identified_biases": [
58
+ {{
59
+ "source": "doctor | AI | both",
60
+ "type": "bias type",
61
+ "evidence": "why you suspect this bias",
62
+ "severity": "choose from LOW | MEDIUM | HIGH"
63
+ }}
64
+ ],
65
+ "missed_findings": ["finding not accounted for by either assessment"],
66
+ "agreement_points": ["findings where both agree"]
67
+ }}"""
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Devil's Advocate: adversarial challenge (deliberately contrarian)
71
+ # Output: structured JSON
72
+ # ---------------------------------------------------------------------------
73
+ DEVIL_ADVOCATE_SYSTEM = """\
74
+ You are a Devil's Advocate in a clinical case review. You have direct access to the medical image AND the full clinical context.
75
+ Your sole purpose is to challenge the working diagnosis — especially for dangerous must-not-miss diagnoses.
76
+ Examine the image yourself AND scrutinize the clinical data (vitals, labs, risk factors). Many must-not-miss diagnoses have subtle imaging but glaring clinical signs — use both sources of evidence.
77
+ Do not simply repeat earlier findings — look for anything that may have been overlooked.
78
+ Respond with valid JSON only — no markdown, no text outside the JSON.
79
+ Top-level JSON must be a single object (not an array)."""
80
+
81
+ DEVIL_ADVOCATE_USER = """\
82
+ Working diagnosis: "{doctor_diagnosis}"
83
+ Clinical context: {clinical_context}
84
+ Prior independent analysis (for reference only — form your own opinion from the image and clinical data): {diagnostician_findings}
85
+ Detected biases: {bias_summary}
86
+
87
+ Examine the attached medical image AND the clinical context. Challenge the working diagnosis using evidence from both imaging and clinical data.
88
+ IMPORTANT: Do NOT return empty lists — provide at least 1 item in each list. If evidence is weak, state uncertainty and suggest a rule-out test.
89
+ Respond with JSON:
90
+
91
+ {{
92
+ "challenges": [
93
+ {{
94
+ "claim": "aspect being challenged",
95
+ "counter_evidence": "why it may be wrong"
96
+ }}
97
+ ],
98
+ "must_not_miss": [
99
+ {{
100
+ "diagnosis": "dangerous alternative",
101
+ "why_dangerous": "consequence if missed",
102
+ "supporting_signs": "evidence from this case",
103
+ "rule_out_test": "best test to confirm or exclude"
104
+ }}
105
+ ],
106
+ "recommended_workup": ["test 1", "test 2"]
107
+ }}"""
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Consultant: synthesize debiasing report
111
+ # Output: structured JSON
112
+ # ---------------------------------------------------------------------------
113
+ CONSULTANT_SYSTEM = """\
114
+ You are a senior clinician writing a consultation note. Your reader is "you". The sick person is "the patient".
115
+ Tone: collegial, direct — "Have you considered..." style.
116
+ Never mention cognitive bias names. Never use brackets or placeholders.
117
+ Respond with valid JSON only — no markdown, no text outside the JSON.
118
+ Top-level JSON must be a single object (not an array)."""
119
+
120
+ CONSULTANT_USER = """\
121
+ Original diagnosis: "{doctor_diagnosis}"
122
+ Clinical context: {clinical_context}
123
+ Independent analysis: {diagnostician_findings}
124
+ Bias analysis: {bias_report}
125
+ Devil's advocate challenges: {devil_advocate_report}
126
+ Similar cases: {similar_cases}
127
+
128
+ Write a 2-4 paragraph consultation note. Call the reader "you" and the sick person "the patient". Start the note directly with clinical content (e.g., "I reviewed the imaging and..."). Respond with JSON:
129
+
130
+ {{
131
+ "consultation_note": "2-4 paragraphs. Address the reader as you. Call the sick person the patient. Start directly with clinical content.",
132
+ "alternative_diagnoses": [
133
+ {{
134
+ "diagnosis": "name",
135
+ "urgency": "MUST be one of: critical, high, moderate",
136
+ "evidence": "supporting evidence from this case",
137
+ "next_step": "specific action to confirm or rule out"
138
+ }}
139
+ ],
140
+ "immediate_actions": ["concrete next step 1", "step 2"],
141
+ "confidence_note": "confidence level and limitations"
142
+ }}"""
agents/state.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph state definition for the Diagnostic Devil's Advocate pipeline.
3
+ """
4
+
5
+ from typing import Any, Optional
6
+ from typing_extensions import NotRequired, TypedDict
7
+ from PIL import Image
8
+
9
+
10
+ class ClinicalInput(TypedDict):
11
+ """Raw input from the user."""
12
+ image: Optional[Image.Image]
13
+ doctor_diagnosis: str
14
+ clinical_context: str # age, sex, symptoms, history, etc.
15
+ modality: NotRequired[str] # "CXR" | "CT" | "Other"
16
+
17
+
18
+ class Finding(TypedDict, total=False):
19
+ finding: str
20
+ description: str
21
+
22
+
23
+ class DifferentialDiagnosis(TypedDict, total=False):
24
+ diagnosis: str
25
+ reasoning: str
26
+
27
+
28
+ class DiagnosticianOutput(TypedDict):
29
+ """Independent analysis from the Diagnostician agent (does NOT see doctor's diagnosis)."""
30
+ analysis: str # formatted text for downstream agents
31
+ findings: str # findings-only text
32
+ findings_list: list[Finding] # structured findings
33
+ differential_diagnoses: list[DifferentialDiagnosis] # top differentials
34
+ differentials_text: NotRequired[str]
35
+
36
+
37
+ class BiasDetectorOutput(TypedDict):
38
+ """Bias analysis comparing doctor's diagnosis vs independent analysis."""
39
+ identified_biases: list[dict[str, Any]] # [{"type": str, "evidence": str, "severity": str}]
40
+ discrepancy_summary: str
41
+ missed_findings: list[str]
42
+ consistency_check: list[dict[str, Any]] # MedSigLIP sign verification results
43
+
44
+
45
+ class DevilsAdvocateOutput(TypedDict):
46
+ """Adversarial challenge to the working diagnosis."""
47
+ challenges: list[dict[str, Any]] # [{"claim": str, "counter_evidence": str}]
48
+ must_not_miss: list[dict[str, Any]] # [{"diagnosis": str, "why_dangerous": str, "supporting_signs": str}]
49
+ recommended_workup: list[str]
50
+
51
+
52
+ class AlternativeDiagnosis(TypedDict, total=False):
53
+ diagnosis: str
54
+ urgency: str # "critical" | "high" | "moderate"
55
+ evidence: str
56
+ next_step: str
57
+
58
+
59
+ class ConsultantOutput(TypedDict):
60
+ """Final synthesized consultation note."""
61
+ consultation_note: str
62
+ alternative_diagnoses: list[AlternativeDiagnosis]
63
+ immediate_actions: list[str]
64
+ confidence_note: str
65
+
66
+
67
+ class PipelineState(TypedDict):
68
+ """Full state passed through the LangGraph pipeline."""
69
+ # Input
70
+ clinical_input: ClinicalInput
71
+
72
+ # Agent outputs (populated as pipeline progresses)
73
+ diagnostician_output: Optional[DiagnosticianOutput]
74
+ bias_detector_output: Optional[BiasDetectorOutput]
75
+ devils_advocate_output: Optional[DevilsAdvocateOutput]
76
+ consultant_output: Optional[ConsultantOutput]
77
+
78
+ # Metadata
79
+ current_step: str
80
+ error: Optional[str]
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnostic Devil's Advocate — Main entry point.
3
+ A multi-agent AI system that challenges clinical diagnoses to prevent cognitive bias errors.
4
+ """
5
+
6
+ import logging
7
+ import sys
8
+ import os
9
+
10
+ # Add project root to path for imports
11
+ sys.path.insert(0, os.path.dirname(__file__))
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
16
+ )
17
+
18
+ import gradio as gr # noqa: E402
19
+
20
+ from config import ENABLE_MEDASR # noqa: E402
21
+ from ui.components import build_ui # noqa: E402
22
+ from ui.callbacks import analyze_streaming, load_demo, transcribe_audio # noqa: E402
23
+ from ui.css import CUSTOM_CSS # noqa: E402
24
+
25
+
26
+ def main():
27
+ demo = build_ui(
28
+ analyze_fn=analyze_streaming,
29
+ load_demo_fn=load_demo,
30
+ transcribe_fn=transcribe_audio if ENABLE_MEDASR else None,
31
+ )
32
+ demo.launch(
33
+ server_name="0.0.0.0",
34
+ server_port=7860,
35
+ css=CUSTOM_CSS,
36
+ theme=gr.themes.Soft(),
37
+ )
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
config.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for Diagnostic Devil's Advocate.
3
+ Controls model loading, quantization, and environment-specific settings.
4
+
5
+ Model loading priority:
6
+ 1. Local path (MODEL_LOCAL_DIR env var) — fully offline
7
+ 2. HF cache (auto-downloaded via huggingface-cli download) — offline after first download
8
+ 3. HF Hub (requires HF_TOKEN for gated models) — online fallback
9
+ """
10
+
11
+ import os
12
+ from huggingface_hub import try_to_load_from_cache
13
+
14
+ # --- Model Selection ---
15
+ USE_27B = os.environ.get("USE_27B", "false").lower() == "true"
16
+ QUANTIZE_4B = os.environ.get("QUANTIZE_4B", "true").lower() == "true"
17
+ ENABLE_MEDASR = os.environ.get("ENABLE_MEDASR", "true").lower() == "true"
18
+
19
+ # --- Prompt Repetition (arXiv:2512.14982) ---
20
+ # Repeating the user prompt improves non-reasoning LLM performance (47 wins, 0 losses
21
+ # across 70 benchmark-model combos). Only increases prefill tokens, no extra generation.
22
+ ENABLE_PROMPT_REPETITION = os.environ.get("ENABLE_PROMPT_REPETITION", "true").lower() == "true"
23
+
24
+ # --- HF Token (for gated models) ---
25
+ # Loaded from: env var > huggingface-cli login stored token (auto)
26
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
27
+
28
+ # --- Model IDs (HF Hub) ---
29
+ _MEDGEMMA_4B_HUB_ID = "google/medgemma-1.5-4b-it"
30
+ _MEDGEMMA_27B_HUB_ID = "google/medgemma-27b-text-it"
31
+ _MEDSIGLIP_HUB_ID = "google/medsiglip-448"
32
+ _MEDASR_HUB_ID = "google/medasr"
33
+
34
+ # --- Optional local model directories (override HF Hub) ---
35
+ # Set these env vars to point to a local directory containing model weights.
36
+ # If not set, models load from HF cache (downloaded via `huggingface-cli download`).
37
+ MODEL_LOCAL_DIR = os.environ.get("MODEL_LOCAL_DIR", None)
38
+
39
+ def _resolve_model_path(hub_id: str, local_subdir: str | None = None) -> str:
40
+ """Resolve model path: local dir > HF cache > HF Hub ID."""
41
+ # 1. Explicit local directory
42
+ if MODEL_LOCAL_DIR:
43
+ local_path = os.path.join(MODEL_LOCAL_DIR, local_subdir or hub_id.split("/")[-1])
44
+ if os.path.isdir(local_path):
45
+ return local_path
46
+ # 2. HF cache (already downloaded via huggingface-cli download)
47
+ try:
48
+ cached = try_to_load_from_cache(hub_id, "config.json")
49
+ except Exception:
50
+ cached = None
51
+ if cached is not None and isinstance(cached, str):
52
+ # Return the repo snapshot directory (parent of config.json)
53
+ return os.path.dirname(cached)
54
+ # 3. Fallback to Hub ID (will download on first use)
55
+ return hub_id
56
+
57
+ MEDGEMMA_4B_MODEL_ID = _resolve_model_path(_MEDGEMMA_4B_HUB_ID, "medgemma-4b")
58
+ MEDGEMMA_27B_MODEL_ID = _resolve_model_path(_MEDGEMMA_27B_HUB_ID, "medgemma-27b")
59
+ MEDSIGLIP_MODEL_ID = _resolve_model_path(_MEDSIGLIP_HUB_ID, "medsiglip-448")
60
+ MEDASR_MODEL_ID = _resolve_model_path(_MEDASR_HUB_ID, "medasr")
61
+
62
+ # --- Generation Parameters ---
63
+ MAX_NEW_TOKENS_4B = 4096
64
+ MAX_NEW_TOKENS_27B = 6000
65
+ TEMPERATURE = 0.0
66
+ REPETITION_PENALTY = 1.2 # Prevent greedy decoding repetition loops
67
+
68
+ # --- Device ---
69
+ DEVICE = os.environ.get("DEVICE", "cuda")
70
+
71
+ # --- Demo Cases Directory ---
72
+ DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
73
+ DEMO_CASES_DIR = os.path.join(DATA_DIR, "demo_cases")
data/demo_cases/SOURCES.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo Case Clinical Data Sources
2
+
3
+ The clinical scenarios in this project are **composite cases** constructed from published medical literature on diagnostic errors. They are not direct copies of any single patient case. Each scenario synthesizes common presentation patterns, vitals, labs, and misdiagnosis trajectories documented across multiple peer-reviewed sources.
4
+
5
+ ---
6
+
7
+ ## Case 1: Missed Pneumothorax (32M, Motorcycle Collision)
8
+
9
+ **Misdiagnosis pattern**: Traumatic pneumothorax missed on supine AP chest X-ray, discharged as rib contusion.
10
+
11
+ ### Key References
12
+
13
+ - **Ball CG, Kirkpatrick AW, Laupland KB, et al.** "Incidence, risk factors, and outcomes for occult pneumothoraces in victims of major trauma." *J Trauma.* 2005;59(4):917-924.
14
+ - Documents occult pneumothorax rates of 29-72% in trauma patients; supine CXR misses a significant proportion.
15
+
16
+ - **Soldati G, Testa A, Sher S, et al.** "Occult traumatic pneumothorax: diagnostic accuracy of lung ultrasonography in the emergency department." *Chest.* 2008;133(1):204-211.
17
+
18
+ - **Omar HR, Abdelmalak H, Mangar D, Rashad R.** "Occult pneumothorax, revisited." *J Trauma Manag Outcomes.* 2010;4:12.
19
+ - PMC2984474 — Reviews occult pneumothorax prevalence (3.7% to 64%), risk factors (subcutaneous emphysema OR 5.47, rib fractures OR 2.65).
20
+
21
+ - **Defined A, et al.** "Anteroposterior chest radiograph vs. chest CT scan in early detection of pneumothorax in trauma patients." *J Cardiothorac Surg.* 2011;6:74.
22
+ - PMC3195099 — Case series including 42M and 24M MVA patients with CXR-negative, CT-positive pneumothorax.
23
+
24
+ - **Del Cura JL, et al.** "Commonly Missed Findings on Chest Radiographs: Causes and Consequences." *Chest.* 2023;163(3):650-661.
25
+ - PMC10154905 — Systematic review of perceptual errors in CXR interpretation.
26
+
27
+ ### Clinical Data Basis
28
+ - Vitals (HR 104, SpO2 96%, BP 132/84) reflect typical blunt chest trauma presentation from trauma registry data.
29
+ - Labs (WBC 11.2, Lactate 1.8) are within ranges reported for minor trauma without hemorrhagic shock.
30
+ - Supine AP film reading pattern based on documented false-negative scenarios in the cited studies.
31
+
32
+ ---
33
+
34
+ ## Case 2: Aortic Dissection Misdiagnosed as GERD (58M, Hypertensive)
35
+
36
+ **Misdiagnosis pattern**: Acute aortic dissection attributed to acid reflux/esophageal spasm, sent home with antacids.
37
+
38
+ ### Key References
39
+
40
+ - **Defined A, et al.** "Acute aortic dissection: a missed diagnosis." *BMJ Case Rep.* 2018;2018:bcr2018226586.
41
+ - PMC6203039 — 60M with untreated hypertension, sudden chest pain radiating to back, initially misdiagnosed as indigestion. CT angiography revealed Stanford type B dissection.
42
+
43
+ - **Hansen MS, Nogareda GJ, Hutchison SJ.** "Frequency of and inappropriate treatment of misdiagnosis of acute aortic dissection." *Am J Cardiol.* 2007;99(6):852-856.
44
+ - Overall misdiagnosis rate of 33.8% for aortic dissection.
45
+
46
+ - **Defined A, et al.** "Misdiagnosis of aortic dissection: experience of 361 patients." *J Clin Hypertens.* 2012;14(4):256-260.
47
+ - PubMed 22458748 — Large series documenting misdiagnosis factors including GI-like symptoms.
48
+
49
+ - **Defined A, et al.** "Acute aortic dissection: be aware of misdiagnosis." *BMC Res Notes.* 2009;2:25.
50
+ - Vitals: BP 210/135, HR 126, RR 40, SpO2 95% on O2.
51
+
52
+ - **MLMIC Insurance Company.** "Case Study: Failure to Diagnose Dissection of Ascending Thoracic Aorta Results in Settlement."
53
+ - Real malpractice case: patient prescribed Prilosec for presumed GERD, died same evening from undiagnosed ascending aortic dissection with cardiac tamponade.
54
+
55
+ - **CBS News / Mayo Clinic.** "He thought he had severe acid reflux. Doctors found a much different problem."
56
+ - Patient with prolonged GERD misdiagnosis, eventually found to have 7cm aortic aneurysm with bicuspid aortic valve.
57
+
58
+ ### Clinical Data Basis
59
+ - Blood pressure asymmetry (178/102 R arm vs 146/88 L arm) is a classic dissection sign documented in IRAD registry data.
60
+ - D-dimer 4,850 ng/mL reflects typical elevation in acute dissection (sensitivity >95% per meta-analyses).
61
+ - Serial negative troponins ruling out ACS before GERD attribution matches the documented diagnostic pathway in the cited cases.
62
+
63
+ ---
64
+
65
+ ## Case 3: Postpartum Pulmonary Embolism Misdiagnosed as Anxiety (29F, Post C-section)
66
+
67
+ **Misdiagnosis pattern**: Postpartum PE symptoms attributed to anxiety/hyperventilation, psychiatric consult ordered instead of CTPA.
68
+
69
+ ### Key References
70
+
71
+ - **Defined A, et al.** "Pulmonary embolism masked by symptoms of mental disorders." *Psychiatr Pol.* 2023;57(5):1121-1136.
72
+ - PMC10683049 — 21F postpartum patient on duloxetine, repeated "panic attacks" with tachycardia (123 bpm) and hyperventilation (RR 20-24), symptoms attributed to anxiety. Died from PE. Autopsy confirmed pulmonary embolism as cause of death.
73
+
74
+ - **Defined A, et al.** "Pulmonary Embolism in the Setting of Panic Attacks." In: *Pulmonary Embolism.* Springer, 2017.
75
+ - Discusses overlap between PE symptoms (dyspnea, tachycardia, chest pain) and panic attacks; concept of "diagnostic overshadowing."
76
+
77
+ - **Defined A.** "My Symptoms Were Misdiagnosed as Anxiety: Tamara's Story." *StopTheClot.org / National Blood Clot Alliance.*
78
+ - Patient narrative of PE misdiagnosed as anxiety.
79
+
80
+ - **Defined A.** "'Organic Anxiety' in a Middle-aged Man Presenting with Dyspnoea: a Case Report." *East Asian Arch Psychiatry.* 2019;29(3):97.
81
+ - PE presenting as anxiety disorder, eventually diagnosed after high index of suspicion.
82
+
83
+ - **Royal College of Obstetricians and Gynaecologists.** "Thromboembolic Disease in Pregnancy and the Puerperium: Acute Management." Green-top Guideline No. 37b.
84
+ - Half of pregnancy-related VTE occurs postpartum; PE is a leading cause of maternal death.
85
+
86
+ - **Defined A, et al.** "Postpartum Pulmonary Embolism in a Grand Multiparous: A Case Report." *Cureus.* 2023;15(6):e40777.
87
+ - PMC10291952 — Broad differential including anxiety and PE in postpartum dyspnea.
88
+
89
+ ### Clinical Data Basis
90
+ - Vitals (HR 118, SpO2 91%, RR 28) reflect typical submassive PE presentation from PIOPED II data.
91
+ - ABG (pH 7.48, pO2 68, pCO2 29) shows respiratory alkalosis with hypoxemia, classic PE pattern.
92
+ - D-dimer 3,200 ng/mL is elevated but often dismissed postpartum due to physiologically raised baseline.
93
+ - Right calf tenderness as DVT source matches the documented PE-DVT association (>90% of PE from lower extremity DVT).
94
+
95
+ ---
96
+
97
+ ## Medical Images
98
+
99
+ The chest X-ray images used in the demo cases are sourced from the **University of Saskatchewan Teaching Collection** (CC-BY-NC-SA 4.0 license) and are representative radiographs, not from the specific patients described in the composite clinical scenarios above.
100
+
101
+ ---
102
+
103
+ ## Disclaimer
104
+
105
+ These demo cases are **educational composites** designed to illustrate common diagnostic error patterns. They do not represent any individual patient. This tool is a research prototype for the MedGemma Impact Challenge and is **not intended for clinical decision-making**.
data/demo_cases/case1_pneumothorax.png ADDED
data/demo_cases/case2_aortic_dissection.png ADDED
data/demo_cases/case3_pulmonary_embolism.png ADDED
models/__init__.py ADDED
File without changes
models/medasr_client.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedASR client: medical speech-to-text transcription.
3
+ Uses CTC decoding with proper blank-token collapse.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import os
10
+ import threading
11
+ import warnings
12
+
13
+ from config import MEDASR_MODEL_ID, HF_TOKEN, DEVICE, ENABLE_MEDASR
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ _model = None
18
+ _processor = None
19
+ _load_lock = threading.Lock()
20
+
21
+
22
+ def _token_arg() -> dict:
23
+ if os.path.isdir(MEDASR_MODEL_ID):
24
+ return {}
25
+ return {"token": HF_TOKEN}
26
+
27
+
28
+ def load():
29
+ """Load MedASR model and processor."""
30
+ global _model, _processor
31
+ if _model is not None:
32
+ return _model, _processor
33
+ if not ENABLE_MEDASR:
34
+ raise RuntimeError("MedASR is disabled via ENABLE_MEDASR=false")
35
+
36
+ with _load_lock:
37
+ if _model is not None:
38
+ return _model, _processor
39
+
40
+ import torch
41
+ from transformers import AutoModelForCTC, AutoProcessor
42
+
43
+ logger.info("Loading MedASR from %s...", "local" if os.path.isdir(MEDASR_MODEL_ID) else "HF Hub")
44
+ _processor = AutoProcessor.from_pretrained(MEDASR_MODEL_ID, **_token_arg())
45
+ _model = AutoModelForCTC.from_pretrained(
46
+ MEDASR_MODEL_ID, **_token_arg(), dtype=torch.float32,
47
+ ).to(DEVICE)
48
+ _model.eval()
49
+ logger.info("MedASR loaded.")
50
+ return _model, _processor
51
+
52
+
53
+ def _ctc_greedy_decode(logits, processor) -> str:
54
+ """
55
+ Proper CTC greedy decode:
56
+ 1. argmax to get predicted token IDs
57
+ 2. Collapse consecutive duplicate IDs
58
+ 3. Remove blank token IDs
59
+ 4. Decode remaining IDs to text
60
+ """
61
+ import torch
62
+
63
+ predicted_ids = torch.argmax(logits, dim=-1)[0] # (seq_len,)
64
+
65
+ # Determine blank token ID
66
+ blank_id = getattr(processor.tokenizer, "pad_token_id", None)
67
+ if blank_id is None:
68
+ blank_id = 0 # CTC blank is typically ID 0
69
+
70
+ # Collapse consecutive duplicates, then remove blanks
71
+ collapsed = []
72
+ prev_id = -1
73
+ for token_id in predicted_ids.tolist():
74
+ if token_id != prev_id:
75
+ if token_id != blank_id:
76
+ collapsed.append(token_id)
77
+ prev_id = token_id
78
+
79
+ if not collapsed:
80
+ return ""
81
+
82
+ # Decode token IDs to text
83
+ text = processor.tokenizer.decode(collapsed, skip_special_tokens=True)
84
+ return text.strip()
85
+
86
+
87
+ def transcribe(audio_array, sampling_rate: int = 16000) -> str:
88
+ """
89
+ Transcribe audio to text using CTC greedy decoding.
90
+
91
+ Args:
92
+ audio_array: numpy array of audio samples (mono, float32).
93
+ sampling_rate: audio sample rate (MedASR expects 16kHz).
94
+
95
+ Returns:
96
+ Transcribed text string.
97
+ """
98
+ model, processor = load()
99
+ import torch
100
+
101
+ inputs = processor(
102
+ audio_array, sampling_rate=sampling_rate, return_tensors="pt",
103
+ ).to(model.device)
104
+
105
+ with torch.inference_mode():
106
+ # Suppress the harmless padding='same' convolution warning
107
+ with warnings.catch_warnings():
108
+ warnings.filterwarnings("ignore", message=".*padding='same'.*")
109
+ logits = model(**inputs).logits
110
+
111
+ return _ctc_greedy_decode(logits, processor)
models/medgemma_client.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedGemma client: unified interface for 4B (multimodal) and 27B (text-only) models.
3
+ Loads locally via transformers with optional 4-bit quantization.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import os
10
+ import threading
11
+ from PIL import Image
12
+
13
+ from config import (
14
+ USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE,
15
+ MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID,
16
+ MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY,
17
+ )
18
+ from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ _model_4b = None
23
+ _processor_4b = None
24
+ _model_27b = None
25
+ _tokenizer_27b = None
26
+ _load_4b_lock = threading.Lock()
27
+ _load_27b_lock = threading.Lock()
28
+
29
+
30
+ def _is_local_path(model_id: str) -> bool:
31
+ """Check if model_id is a local directory path."""
32
+ return os.path.isdir(model_id)
33
+
34
+
35
+ def _token_arg(model_id: str) -> dict:
36
+ """Return token kwarg only when loading from HF Hub (not local path)."""
37
+ if _is_local_path(model_id):
38
+ return {}
39
+ return {"token": HF_TOKEN}
40
+
41
+
42
+ def _get_quantization_config():
43
+ """Return BitsAndBytesConfig for 4-bit quantization."""
44
+ import torch
45
+ from transformers import BitsAndBytesConfig
46
+ return BitsAndBytesConfig(
47
+ load_in_4bit=True,
48
+ bnb_4bit_compute_dtype=torch.bfloat16,
49
+ bnb_4bit_quant_type="nf4",
50
+ )
51
+
52
+
53
+ def load_4b():
54
+ """Load MedGemma 4B-IT (multimodal) model and processor."""
55
+ global _model_4b, _processor_4b
56
+ if _model_4b is not None:
57
+ return _model_4b, _processor_4b
58
+
59
+ with _load_4b_lock:
60
+ if _model_4b is not None:
61
+ return _model_4b, _processor_4b
62
+
63
+ import torch
64
+ from transformers import AutoModelForImageTextToText, AutoProcessor
65
+
66
+ is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID)
67
+ logger.info(
68
+ "Loading MedGemma 4B-IT (%s) from %s...",
69
+ "4-bit" if QUANTIZE_4B else "bf16",
70
+ "local" if is_local else "HF Hub",
71
+ )
72
+
73
+ # BitsAndBytes quantization requires device_map="auto", not "cuda"
74
+ device_map = "auto" if QUANTIZE_4B else DEVICE
75
+ kwargs = {**_token_arg(MEDGEMMA_4B_MODEL_ID), "device_map": device_map}
76
+ if QUANTIZE_4B:
77
+ kwargs["quantization_config"] = _get_quantization_config()
78
+ else:
79
+ kwargs["dtype"] = torch.bfloat16
80
+
81
+ _processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID))
82
+ _model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs)
83
+ _model_4b.eval()
84
+ logger.info("MedGemma 4B loaded.")
85
+ return _model_4b, _processor_4b
86
+
87
+
88
+ def load_27b():
89
+ """Load MedGemma 27B Text-IT model and tokenizer (A100 only)."""
90
+ global _model_27b, _tokenizer_27b
91
+ if _model_27b is not None:
92
+ return _model_27b, _tokenizer_27b
93
+
94
+ with _load_27b_lock:
95
+ if _model_27b is not None:
96
+ return _model_27b, _tokenizer_27b
97
+
98
+ import torch
99
+ from transformers import AutoModelForCausalLM, AutoTokenizer
100
+
101
+ is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID)
102
+ logger.info(
103
+ "Loading MedGemma 27B Text-IT (bf16) from %s...",
104
+ "local" if is_local else "HF Hub",
105
+ )
106
+
107
+ _tokenizer_27b = AutoTokenizer.from_pretrained(MEDGEMMA_27B_MODEL_ID, **_token_arg(MEDGEMMA_27B_MODEL_ID))
108
+ _model_27b = AutoModelForCausalLM.from_pretrained(
109
+ MEDGEMMA_27B_MODEL_ID,
110
+ **_token_arg(MEDGEMMA_27B_MODEL_ID),
111
+ dtype=torch.bfloat16,
112
+ device_map="auto",
113
+ )
114
+ _model_27b.eval()
115
+ logger.info("MedGemma 27B loaded.")
116
+ return _model_27b, _tokenizer_27b
117
+
118
+
119
+ def generate_with_image(prompt: str, image: Image.Image, system_prompt: str = "") -> str:
120
+ """Generate text from image + text prompt using MedGemma 4B."""
121
+ model, processor = load_4b()
122
+ image = resize_for_medgemma(image)
123
+ prompt = apply_prompt_repetition(prompt)
124
+
125
+ messages = []
126
+ if system_prompt:
127
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
128
+ messages.append({
129
+ "role": "user",
130
+ "content": [
131
+ {"type": "image", "image": image},
132
+ {"type": "text", "text": prompt},
133
+ ],
134
+ })
135
+
136
+ inputs = processor.apply_chat_template(
137
+ messages, add_generation_prompt=True, tokenize=True,
138
+ return_dict=True, return_tensors="pt",
139
+ ).to(model.device)
140
+
141
+ import torch
142
+
143
+ with torch.inference_mode():
144
+ output_ids = model.generate(
145
+ **inputs,
146
+ max_new_tokens=MAX_NEW_TOKENS_4B,
147
+ do_sample=TEMPERATURE > 0,
148
+ repetition_penalty=REPETITION_PENALTY,
149
+ **({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}),
150
+ )
151
+
152
+ # Decode only the new tokens
153
+ new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
154
+ text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)
155
+ return strip_thinking_tokens(text)
156
+
157
+
158
+ def generate_text(prompt: str, system_prompt: str = "") -> str:
159
+ """Generate text from text-only prompt. Uses 27B if available, else 4B."""
160
+ if USE_27B:
161
+ return _generate_text_27b(prompt, system_prompt)
162
+ return _generate_text_4b(prompt, system_prompt)
163
+
164
+
165
+ def _generate_text_4b(prompt: str, system_prompt: str = "") -> str:
166
+ """Text-only generation with 4B model."""
167
+ model, processor = load_4b()
168
+ prompt = apply_prompt_repetition(prompt)
169
+
170
+ messages = []
171
+ if system_prompt:
172
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
173
+ messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
174
+
175
+ inputs = processor.apply_chat_template(
176
+ messages, add_generation_prompt=True, tokenize=True,
177
+ return_dict=True, return_tensors="pt",
178
+ ).to(model.device)
179
+
180
+ import torch
181
+
182
+ with torch.inference_mode():
183
+ output_ids = model.generate(
184
+ **inputs,
185
+ max_new_tokens=MAX_NEW_TOKENS_4B,
186
+ do_sample=TEMPERATURE > 0,
187
+ repetition_penalty=REPETITION_PENALTY,
188
+ **({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}),
189
+ )
190
+
191
+ new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
192
+ text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True)
193
+ return strip_thinking_tokens(text)
194
+
195
+
196
+ def _generate_text_27b(prompt: str, system_prompt: str = "") -> str:
197
+ """Text-only generation with 27B model (thinking mode)."""
198
+ model, tokenizer = load_27b()
199
+ prompt = apply_prompt_repetition(prompt)
200
+
201
+ messages = []
202
+ if system_prompt:
203
+ messages.append({"role": "system", "content": system_prompt})
204
+ messages.append({"role": "user", "content": prompt})
205
+
206
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
207
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
208
+
209
+ import torch
210
+
211
+ with torch.inference_mode():
212
+ output_ids = model.generate(
213
+ **inputs,
214
+ max_new_tokens=MAX_NEW_TOKENS_27B,
215
+ do_sample=TEMPERATURE > 0,
216
+ repetition_penalty=REPETITION_PENALTY,
217
+ **({"temperature": TEMPERATURE} if TEMPERATURE > 0 else {}),
218
+ )
219
+
220
+ new_tokens = output_ids[0, inputs["input_ids"].shape[1]:]
221
+ text = tokenizer.decode(new_tokens, skip_special_tokens=True)
222
+ return strip_thinking_tokens(text)
models/medsiglip_client.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedSigLIP client: zero-shot medical image classification and embedding extraction.
3
+ Uses AutoProcessor following the official Google-Health/medsiglip notebook.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ import os
10
+ import threading
11
+
12
+ from PIL import Image
13
+
14
+ from config import MEDSIGLIP_MODEL_ID, HF_TOKEN, DEVICE
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _model = None
19
+ _processor = None
20
+ _load_lock = threading.Lock()
21
+
22
+
23
+ def _token_arg() -> dict:
24
+ if os.path.isdir(MEDSIGLIP_MODEL_ID):
25
+ return {}
26
+ return {"token": HF_TOKEN}
27
+
28
+
29
+ def load():
30
+ """Load MedSigLIP model and processor."""
31
+ global _model, _processor
32
+ if _model is not None:
33
+ return _model, _processor
34
+
35
+ with _load_lock:
36
+ if _model is not None:
37
+ return _model, _processor
38
+
39
+ import torch
40
+ from transformers import AutoModel, AutoProcessor
41
+
42
+ logger.info("Loading MedSigLIP from %s...", "local" if os.path.isdir(MEDSIGLIP_MODEL_ID) else "HF Hub")
43
+ _processor = AutoProcessor.from_pretrained(MEDSIGLIP_MODEL_ID, **_token_arg())
44
+ _model = AutoModel.from_pretrained(
45
+ MEDSIGLIP_MODEL_ID, **_token_arg(), dtype=torch.float32,
46
+ ).to(DEVICE)
47
+ _model.eval()
48
+ logger.info("MedSigLIP loaded.")
49
+ return _model, _processor
50
+
51
+
52
+ def classify(image: Image.Image, candidate_labels: list) -> list[dict]:
53
+ """
54
+ Zero-shot classification of a medical image.
55
+
56
+ Args:
57
+ candidate_labels: list of str OR list of (short_label, descriptive_prompt) tuples.
58
+
59
+ Returns list of {"label": str, "score": float} sorted by descending score.
60
+ Scores are raw logits (not sigmoid/softmax) — higher = better match.
61
+ """
62
+ if candidate_labels and isinstance(candidate_labels[0], (list, tuple)):
63
+ display_labels = [c[0] for c in candidate_labels]
64
+ text_prompts = [c[1] for c in candidate_labels]
65
+ else:
66
+ display_labels = candidate_labels
67
+ text_prompts = candidate_labels
68
+
69
+ model, processor = load()
70
+ # Official usage: single processor call with padding="max_length"
71
+ inputs = processor(
72
+ text=text_prompts, images=image,
73
+ padding="max_length", return_tensors="pt",
74
+ ).to(model.device)
75
+
76
+ import torch
77
+
78
+ with torch.inference_mode():
79
+ outputs = model(**inputs)
80
+
81
+ # Use raw logits — official notebook uses argmax on logits_per_image directly
82
+ logits = outputs.logits_per_image[0].cpu().tolist()
83
+
84
+ results = [{"label": label, "score": score} for label, score in zip(display_labels, logits)]
85
+ results.sort(key=lambda x: x["score"], reverse=True)
86
+ return results
87
+
88
+
89
+ def _normalize_modality(modality: str | None) -> str:
90
+ m = (modality or "").strip().lower()
91
+ if m in {"cxr", "x-ray", "xray", "chest x-ray", "chest xray", "chest radiograph", "radiograph"}:
92
+ return "cxr"
93
+ if m in {"ct", "ct scan", "computed tomography"}:
94
+ return "ct"
95
+ return "other"
96
+
97
+
98
+ def _verification_prompts(sign: str, modality: str | None) -> tuple[str, str]:
99
+ sign_l = sign.lower()
100
+ m = _normalize_modality(modality)
101
+ if m == "ct":
102
+ positive = f"a CT scan showing {sign_l}"
103
+ negative = f"a CT scan showing no evidence of {sign_l}"
104
+ elif m == "other":
105
+ positive = f"a medical image showing {sign_l}"
106
+ negative = f"a medical image showing no evidence of {sign_l}"
107
+ else:
108
+ positive = f"a chest radiograph showing {sign_l}"
109
+ negative = f"a normal chest radiograph with no {sign_l}"
110
+ return positive, negative
111
+
112
+
113
+ def verify_sign(image: Image.Image, sign: str, modality: str | None = None) -> dict:
114
+ """
115
+ Binary verification: does the image show this finding/sign?
116
+ Compares "showing X" vs "no X" — matches official MedSigLIP usage pattern.
117
+
118
+ Returns confidence level based on logit difference:
119
+ diff > 2 → "likely present"
120
+ diff > 0 → "possibly present"
121
+ diff > -2 → "inconclusive"
122
+ else → "likely absent"
123
+ """
124
+ positive, negative = _verification_prompts(sign, modality)
125
+
126
+ results = classify(image, [
127
+ ("positive", positive),
128
+ ("negative", negative),
129
+ ])
130
+
131
+ pos = next(r for r in results if r["label"] == "positive")
132
+ neg = next(r for r in results if r["label"] == "negative")
133
+ diff = pos["score"] - neg["score"]
134
+
135
+ if diff > 2:
136
+ confidence = "likely present"
137
+ elif diff > 0:
138
+ confidence = "possibly present"
139
+ elif diff > -2:
140
+ confidence = "inconclusive"
141
+ else:
142
+ confidence = "likely absent"
143
+
144
+ return {
145
+ "sign": sign,
146
+ "modality": _normalize_modality(modality),
147
+ "positive_logit": pos["score"],
148
+ "negative_logit": neg["score"],
149
+ "diff": diff,
150
+ "confidence": confidence,
151
+ }
152
+
153
+
154
+ def verify_findings(
155
+ image: Image.Image,
156
+ signs: list[str],
157
+ modality: str | None = None,
158
+ ) -> list[dict]:
159
+ """
160
+ Verify a list of imaging signs against the image.
161
+ Returns only results where SigLIP has a meaningful opinion (not inconclusive).
162
+ """
163
+ results = [verify_sign(image, sign, modality=modality) for sign in signs]
164
+ return results
models/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for model outputs: thinking token stripping, image encoding,
3
+ prompt repetition, etc.
4
+ """
5
+
6
+ import re
7
+ import base64
8
+ from io import BytesIO
9
+ from PIL import Image
10
+
11
+ from config import ENABLE_PROMPT_REPETITION
12
+
13
+ # MedGemma wraps internal reasoning in <unused94>...<unused95> tags
14
+ THINKING_PATTERN = re.compile(r"<unused94>.*?<unused95>", re.DOTALL)
15
+
16
+
17
+ def strip_thinking_tokens(text: str) -> str:
18
+ """Remove MedGemma's internal thinking tokens from output."""
19
+ return THINKING_PATTERN.sub("", text).strip()
20
+
21
+
22
+ def image_to_base64(image: Image.Image, fmt: str = "PNG") -> str:
23
+ """Convert PIL Image to base64 data URL string."""
24
+ buffer = BytesIO()
25
+ image.save(buffer, format=fmt)
26
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
27
+ return f"data:image/{fmt.lower()};base64,{encoded}"
28
+
29
+
30
+ def apply_prompt_repetition(prompt: str) -> str:
31
+ """Repeat the user prompt to improve LLM output quality.
32
+
33
+ Based on "Prompt Repetition Improves Non-Reasoning LLMs"
34
+ (arXiv:2512.14982, Google Research 2025): repeating the input prompt
35
+ wins 47/70 benchmark-model combos with 0 losses. Uses the verbose
36
+ variant with a transition phrase for clarity.
37
+ """
38
+ if not ENABLE_PROMPT_REPETITION:
39
+ return prompt
40
+ return f"{prompt}\n\nLet me repeat the request:\n\n{prompt}"
41
+
42
+
43
+ def resize_for_medgemma(image: Image.Image, max_size: int = 896) -> Image.Image:
44
+ """Resize image to fit MedGemma's expected input resolution (896x896)."""
45
+ if max(image.size) <= max_size:
46
+ return image
47
+ ratio = max_size / max(image.size)
48
+ new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
49
+ return image.resize(new_size, Image.LANCZOS)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.50.0
3
+ accelerate>=0.26.0
4
+ bitsandbytes>=0.42.0
5
+ langgraph>=0.2.0
6
+ gradio==5.12.0
7
+ Pillow>=10.0.0
8
+ numpy>=1.24.0
9
+ scipy>=1.10.0
10
+ json-repair>=0.30.0
tests/__init__.py ADDED
File without changes
tests/test_output_parser.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5
+
6
+ from agents.output_parser import parse_json_response
7
+
8
+
9
+ def test_parse_json_response_returns_dict():
10
+ parsed = parse_json_response('{"challenges":[{"claim":"x","counter_evidence":"y"}]}')
11
+ assert parsed["challenges"][0]["claim"] == "x"
12
+
13
+
14
+ def test_parse_json_response_coerces_top_level_list_of_strings():
15
+ parsed = parse_json_response('["CT angiogram","D-dimer"]')
16
+ assert parsed["items"] == ["CT angiogram", "D-dimer"]
17
+
18
+
19
+ def test_parse_json_response_infers_container_key_for_da_items():
20
+ parsed = parse_json_response(
21
+ '[{"diagnosis":"Aortic dissection","why_dangerous":"High mortality","supporting_signs":"Pain radiating to back","rule_out_test":"CTA chest"}]'
22
+ )
23
+ assert parsed["must_not_miss"][0]["diagnosis"] == "Aortic dissection"
24
+
tests/test_pipeline_mock.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end pipeline test with mocked model calls (no GPU required)."""
2
+
3
+ import os
4
+ import sys
5
+ from unittest.mock import patch
6
+
7
+ from PIL import Image
8
+
9
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
10
+
11
+
12
+ def test_pipeline_end_to_end_with_mocks():
13
+ from agents.graph import run_pipeline
14
+ from agents.prompts import (
15
+ DIAGNOSTICIAN_SYSTEM,
16
+ BIAS_DETECTOR_SYSTEM,
17
+ DEVIL_ADVOCATE_SYSTEM,
18
+ CONSULTANT_SYSTEM,
19
+ )
20
+
21
+ dummy_image = Image.new("RGB", (512, 512), color="gray")
22
+
23
+ diagnostician_json = """
24
+ {
25
+ "findings": [
26
+ {"finding": "Pneumothorax", "source": "imaging", "description": "Left apical pleural line with absent peripheral markings."},
27
+ {"finding": "Rib fracture", "source": "imaging", "description": "Possible fracture of the left 5th rib."},
28
+ {"finding": "Tachycardia", "source": "clinical", "description": "HR 104, consistent with pain or hemodynamic compromise."}
29
+ ],
30
+ "differential_diagnoses": [
31
+ {"diagnosis": "Pneumothorax", "reasoning": "Visible pleural line on imaging combined with tachycardia and dyspnea from clinical context."}
32
+ ]
33
+ }
34
+ """.strip()
35
+
36
+ bias_detector_json = """
37
+ {
38
+ "discrepancy_summary": "Doctor focused on rib pain; image suggests pneumothorax.",
39
+ "identified_biases": [
40
+ {"type": "Anchoring", "evidence": "Trauma mechanism overweighted", "severity": "HIGH"}
41
+ ],
42
+ "missed_findings": ["Pneumothorax"],
43
+ "agreement_points": ["Rib pain consistent with trauma"]
44
+ }
45
+ """.strip()
46
+
47
+ devil_advocate_json = """
48
+ {
49
+ "challenges": [
50
+ {"claim": "Rib contusion explains symptoms", "counter_evidence": "Dyspnea can reflect pneumothorax severity."}
51
+ ],
52
+ "must_not_miss": [
53
+ {
54
+ "diagnosis": "Tension pneumothorax",
55
+ "why_dangerous": "Rapid hemodynamic collapse if untreated",
56
+ "supporting_signs": "Worsening dyspnea and pleural line",
57
+ "rule_out_test": "Bedside ultrasound or repeat upright CXR"
58
+ }
59
+ ],
60
+ "recommended_workup": ["Repeat upright chest radiograph", "Point-of-care ultrasound"]
61
+ }
62
+ """.strip()
63
+
64
+ consultant_json = """
65
+ {
66
+ "consultation_note": "Have you considered pneumothorax given the pleural line?\\n\\nI would re-image upright and consider bedside ultrasound.",
67
+ "alternative_diagnoses": [
68
+ {
69
+ "diagnosis": "Pneumothorax",
70
+ "urgency": "high",
71
+ "evidence": "Pleural line and absent peripheral markings",
72
+ "next_step": "Repeat upright CXR or POCUS"
73
+ }
74
+ ],
75
+ "immediate_actions": ["Repeat upright CXR", "POCUS"],
76
+ "confidence_note": "Based on a single image; clinical correlation required."
77
+ }
78
+ """.strip()
79
+
80
+ def fake_generate_with_image(_prompt: str, _image, system_prompt: str = "") -> str:
81
+ if system_prompt == DIAGNOSTICIAN_SYSTEM:
82
+ return diagnostician_json
83
+ if system_prompt == BIAS_DETECTOR_SYSTEM:
84
+ return bias_detector_json
85
+ if system_prompt.startswith(DEVIL_ADVOCATE_SYSTEM):
86
+ return devil_advocate_json
87
+ raise AssertionError(f"Unexpected system_prompt (with image): {system_prompt!r}")
88
+
89
+ def fake_generate_text(_prompt: str, system_prompt: str = "") -> str:
90
+ if system_prompt == CONSULTANT_SYSTEM:
91
+ return consultant_json
92
+ raise AssertionError(f"Unexpected system_prompt: {system_prompt!r}")
93
+
94
+ with patch("models.medgemma_client.generate_with_image", side_effect=fake_generate_with_image), patch(
95
+ "models.medgemma_client.generate_text",
96
+ side_effect=fake_generate_text,
97
+ ), patch(
98
+ "models.medsiglip_client.verify_findings",
99
+ return_value=[{"sign": "pneumothorax", "confidence": "likely present"}],
100
+ ):
101
+ result = run_pipeline(
102
+ image=dummy_image,
103
+ doctor_diagnosis="Rib contusion",
104
+ clinical_context="32M, trauma, left chest pain, mild dyspnea.",
105
+ modality="CXR",
106
+ )
107
+
108
+ assert result.get("error") is None
109
+
110
+ diag = result.get("diagnostician_output") or {}
111
+ assert diag.get("findings_list"), "Diagnostician findings_list missing"
112
+ assert diag.get("analysis"), "Diagnostician analysis missing"
113
+
114
+ bias = result.get("bias_detector_output") or {}
115
+ assert bias.get("discrepancy_summary")
116
+ assert bias.get("identified_biases"), "Bias detector identified_biases missing"
117
+
118
+ da = result.get("devils_advocate_output") or {}
119
+ assert da.get("must_not_miss"), "Devil's advocate must_not_miss missing"
120
+ assert all(isinstance(x, str) for x in da.get("recommended_workup", []))
121
+
122
+ ref = result.get("consultant_output") or {}
123
+ assert ref.get("consultation_note")
124
+ assert isinstance(ref.get("alternative_diagnoses"), list)
tests/test_smoke.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick smoke tests: imports, graph build, demo loading, utils."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ from PIL import Image
7
+
8
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
9
+
10
+
11
+ def test_smoke_graph_builds():
12
+ from agents.graph import build_graph
13
+
14
+ graph = build_graph()
15
+ assert graph is not None
16
+
17
+
18
+ def test_smoke_demo_loader_returns_expected_shape():
19
+ from ui.callbacks import load_demo
20
+
21
+ img, diagnosis, context, modality = load_demo("Case 1: Missed Pneumothorax")
22
+ assert isinstance(diagnosis, str) and diagnosis
23
+ assert isinstance(context, str) and context
24
+ assert modality in {"CXR", "CT", "Other"}
25
+ assert img is None or hasattr(img, "size")
26
+
27
+
28
+ def test_utils_strip_and_resize():
29
+ from models.utils import resize_for_medgemma, strip_thinking_tokens
30
+
31
+ assert strip_thinking_tokens("<unused94>t<unused95>Real answer") == "Real answer"
32
+ big = Image.new("RGB", (2000, 2000), color="gray")
33
+ resized = resize_for_medgemma(big)
34
+ assert max(resized.size) <= 896
35
+
ui/__init__.py ADDED
File without changes
ui/callbacks.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Callbacks connecting UI events to the LangGraph pipeline."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from agents.graph import stream_pipeline
10
+ from config import DEMO_CASES_DIR, ENABLE_MEDASR
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Agent display info
15
+ AGENT_INFO = {
16
+ "diagnostician": ("Diagnostician", "Analyzing image independently..."),
17
+ "bias_detector": ("Bias Detector", "Scanning for cognitive biases..."),
18
+ "devil_advocate": ("Devil's Advocate", "Challenging the diagnosis..."),
19
+ "consultant": ("Consultant", "Synthesizing consultation report..."),
20
+ }
21
+
22
+ # Demo case definitions — based on published case reports and clinical literature.
23
+ # References:
24
+ # Case 1: PMC3195099 (AP CXR vs CT in trauma pneumothorax detection)
25
+ # Case 2: PMC6203039 (Acute aortic dissection: a missed diagnosis)
26
+ # Case 3: PMC10683049 (PE masked by symptoms of mental disorders)
27
+ DEMO_CASES = {
28
+ "Case 1: Missed Pneumothorax": {
29
+ "diagnosis": "Left rib contusion with musculoskeletal chest wall pain",
30
+ "context": (
31
+ "32-year-old male, presented to ED after a motorcycle collision at ~40 mph. "
32
+ "Helmet worn, no LOC. Chief complaint: left-sided chest pain worse with deep "
33
+ "inspiration.\n\n"
34
+ "Vitals: HR 104 bpm, BP 132/84 mmHg, RR 22/min, SpO2 96% on room air, "
35
+ "Temp 37.1 C.\n\n"
36
+ "Exam: Tenderness over left 4th-6th ribs, no crepitus, no subcutaneous "
37
+ "emphysema palpated. Breath sounds reportedly equal bilaterally (noisy ED). "
38
+ "Mild dyspnea attributed to pain.\n\n"
39
+ "Labs: WBC 11.2, Hgb 14.1, Lactate 1.8 mmol/L.\n\n"
40
+ "ED physician ordered AP chest X-ray (supine) — read as 'no acute "
41
+ "cardiopulmonary abnormality, possible left rib fracture.' Patient was given "
42
+ "ibuprofen and discharged with rib fracture precautions."
43
+ ),
44
+ "image_file": "case1_pneumothorax.png",
45
+ "modality": "CXR",
46
+ },
47
+ "Case 2: Aortic Dissection": {
48
+ "diagnosis": "Acute gastroesophageal reflux / esophageal spasm",
49
+ "context": (
50
+ "58-year-old male with a 15-year history of hypertension (poorly controlled, "
51
+ "non-compliant with amlodipine). Presented to ED with sudden-onset severe "
52
+ "retrosternal chest pain radiating to the interscapular back region, starting "
53
+ "30 minutes ago.\n\n"
54
+ "Vitals: BP 178/102 mmHg (right arm), 146/88 mmHg (left arm), HR 92 bpm, "
55
+ "RR 20/min, SpO2 97%, Temp 37.0 C.\n\n"
56
+ "Exam: Diaphoretic, visibly distressed. Abdomen soft, mild epigastric "
57
+ "tenderness. Heart sounds normal, no murmur. Peripheral pulses intact but "
58
+ "radial pulse asymmetry noted.\n\n"
59
+ "Labs: Troponin I <0.01 (negative x2 at 0h and 3h), D-dimer 4,850 ng/mL "
60
+ "(markedly elevated), WBC 13.4, Creatinine 1.3.\n\n"
61
+ "ECG: Sinus tachycardia, nonspecific ST changes. Initial CXR ordered. "
62
+ "ED physician considered ACS (ruled out by troponin), then attributed symptoms "
63
+ "to acid reflux; prescribed IV pantoprazole and GI cocktail. Pain not relieved."
64
+ ),
65
+ "image_file": "case2_aortic_dissection.png",
66
+ "modality": "CXR",
67
+ },
68
+ "Case 3: Pulmonary Embolism": {
69
+ "diagnosis": "Postpartum anxiety with hyperventilation syndrome",
70
+ "context": (
71
+ "29-year-old female, G2P2, day 5 after emergency cesarean section (prolonged "
72
+ "labor, general anesthesia). Presented with acute onset dyspnea and chest "
73
+ "tightness at rest. Reports feeling of 'impending doom' and inability to catch "
74
+ "breath.\n\n"
75
+ "Vitals: HR 118 bpm, BP 108/72 mmHg, RR 28/min, SpO2 91% on room air "
76
+ "(improved to 95% on 4L NC), Temp 37.3 C.\n\n"
77
+ "Exam: Anxious-appearing, tachypneic. Lungs clear to auscultation. Mild "
78
+ "right-sided pleuritic chest pain. Right calf tenderness and mild swelling "
79
+ "noted but attributed to post-surgical immobility. No Homan sign.\n\n"
80
+ "Labs: D-dimer 3,200 ng/mL (elevated, but 'expected postpartum'), "
81
+ "WBC 10.8, Hgb 10.2, ABG on RA: pH 7.48, pO2 68 mmHg, pCO2 29 mmHg.\n\n"
82
+ "OB team attributed symptoms to postpartum anxiety, prescribed lorazepam "
83
+ "0.5 mg PRN. Psychiatry consult requested. No CTPA ordered initially."
84
+ ),
85
+ "image_file": "case3_pulmonary_embolism.png",
86
+ "modality": "CXR",
87
+ },
88
+ }
89
+
90
+
91
+ def analyze_streaming(image: Image.Image | None, diagnosis: str, context: str, modality: str):
92
+ """
93
+ Generator: run pipeline and yield single HTML output after each agent step.
94
+ Each agent's output appears inline below its progress header.
95
+ """
96
+ if image is None:
97
+ yield '<div class="pipeline-error">Please upload a medical image.</div>'
98
+ return
99
+ if not diagnosis.strip():
100
+ yield '<div class="pipeline-error">Please enter the doctor\'s working diagnosis.</div>'
101
+ return
102
+ if not context.strip():
103
+ context = "No additional clinical context provided."
104
+ if not isinstance(modality, str) or not modality.strip():
105
+ modality = "CXR"
106
+
107
+ completed = {}
108
+ agent_outputs = {}
109
+ all_agents = ["diagnostician", "bias_detector", "devil_advocate", "consultant"]
110
+
111
+ try:
112
+ yield _build_pipeline(all_agents, completed, agent_outputs, active="diagnostician")
113
+
114
+ accumulated_state = {}
115
+ for node_name, state_update in stream_pipeline(image, diagnosis.strip(), context.strip(), modality.strip()):
116
+ completed[node_name] = True
117
+ accumulated_state.update(state_update)
118
+
119
+ if state_update.get("error"):
120
+ agent_outputs[node_name] = f'<div class="pipeline-error">{_esc(state_update.get("error"))}</div>'
121
+ yield _build_pipeline(all_agents, completed, agent_outputs, error=node_name)
122
+ return
123
+
124
+ # Generate this agent's HTML output
125
+ agent_outputs[node_name] = _format_agent_output(node_name, accumulated_state)
126
+
127
+ idx = all_agents.index(node_name) if node_name in all_agents else -1
128
+ next_active = all_agents[idx + 1] if idx + 1 < len(all_agents) else None
129
+
130
+ yield _build_pipeline(all_agents, completed, agent_outputs, active=next_active)
131
+
132
+ except Exception as e:
133
+ logger.exception("Pipeline failed")
134
+ yield f'<div class="pipeline-error">Pipeline error: {_esc(e)}</div>'
135
+
136
+
137
+ def _build_pipeline(all_agents, completed, agent_outputs, active=None, error=None) -> str:
138
+ """Build combined progress + inline output HTML."""
139
+ from ui.components import _build_progress_html
140
+ return _build_progress_html(
141
+ completed=list(completed.keys()),
142
+ active=active,
143
+ error=error,
144
+ agent_outputs=agent_outputs,
145
+ )
146
+
147
+
148
+ def _format_agent_output(agent_id: str, state: dict) -> str:
149
+ """Generate HTML content for a specific agent's output."""
150
+ if agent_id == "diagnostician":
151
+ return _format_diagnostician(state)
152
+ elif agent_id == "bias_detector":
153
+ return _format_bias_detector(state)
154
+ elif agent_id == "devil_advocate":
155
+ return _format_devil_advocate(state)
156
+ elif agent_id == "consultant":
157
+ return _format_consultant(state)
158
+ return ""
159
+
160
+
161
+ def _esc(text: object) -> str:
162
+ """Escape HTML special characters."""
163
+ return str(text).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
164
+
165
+
166
+ def _format_diagnostician(state: dict) -> str:
167
+ diag = state.get("diagnostician_output") or {}
168
+ parts = []
169
+
170
+ # Structured findings
171
+ findings_list = diag.get("findings_list", [])
172
+ if findings_list:
173
+ items = []
174
+ for f in findings_list:
175
+ if isinstance(f, dict):
176
+ name = _esc(f.get("finding", ""))
177
+ desc = _esc(f.get("description", ""))
178
+ source = f.get("source", "").strip().lower()
179
+ source_tag = ""
180
+ if source in ("imaging", "clinical", "both"):
181
+ source_tag = f' <span class="source-tag source-{source}">{_esc(source)}</span>'
182
+ line = f"<li><strong>{name}</strong>{source_tag}: {desc}</li>" if desc else f"<li>{name}{source_tag}</li>"
183
+ items.append(line)
184
+ else:
185
+ items.append(f"<li>{_esc(str(f))}</li>")
186
+ parts.append(f'<div class="findings-section"><strong>Findings</strong><ul>{"".join(items)}</ul></div>')
187
+
188
+ # Differential diagnoses
189
+ differentials = diag.get("differential_diagnoses", [])
190
+ if differentials:
191
+ items = []
192
+ for d in differentials:
193
+ if isinstance(d, dict):
194
+ name = _esc(d.get("diagnosis", ""))
195
+ reason = _esc(d.get("reasoning", ""))
196
+ items.append(f"<li><strong>{name}</strong>: {reason}</li>" if reason else f"<li>{name}</li>")
197
+ else:
198
+ items.append(f"<li>{_esc(str(d))}</li>")
199
+ parts.append(f'<div class="differentials-section"><strong>Differential Diagnoses</strong><ol>{"".join(items)}</ol></div>')
200
+
201
+ # Fallback: raw text if no structured data
202
+ if not parts:
203
+ raw = diag.get("findings", "")
204
+ if raw:
205
+ parts.append(f'<div class="agent-text">{_esc(raw).replace(chr(10), "<br>")}</div>')
206
+
207
+ return "".join(parts)
208
+
209
+
210
+ def _format_bias_detector(state: dict) -> str:
211
+ bias_out = state.get("bias_detector_output") or {}
212
+ parts = []
213
+
214
+ # Discrepancy summary (always show if present)
215
+ disc = bias_out.get("discrepancy_summary", "")
216
+ if disc:
217
+ parts.append(f'<div class="discrepancy-summary">{_esc(disc)}</div>')
218
+
219
+ # Biases
220
+ biases = bias_out.get("identified_biases", [])
221
+ for b in biases:
222
+ severity = b.get("severity", "").strip().lower()
223
+ bias_type = _esc(b.get("type", "Unknown"))
224
+ evidence = _esc(b.get("evidence", ""))
225
+ source = b.get("source", "").strip().lower()
226
+ if severity in ("low", "medium", "high"):
227
+ sev_tag = f'<span class="severity-tag severity-{severity}">{severity.upper()}</span>'
228
+ else:
229
+ sev_tag = ""
230
+ if source in ("doctor", "ai", "both"):
231
+ src_tag = f'<span class="source-tag source-{source}">{source.upper()}</span>'
232
+ else:
233
+ src_tag = ""
234
+ parts.append(
235
+ f'<div class="bias-item">'
236
+ f'<div class="bias-title">{sev_tag} {src_tag} {bias_type}</div>'
237
+ f'<div class="bias-evidence">{evidence}</div>'
238
+ f'</div>'
239
+ )
240
+
241
+ # Missed findings
242
+ missed = bias_out.get("missed_findings", [])
243
+ if missed:
244
+ items = "".join(f"<li>{_esc(f)}</li>" for f in missed)
245
+ parts.append(f'<div class="missed-findings"><strong>Missed Findings</strong><ul>{items}</ul></div>')
246
+
247
+ # SigLIP sign verification
248
+ sign_results = bias_out.get("consistency_check", [])
249
+ if isinstance(sign_results, list) and sign_results:
250
+ meaningful = [r for r in sign_results if r.get("confidence") != "inconclusive"]
251
+ if meaningful:
252
+ items = []
253
+ for r in meaningful:
254
+ conf = r.get("confidence", "?")
255
+ sign = _esc(r.get("sign", "?"))
256
+ css_cls = "sign-present" if "present" in conf else "sign-absent"
257
+ items.append(f'<li class="{css_cls}"><strong>{sign}</strong> — {conf}</li>')
258
+ parts.append(
259
+ f'<div class="siglip-section">'
260
+ f'<strong>Image Verification (MedSigLIP)</strong>'
261
+ f'<ul>{"".join(items)}</ul>'
262
+ f'</div>'
263
+ )
264
+
265
+ return "".join(parts)
266
+
267
+
268
+ def _format_devil_advocate(state: dict) -> str:
269
+ da_out = state.get("devils_advocate_output") or {}
270
+ parts = []
271
+
272
+ # Must-not-miss
273
+ mnm = da_out.get("must_not_miss", [])
274
+ for m in mnm:
275
+ dx = _esc(m.get("diagnosis", "?"))
276
+ why = _esc(m.get("why_dangerous", ""))
277
+ signs = _esc(m.get("supporting_signs", ""))
278
+ test = _esc(m.get("rule_out_test", ""))
279
+ details = ""
280
+ if why:
281
+ details += f"<li><strong>Why dangerous:</strong> {why}</li>"
282
+ if signs:
283
+ details += f"<li><strong>Supporting signs:</strong> {signs}</li>"
284
+ if test:
285
+ details += f"<li><strong>Rule-out test:</strong> {test}</li>"
286
+ parts.append(
287
+ f'<div class="mnm-item">'
288
+ f'<div class="mnm-title">{dx}</div>'
289
+ f'<ul>{details}</ul>'
290
+ f'</div>'
291
+ )
292
+
293
+ # Challenges
294
+ challenges = da_out.get("challenges", [])
295
+ if challenges:
296
+ for c in challenges:
297
+ claim = _esc(c.get("claim", ""))
298
+ counter = _esc(c.get("counter_evidence", ""))
299
+ parts.append(
300
+ f'<div class="challenge-item">'
301
+ f'<div class="challenge-claim">{claim}</div>'
302
+ f'<div class="challenge-counter">{counter}</div>'
303
+ f'</div>'
304
+ )
305
+
306
+ # Recommended workup
307
+ workup = da_out.get("recommended_workup", [])
308
+ if workup:
309
+ items = "".join(f"<li>{_esc(str(w))}</li>" for w in workup)
310
+ parts.append(f'<div class="workup-section"><strong>Recommended Workup</strong><ul>{items}</ul></div>')
311
+
312
+ # Fallback: ensure non-empty so the collapsible block can expand
313
+ if not parts:
314
+ parts.append('<div class="agent-text">No structured challenges parsed.</div>')
315
+
316
+ return "".join(parts)
317
+
318
+
319
+ def _format_consultant(state: dict) -> str:
320
+ ref = state.get("consultant_output") or {}
321
+ da_out = state.get("devils_advocate_output") or {}
322
+ parts = []
323
+
324
+ # Consultation note — the main human-readable report
325
+ note = ref.get("consultation_note", "")
326
+ if note:
327
+ paragraphs = _esc(note).split("\n")
328
+ formatted = "".join(f"<p>{p.strip()}</p>" for p in paragraphs if p.strip())
329
+ parts.append(f'<div class="consultation-note">{formatted}</div>')
330
+
331
+ # Alternative diagnoses to consider
332
+ alt_raw = ref.get("alternative_diagnoses", "")
333
+ if alt_raw:
334
+ try:
335
+ alts = json.loads(alt_raw) if isinstance(alt_raw, str) else alt_raw
336
+ if not isinstance(alts, list):
337
+ alts = []
338
+ if alts:
339
+ items = []
340
+ for a in alts:
341
+ urgency_raw = str(a.get("urgency", "")).strip().lower()
342
+ urgency = urgency_raw if urgency_raw in {"critical", "high", "moderate"} else "moderate"
343
+ urgency_label = urgency.upper()
344
+ dx = _esc(a.get("diagnosis", "?"))
345
+ ev = _esc(a.get("evidence", ""))
346
+ ns = _esc(a.get("next_step", ""))
347
+ detail = f" — {ev}" if ev else ""
348
+ step = f"<br><em>Next step: {ns}</em>" if ns else ""
349
+ items.append(
350
+ f'<li><span class="urgency-tag urgency-{urgency}">{urgency_label}</span> '
351
+ f"<strong>{dx}</strong>{detail}{step}</li>"
352
+ )
353
+ parts.append(f'<div class="alt-diagnoses"><strong>Consider</strong><ul>{"".join(items)}</ul></div>')
354
+ except (json.JSONDecodeError, TypeError):
355
+ pass
356
+
357
+ # Immediate actions (merged from Devil's Advocate + Consultant)
358
+ workup = da_out.get("recommended_workup", []) if isinstance(da_out, dict) else []
359
+ actions = ref.get("immediate_actions", [])
360
+ safe_workup = [str(x).strip() for x in workup if str(x).strip()]
361
+ safe_actions = [str(x).strip() for x in actions if str(x).strip()]
362
+ all_items = list(dict.fromkeys(safe_workup + safe_actions))
363
+ if all_items:
364
+ items = "".join(f"<li>{_esc(item)}</li>" for item in all_items)
365
+ parts.append(f'<div class="next-steps"><strong>Recommended Actions</strong><ul>{items}</ul></div>')
366
+
367
+ # Confidence note
368
+ if ref.get("confidence_note"):
369
+ parts.append(f'<div class="confidence-note"><em>{_esc(ref["confidence_note"])}</em></div>')
370
+
371
+ return "".join(parts)
372
+
373
+
374
+ def transcribe_audio(audio, existing_context: str = ""):
375
+ """
376
+ Transcribe audio input using MedASR.
377
+
378
+ Generator that yields (context_text, status_html) for streaming UI feedback.
379
+ Appends transcribed text to any existing context.
380
+ """
381
+ def _status_html(cls: str, text: str) -> str:
382
+ return f'<div class="voice-status {cls}">{text}</div>'
383
+
384
+ if audio is None:
385
+ yield existing_context, _status_html("voice-idle", "No audio recorded. Click the microphone to start.")
386
+ return
387
+
388
+ if not ENABLE_MEDASR:
389
+ yield existing_context, _status_html("voice-error", "MedASR is disabled (set ENABLE_MEDASR=true)")
390
+ return
391
+
392
+ # Step 1: Show processing state
393
+ sr, audio_data = audio
394
+ duration = len(audio_data) / sr if sr > 0 else 0
395
+ yield existing_context, _status_html(
396
+ "voice-processing",
397
+ f'<span class="pulse-dot"></span> Transcribing {duration:.1f}s of audio with MedASR...'
398
+ )
399
+
400
+ try:
401
+ from models import medasr_client
402
+
403
+ # Convert to float32 mono
404
+ if audio_data.dtype != np.float32:
405
+ if np.issubdtype(audio_data.dtype, np.integer):
406
+ audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
407
+ else:
408
+ audio_data = audio_data.astype(np.float32)
409
+ if audio_data.ndim > 1:
410
+ audio_data = audio_data.mean(axis=1)
411
+
412
+ # Resample to 16kHz if needed (MedASR expects 16000Hz)
413
+ target_sr = 16000
414
+ if sr != target_sr:
415
+ from scipy.signal import resample
416
+ num_samples = int(len(audio_data) * target_sr / sr)
417
+ audio_data = resample(audio_data, num_samples).astype(np.float32)
418
+ sr = target_sr
419
+
420
+ # Step 2: Run transcription
421
+ text = medasr_client.transcribe(audio_data, sampling_rate=sr)
422
+
423
+ if not text.strip():
424
+ yield existing_context, _status_html("voice-error", "No speech detected. Please try again.")
425
+ return
426
+
427
+ # Step 3: Append to existing context
428
+ if existing_context.strip():
429
+ new_context = existing_context.rstrip() + "\n\n" + text
430
+ else:
431
+ new_context = text
432
+
433
+ word_count = len(text.split())
434
+ yield new_context, _status_html(
435
+ "voice-success",
436
+ f'✓ Transcribed {word_count} words ({duration:.1f}s) — text added to context above'
437
+ )
438
+
439
+ except Exception as e:
440
+ logger.exception("MedASR transcription failed")
441
+ yield existing_context, _status_html("voice-error", f"Transcription failed: {e}")
442
+
443
+
444
+ def load_demo(demo_name: str | None):
445
+ """Load a demo case into the UI inputs."""
446
+ if demo_name is None or demo_name not in DEMO_CASES:
447
+ return None, "", "", "CXR"
448
+
449
+ case = DEMO_CASES[demo_name]
450
+ image_path = os.path.join(DEMO_CASES_DIR, case["image_file"])
451
+
452
+ image = None
453
+ if os.path.exists(image_path):
454
+ image = Image.open(image_path)
455
+ else:
456
+ logger.warning("Demo image not found: %s", image_path)
457
+
458
+ modality = case.get("modality") or "CXR"
459
+ return image, case["diagnosis"], case["context"], modality
ui/components.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio UI layout for Diagnostic Devil's Advocate."""
2
+
3
+ import gradio as gr
4
+ from config import ENABLE_MEDASR
5
+
6
+
7
+ def build_ui(analyze_fn, load_demo_fn, transcribe_fn=None):
8
+ """
9
+ Build the Gradio Blocks UI.
10
+
11
+ Args:
12
+ analyze_fn: generator(image, diagnosis, context, modality) -> yields HTML
13
+ load_demo_fn: callback(demo_name) -> (image, diagnosis, context, modality)
14
+ transcribe_fn: callback(audio, existing_context) -> yields (context, status_html) (optional)
15
+ """
16
+ with gr.Blocks(title="Diagnostic Devil's Advocate") as demo:
17
+
18
+ # ── Hero Banner ──
19
+ gr.HTML(
20
+ """
21
+ <div class="hero-banner">
22
+ <div class="hero-badge">MedGemma Impact Challenge</div>
23
+ <h1>Diagnostic Devil's Advocate</h1>
24
+ <p class="hero-sub">AI-Powered Cognitive Debiasing for Medical Image Interpretation</p>
25
+ <p class="hero-desc">Upload a medical image with the working diagnosis.
26
+ Four AI agents will independently analyze it, detect cognitive biases,
27
+ challenge the diagnosis, and synthesize a debiasing report.</p>
28
+ <div class="hero-models">
29
+ <span class="model-chip">MedGemma 4B</span>
30
+ <span class="model-chip">MedSigLIP</span>
31
+ <span class="model-chip">LangGraph</span>
32
+ <span class="model-chip">MedASR</span>
33
+ </div>
34
+ </div>
35
+ """
36
+ )
37
+
38
+ # ── Demo Cases Row (3 clickable cards) ──
39
+ gr.HTML('<div class="section-label">SELECT A DEMO CASE</div>')
40
+
41
+ with gr.Row(elem_classes=["case-row"]):
42
+ demo_btn_1 = gr.Button(
43
+ value="",
44
+ elem_id="case-btn-1",
45
+ elem_classes=["case-card-btn"],
46
+ )
47
+ demo_btn_2 = gr.Button(
48
+ value="",
49
+ elem_id="case-btn-2",
50
+ elem_classes=["case-card-btn"],
51
+ )
52
+ demo_btn_3 = gr.Button(
53
+ value="",
54
+ elem_id="case-btn-3",
55
+ elem_classes=["case-card-btn"],
56
+ )
57
+
58
+ # Overlay HTML on top of buttons for card visuals
59
+ gr.HTML("""
60
+ <div class="case-cards-overlay">
61
+ <div class="case-card case-card-pneumo" onclick="document.querySelector('#case-btn-1').click()">
62
+ <div class="card-top">
63
+ <span class="case-icon">🫁</span>
64
+ <span class="case-tag tag-blue">TRAUMA</span>
65
+ </div>
66
+ <div class="case-title">Missed Pneumothorax</div>
67
+ <div class="case-meta">32-year-old Male</div>
68
+ <div class="case-desc">Motorcycle collision · Left chest pain · HR 104 · SpO₂ 96%</div>
69
+ <div class="case-misdiag">
70
+ <span class="misdiag-label">Initial Dx:</span>
71
+ <span class="misdiag-value">Rib contusion</span>
72
+ </div>
73
+ </div>
74
+ <div class="case-card case-card-aorta" onclick="document.querySelector('#case-btn-2').click()">
75
+ <div class="card-top">
76
+ <span class="case-icon">🫀</span>
77
+ <span class="case-tag tag-red">VASCULAR</span>
78
+ </div>
79
+ <div class="case-title">Aortic Dissection</div>
80
+ <div class="case-meta">58-year-old Male</div>
81
+ <div class="case-desc">Sudden chest→back pain · BP asymmetry 32mmHg · D-dimer 4850</div>
82
+ <div class="case-misdiag">
83
+ <span class="misdiag-label">Initial Dx:</span>
84
+ <span class="misdiag-value">GERD / Reflux</span>
85
+ </div>
86
+ </div>
87
+ <div class="case-card case-card-pe" onclick="document.querySelector('#case-btn-3').click()">
88
+ <div class="card-top">
89
+ <span class="case-icon">🩸</span>
90
+ <span class="case-tag tag-purple">POSTPARTUM</span>
91
+ </div>
92
+ <div class="case-title">Pulmonary Embolism</div>
93
+ <div class="case-meta">29-year-old Female</div>
94
+ <div class="case-desc">5 days post C-section · HR 118 · SpO₂ 91% · pO₂ 68</div>
95
+ <div class="case-misdiag">
96
+ <span class="misdiag-label">Initial Dx:</span>
97
+ <span class="misdiag-value">Postpartum anxiety</span>
98
+ </div>
99
+ </div>
100
+ </div>
101
+ """)
102
+
103
+ # ── Main Content: Input + Output ──
104
+ with gr.Row(equal_height=False):
105
+
106
+ # ═══════════ Left Column: Input ═══════════
107
+ with gr.Column(scale=4, min_width=340):
108
+ gr.HTML('<div class="section-label">CLINICAL INPUT</div>')
109
+
110
+ image_input = gr.Image(
111
+ type="pil",
112
+ label="Medical Image",
113
+ height=240,
114
+ )
115
+ modality_input = gr.Radio(
116
+ choices=["CXR", "CT", "Other"],
117
+ value="CXR",
118
+ label="Imaging Modality",
119
+ )
120
+ diagnosis_input = gr.Textbox(
121
+ label="Doctor's Working Diagnosis",
122
+ placeholder="e.g., Left rib contusion with musculoskeletal chest wall pain",
123
+ )
124
+ context_input = gr.Textbox(
125
+ label="Clinical Context (history, vitals, labs, exam)",
126
+ placeholder=(
127
+ "e.g., 32M, motorcycle accident, left-sided chest pain, "
128
+ "HR 104, SpO2 96%, WBC 11.2..."
129
+ ),
130
+ lines=5,
131
+ )
132
+
133
+ # ── Voice Input (MedASR) ──
134
+ if ENABLE_MEDASR and transcribe_fn:
135
+ gr.HTML("""
136
+ <div class="voice-section">
137
+ <div class="voice-header">
138
+ <span class="voice-icon">🎙️</span>
139
+ <span class="voice-title">Voice Input</span>
140
+ <span class="voice-badge">MedASR</span>
141
+ </div>
142
+ <div class="voice-hint">Record clinical context with your microphone.
143
+ Text will be appended to the context field above.</div>
144
+ </div>
145
+ """)
146
+ with gr.Row(elem_classes=["voice-row"]):
147
+ audio_input = gr.Audio(
148
+ sources=["microphone"],
149
+ type="numpy",
150
+ label="",
151
+ show_label=False,
152
+ elem_classes=["voice-audio"],
153
+ )
154
+ with gr.Column(scale=1, min_width=160):
155
+ transcribe_btn = gr.Button(
156
+ "Transcribe",
157
+ size="sm",
158
+ elem_classes=["transcribe-btn"],
159
+ )
160
+ voice_status = gr.HTML(
161
+ value='<div class="voice-status voice-idle">Ready to record</div>',
162
+ )
163
+ else:
164
+ gr.HTML(
165
+ '<div class="voice-status voice-idle">Voice input disabled (MedASR)</div>'
166
+ )
167
+
168
+ analyze_btn = gr.Button(
169
+ "Analyze & Challenge Diagnosis",
170
+ variant="primary",
171
+ size="lg",
172
+ elem_classes=["analyze-btn"],
173
+ )
174
+
175
+ # ═══════════ Right Column: Pipeline Output ═══════════
176
+ with gr.Column(scale=6, min_width=500):
177
+
178
+ gr.HTML('<div class="section-label">PIPELINE OUTPUT</div>')
179
+ pipeline_output = gr.HTML(
180
+ value=_initial_progress_html(),
181
+ )
182
+
183
+ # ── Footer ──
184
+ gr.HTML(
185
+ """
186
+ <div class="footer-text">
187
+ <span>Built with</span>
188
+ <span class="footer-chip">MedGemma</span>
189
+ <span class="footer-chip">MedSigLIP</span>
190
+ <span class="footer-chip">LangGraph</span>
191
+ <span class="footer-chip">Gradio</span>
192
+ <span class="footer-sep">|</span>
193
+ <span>MedGemma Impact Challenge 2025</span>
194
+ <span class="footer-sep">|</span>
195
+ <span>Research & educational use only</span>
196
+ </div>
197
+ """
198
+ )
199
+
200
+ # ═══════════ Wire Callbacks ═══════════
201
+
202
+ analyze_btn.click(
203
+ fn=analyze_fn,
204
+ inputs=[image_input, diagnosis_input, context_input, modality_input],
205
+ outputs=[pipeline_output],
206
+ )
207
+
208
+ demo_btn_1.click(
209
+ fn=lambda: load_demo_fn("Case 1: Missed Pneumothorax"),
210
+ inputs=[],
211
+ outputs=[image_input, diagnosis_input, context_input, modality_input],
212
+ )
213
+ demo_btn_2.click(
214
+ fn=lambda: load_demo_fn("Case 2: Aortic Dissection"),
215
+ inputs=[],
216
+ outputs=[image_input, diagnosis_input, context_input, modality_input],
217
+ )
218
+ demo_btn_3.click(
219
+ fn=lambda: load_demo_fn("Case 3: Pulmonary Embolism"),
220
+ inputs=[],
221
+ outputs=[image_input, diagnosis_input, context_input, modality_input],
222
+ )
223
+
224
+ # Voice transcription — outputs to context field + status indicator
225
+ if ENABLE_MEDASR and transcribe_fn:
226
+ transcribe_btn.click(
227
+ fn=transcribe_fn,
228
+ inputs=[audio_input, context_input],
229
+ outputs=[context_input, voice_status],
230
+ )
231
+
232
+ return demo
233
+
234
+
235
+ def _initial_progress_html() -> str:
236
+ """Static initial progress bar HTML."""
237
+ return _build_progress_html([], None, None, {})
238
+
239
+
240
+ def _build_progress_html(
241
+ completed: list[str],
242
+ active: str | None,
243
+ error: str | None,
244
+ agent_outputs: dict[str, str] | None = None,
245
+ ) -> str:
246
+ """Build pipeline output: progress bar + each agent's result inline.
247
+
248
+ Args:
249
+ agent_outputs: {agent_id: html_content} for completed agents.
250
+ """
251
+ if agent_outputs is None:
252
+ agent_outputs = {}
253
+
254
+ agents = [
255
+ ("diagnostician", "Diagnostician", "Independent image analysis"),
256
+ ("bias_detector", "Bias Detector", "Cognitive bias identification"),
257
+ ("devil_advocate", "Devil's Advocate", "Adversarial challenge"),
258
+ ("consultant", "Consultant", "Consultation synthesis"),
259
+ ]
260
+
261
+ n_done = len(completed)
262
+ pct = int(n_done / len(agents) * 100)
263
+
264
+ bar_color = "#ef4444" if error else "#3b82f6"
265
+ html = f"""
266
+ <div class="progress-container">
267
+ <div class="progress-bar-track">
268
+ <div class="progress-bar-fill" style="width:{pct}%;background:{bar_color};">
269
+ {pct}%
270
+ </div>
271
+ </div>
272
+ <div class="pipeline-agents">
273
+ """
274
+
275
+ for agent_id, name, desc in agents:
276
+ if agent_id == error:
277
+ cls = "step-error"
278
+ icon = "✗"
279
+ status = "Failed"
280
+ elif agent_id in completed:
281
+ cls = "step-done"
282
+ icon = "✓"
283
+ status = "Complete"
284
+ elif agent_id == active:
285
+ cls = "step-active"
286
+ icon = "⟳"
287
+ status = desc
288
+ else:
289
+ cls = "step-waiting"
290
+ icon = "○"
291
+ status = "Waiting"
292
+
293
+ content = agent_outputs.get(agent_id, "")
294
+ if content:
295
+ # Collapsible: <details open> with header as <summary>
296
+ html += f"""
297
+ <details class="agent-block {cls}" open>
298
+ <summary class="agent-header">
299
+ <span class="step-icon">{icon}</span>
300
+ <span class="step-name">{name}</span>
301
+ <span class="step-status">{status}</span>
302
+ </summary>
303
+ <div class="agent-output">{content}</div>
304
+ </details>"""
305
+ else:
306
+ # No output yet — just show the header (not collapsible)
307
+ html += f"""
308
+ <div class="agent-block {cls}">
309
+ <div class="agent-header">
310
+ <span class="step-icon">{icon}</span>
311
+ <span class="step-name">{name}</span>
312
+ <span class="step-status">{status}</span>
313
+ </div>
314
+ </div>"""
315
+
316
+ html += "</div></div>"
317
+ return html
ui/css.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom CSS for the Diagnostic Devil's Advocate UI."""
2
+
3
+ CUSTOM_CSS = """
4
+ /* ===== Global ===== */
5
+ .gradio-container {
6
+ max-width: 1320px !important;
7
+ margin: 0 auto !important;
8
+ font-family: 'Inter', 'Segoe UI', system-ui, -apple-system, sans-serif !important;
9
+ width: 100% !important;
10
+ box-sizing: border-box !important;
11
+ overflow-x: hidden;
12
+ }
13
+
14
+ /* ===== Hero Banner ===== */
15
+ .hero-banner {
16
+ text-align: center;
17
+ padding: 36px 28px 24px;
18
+ border-radius: 16px;
19
+ background: linear-gradient(135deg, #0f172a 0%, #1e293b 40%, #0f3460 100%);
20
+ color: #fff;
21
+ margin-bottom: 22px;
22
+ box-shadow: 0 4px 28px rgba(0, 0, 0, 0.2);
23
+ position: relative;
24
+ overflow: hidden;
25
+ }
26
+ .hero-banner::after {
27
+ content: '';
28
+ position: absolute;
29
+ top: -50%;
30
+ right: -20%;
31
+ width: 400px;
32
+ height: 400px;
33
+ background: radial-gradient(circle, rgba(59,130,246,0.08) 0%, transparent 70%);
34
+ pointer-events: none;
35
+ }
36
+ .hero-badge {
37
+ display: inline-block;
38
+ padding: 4px 14px;
39
+ border-radius: 20px;
40
+ background: rgba(59, 130, 246, 0.2);
41
+ border: 1px solid rgba(59, 130, 246, 0.3);
42
+ color: #93c5fd;
43
+ font-size: 0.72rem;
44
+ font-weight: 600;
45
+ letter-spacing: 0.5px;
46
+ text-transform: uppercase;
47
+ margin-bottom: 12px;
48
+ }
49
+ .hero-banner h1 {
50
+ margin: 0 0 8px;
51
+ font-size: 2.2rem;
52
+ font-weight: 800;
53
+ letter-spacing: -0.8px;
54
+ color: #fff !important;
55
+ border: none !important;
56
+ }
57
+ .hero-sub {
58
+ margin: 0 0 6px;
59
+ font-size: 1.02rem;
60
+ color: #93c5fd;
61
+ font-weight: 500;
62
+ }
63
+ .hero-desc {
64
+ margin: 0 auto;
65
+ font-size: 0.86rem;
66
+ color: #94a3b8;
67
+ max-width: 660px;
68
+ line-height: 1.55;
69
+ }
70
+ .hero-models {
71
+ margin-top: 16px;
72
+ display: flex;
73
+ justify-content: center;
74
+ gap: 8px;
75
+ flex-wrap: wrap;
76
+ }
77
+ .model-chip {
78
+ display: inline-block;
79
+ padding: 3px 12px;
80
+ border-radius: 6px;
81
+ background: rgba(255,255,255,0.08);
82
+ border: 1px solid rgba(255,255,255,0.12);
83
+ color: #cbd5e1;
84
+ font-size: 0.72rem;
85
+ font-weight: 600;
86
+ }
87
+
88
+ /* ===== Section Label ===== */
89
+ .section-label {
90
+ font-size: 0.7rem;
91
+ font-weight: 700;
92
+ color: #64748b;
93
+ letter-spacing: 1.2px;
94
+ text-transform: uppercase;
95
+ margin: 16px 0 10px;
96
+ padding-left: 2px;
97
+ }
98
+
99
+ /* ===== Demo Case: Hidden buttons ===== */
100
+ .case-row {
101
+ height: 0 !important;
102
+ overflow: hidden !important;
103
+ margin: 0 !important;
104
+ padding: 0 !important;
105
+ gap: 0 !important;
106
+ }
107
+ .case-card-btn {
108
+ opacity: 0 !important;
109
+ position: absolute !important;
110
+ pointer-events: none !important;
111
+ height: 0 !important;
112
+ padding: 0 !important;
113
+ margin: 0 !important;
114
+ }
115
+
116
+ /* ===== Demo Case Cards (overlay, clickable) ===== */
117
+ .case-cards-overlay {
118
+ display: grid;
119
+ grid-template-columns: 1fr 1fr 1fr;
120
+ gap: 14px;
121
+ margin-bottom: 20px;
122
+ }
123
+ .case-card {
124
+ border-radius: 14px;
125
+ padding: 18px 16px 14px;
126
+ background: #fff;
127
+ border: 1.5px solid #e2e8f0;
128
+ cursor: pointer;
129
+ transition: all 0.22s ease;
130
+ position: relative;
131
+ overflow: hidden;
132
+ }
133
+ .case-card:hover {
134
+ border-color: #93c5fd;
135
+ box-shadow: 0 6px 24px rgba(59, 130, 246, 0.12);
136
+ transform: translateY(-3px);
137
+ }
138
+ .case-card:active {
139
+ transform: translateY(0);
140
+ box-shadow: 0 2px 8px rgba(59, 130, 246, 0.15);
141
+ }
142
+ .case-card::before {
143
+ content: '';
144
+ position: absolute;
145
+ top: 0;
146
+ left: 0;
147
+ right: 0;
148
+ height: 4px;
149
+ }
150
+ .case-card-pneumo::before { background: linear-gradient(90deg, #3b82f6, #60a5fa); }
151
+ .case-card-aorta::before { background: linear-gradient(90deg, #ef4444, #f87171); }
152
+ .case-card-pe::before { background: linear-gradient(90deg, #8b5cf6, #a78bfa); }
153
+
154
+ .card-top {
155
+ display: flex;
156
+ align-items: center;
157
+ justify-content: space-between;
158
+ margin-bottom: 10px;
159
+ }
160
+ .case-icon { font-size: 1.8rem; }
161
+ .case-tag {
162
+ font-size: 0.6rem;
163
+ font-weight: 700;
164
+ letter-spacing: 0.8px;
165
+ padding: 3px 8px;
166
+ border-radius: 5px;
167
+ text-transform: uppercase;
168
+ }
169
+ .tag-blue { background: #dbeafe; color: #1d4ed8; }
170
+ .tag-red { background: #fee2e2; color: #b91c1c; }
171
+ .tag-purple { background: #f3e8ff; color: #7c3aed; }
172
+
173
+ .case-title {
174
+ font-size: 1rem;
175
+ font-weight: 700;
176
+ color: #1e293b;
177
+ margin-bottom: 2px;
178
+ }
179
+ .case-meta {
180
+ font-size: 0.78rem;
181
+ color: #64748b;
182
+ margin-bottom: 6px;
183
+ }
184
+ .case-desc {
185
+ font-size: 0.74rem;
186
+ color: #94a3b8;
187
+ line-height: 1.4;
188
+ margin-bottom: 10px;
189
+ }
190
+ .case-misdiag {
191
+ padding-top: 8px;
192
+ border-top: 1px solid #f1f5f9;
193
+ font-size: 0.74rem;
194
+ }
195
+ .misdiag-label {
196
+ color: #94a3b8;
197
+ }
198
+ .misdiag-value {
199
+ color: #dc2626;
200
+ font-weight: 700;
201
+ }
202
+
203
+ /* ===== Voice Input Section ===== */
204
+ .voice-section {
205
+ margin-top: 12px;
206
+ padding: 12px 14px 8px;
207
+ background: linear-gradient(135deg, #fafbff 0%, #f0f4ff 100%);
208
+ border: 1.5px solid #dbeafe;
209
+ border-radius: 12px 12px 0 0;
210
+ border-bottom: none;
211
+ }
212
+ .voice-header {
213
+ display: flex;
214
+ align-items: center;
215
+ gap: 8px;
216
+ margin-bottom: 4px;
217
+ }
218
+ .voice-icon { font-size: 1.2rem; }
219
+ .voice-title {
220
+ font-size: 0.88rem;
221
+ font-weight: 700;
222
+ color: #1e293b;
223
+ }
224
+ .voice-badge {
225
+ font-size: 0.6rem;
226
+ font-weight: 700;
227
+ padding: 2px 8px;
228
+ border-radius: 4px;
229
+ background: #dbeafe;
230
+ color: #1d4ed8;
231
+ letter-spacing: 0.5px;
232
+ text-transform: uppercase;
233
+ }
234
+ .voice-hint {
235
+ font-size: 0.74rem;
236
+ color: #94a3b8;
237
+ line-height: 1.4;
238
+ }
239
+ .voice-row {
240
+ gap: 10px !important;
241
+ align-items: stretch !important;
242
+ margin-bottom: 6px !important;
243
+ }
244
+ .voice-audio {
245
+ border-radius: 0 0 0 12px !important;
246
+ }
247
+ .transcribe-btn {
248
+ border-radius: 8px !important;
249
+ font-weight: 600 !important;
250
+ border: 1.5px solid #3b82f6 !important;
251
+ background: #eff6ff !important;
252
+ color: #1d4ed8 !important;
253
+ transition: all 0.15s ease !important;
254
+ }
255
+ .transcribe-btn:hover {
256
+ background: #dbeafe !important;
257
+ box-shadow: 0 2px 8px rgba(59, 130, 246, 0.15) !important;
258
+ }
259
+
260
+ /* Voice status indicators */
261
+ .voice-status {
262
+ font-size: 0.76rem;
263
+ padding: 6px 10px;
264
+ border-radius: 8px;
265
+ margin-top: 6px;
266
+ text-align: center;
267
+ line-height: 1.4;
268
+ }
269
+ .voice-idle {
270
+ background: #f8fafc;
271
+ color: #94a3b8;
272
+ border: 1px solid #e2e8f0;
273
+ }
274
+ .voice-processing {
275
+ background: #eff6ff;
276
+ color: #1d4ed8;
277
+ border: 1px solid #bfdbfe;
278
+ display: flex;
279
+ align-items: center;
280
+ justify-content: center;
281
+ gap: 6px;
282
+ }
283
+ .voice-success {
284
+ background: #f0fdf4;
285
+ color: #166534;
286
+ border: 1px solid #bbf7d0;
287
+ font-weight: 600;
288
+ }
289
+ .voice-error {
290
+ background: #fef2f2;
291
+ color: #991b1b;
292
+ border: 1px solid #fecaca;
293
+ }
294
+
295
+ /* Pulsing dot for processing state */
296
+ .pulse-dot {
297
+ display: inline-block;
298
+ width: 8px;
299
+ height: 8px;
300
+ border-radius: 50%;
301
+ background: #3b82f6;
302
+ animation: pulse-dot-anim 1s ease-in-out infinite;
303
+ }
304
+ @keyframes pulse-dot-anim {
305
+ 0%, 100% { opacity: 1; transform: scale(1); }
306
+ 50% { opacity: 0.4; transform: scale(0.7); }
307
+ }
308
+
309
+ /* ===== Analyze Button ===== */
310
+ .analyze-btn {
311
+ width: 100% !important;
312
+ border-radius: 12px !important;
313
+ font-size: 1.05rem !important;
314
+ font-weight: 700 !important;
315
+ padding: 14px !important;
316
+ background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important;
317
+ box-shadow: 0 4px 14px rgba(37, 99, 235, 0.3) !important;
318
+ transition: all 0.2s ease !important;
319
+ margin-top: 10px !important;
320
+ letter-spacing: -0.2px !important;
321
+ }
322
+ .analyze-btn:hover {
323
+ box-shadow: 0 6px 24px rgba(37, 99, 235, 0.4) !important;
324
+ transform: translateY(-2px) !important;
325
+ }
326
+
327
+ /* ===== Progress Bar ===== */
328
+ .progress-container {
329
+ margin: 4px 0 8px;
330
+ }
331
+ .progress-bar-track {
332
+ height: 28px;
333
+ background: #f1f5f9;
334
+ border-radius: 14px;
335
+ overflow: hidden;
336
+ margin-bottom: 14px;
337
+ border: 1px solid #e2e8f0;
338
+ }
339
+ .progress-bar-fill {
340
+ height: 100%;
341
+ border-radius: 14px;
342
+ color: #fff;
343
+ font-size: 0.75rem;
344
+ font-weight: 700;
345
+ display: flex;
346
+ align-items: center;
347
+ justify-content: center;
348
+ transition: width 0.6s ease;
349
+ min-width: 0;
350
+ }
351
+ /* ===== Pipeline Agent Blocks ===== */
352
+ .pipeline-agents {
353
+ display: flex;
354
+ flex-direction: column;
355
+ gap: 8px;
356
+ }
357
+ details.agent-block,
358
+ div.agent-block {
359
+ border-radius: 12px;
360
+ border: 1px solid transparent;
361
+ box-sizing: border-box;
362
+ width: 100%;
363
+ min-width: 0;
364
+ overflow-x: hidden;
365
+ }
366
+ details.agent-block > summary {
367
+ list-style: none;
368
+ cursor: pointer;
369
+ user-select: none;
370
+ }
371
+ details.agent-block > summary::-webkit-details-marker {
372
+ display: none;
373
+ }
374
+ details.agent-block > summary::after {
375
+ content: '▾';
376
+ font-size: 0.7rem;
377
+ color: #94a3b8;
378
+ margin-left: 6px;
379
+ transition: transform 0.2s ease;
380
+ }
381
+ details.agent-block:not([open]) > summary::after {
382
+ transform: rotate(-90deg);
383
+ }
384
+ .agent-header {
385
+ display: flex;
386
+ align-items: center;
387
+ gap: 10px;
388
+ padding: 10px 14px;
389
+ }
390
+ div.agent-block .agent-header {
391
+ cursor: default;
392
+ }
393
+ .step-icon {
394
+ width: 24px;
395
+ height: 24px;
396
+ display: flex;
397
+ align-items: center;
398
+ justify-content: center;
399
+ border-radius: 50%;
400
+ font-size: 0.82rem;
401
+ font-weight: 700;
402
+ flex-shrink: 0;
403
+ }
404
+ .step-name {
405
+ font-size: 0.88rem;
406
+ font-weight: 700;
407
+ }
408
+ .step-status {
409
+ font-size: 0.72rem;
410
+ margin-left: auto;
411
+ }
412
+
413
+ /* Agent output area */
414
+ .agent-output {
415
+ padding: 4px 14px 14px;
416
+ font-size: 0.84rem;
417
+ line-height: 1.6;
418
+ color: #334155;
419
+ border-top: 1px solid rgba(0,0,0,0.06);
420
+ overflow: hidden;
421
+ overflow-wrap: break-word;
422
+ overflow-wrap: anywhere;
423
+ word-break: break-word;
424
+ }
425
+ .agent-output * {
426
+ max-width: 100%;
427
+ box-sizing: border-box;
428
+ }
429
+ .agent-output pre,
430
+ .agent-output code {
431
+ max-width: 100%;
432
+ white-space: pre-wrap;
433
+ word-break: break-word;
434
+ }
435
+ .agent-output ul {
436
+ margin: 6px 0;
437
+ padding-left: 18px;
438
+ }
439
+ .agent-output li {
440
+ margin-bottom: 4px;
441
+ }
442
+ .agent-text {
443
+ overflow-wrap: break-word;
444
+ word-break: break-word;
445
+ }
446
+ .agent-output details summary {
447
+ cursor: pointer;
448
+ color: #475569;
449
+ }
450
+ .findings-section,
451
+ .differentials-section {
452
+ margin-bottom: 6px;
453
+ }
454
+ .findings-section strong,
455
+ .differentials-section strong {
456
+ color: #1e293b;
457
+ }
458
+ .differentials-section ol {
459
+ margin: 6px 0;
460
+ padding-left: 22px;
461
+ }
462
+
463
+ /* Step states */
464
+ .step-done {
465
+ background: #f0fdf4;
466
+ border-color: #bbf7d0;
467
+ }
468
+ .step-done .step-icon {
469
+ background: #22c55e;
470
+ color: #fff;
471
+ }
472
+ .step-done .step-name { color: #166534; }
473
+ .step-done .step-status { color: #4ade80; }
474
+
475
+ .step-active {
476
+ background: #eff6ff;
477
+ border-color: #bfdbfe;
478
+ animation: pulse-border 2s ease-in-out infinite;
479
+ }
480
+ .step-active .step-icon {
481
+ background: #3b82f6;
482
+ color: #fff;
483
+ animation: spin 1.2s linear infinite;
484
+ }
485
+ .step-active .step-name { color: #1d4ed8; }
486
+ .step-active .step-status { color: #60a5fa; }
487
+
488
+ .step-waiting {
489
+ background: #f8fafc;
490
+ border-color: #f1f5f9;
491
+ }
492
+ .step-waiting .step-icon {
493
+ background: #e2e8f0;
494
+ color: #94a3b8;
495
+ }
496
+ .step-waiting .step-name { color: #94a3b8; }
497
+ .step-waiting .step-status { color: #cbd5e1; }
498
+
499
+ .step-error {
500
+ background: #fef2f2;
501
+ border-color: #fecaca;
502
+ }
503
+ .step-error .step-icon {
504
+ background: #ef4444;
505
+ color: #fff;
506
+ }
507
+ .step-error .step-name { color: #991b1b; }
508
+ .step-error .step-status { color: #f87171; }
509
+
510
+ @keyframes pulse-border {
511
+ 0%, 100% { border-color: #bfdbfe; }
512
+ 50% { border-color: #60a5fa; }
513
+ }
514
+ @keyframes spin {
515
+ from { transform: rotate(0deg); }
516
+ to { transform: rotate(360deg); }
517
+ }
518
+
519
+ /* ===== Agent Output Styling ===== */
520
+
521
+ /* Bias Detector */
522
+ .discrepancy-summary {
523
+ padding: 8px 12px;
524
+ margin-bottom: 10px;
525
+ background: #fff7ed;
526
+ border-radius: 8px;
527
+ border: 1px solid #fed7aa;
528
+ color: #9a3412;
529
+ font-size: 0.84rem;
530
+ line-height: 1.5;
531
+ }
532
+ .bias-item {
533
+ margin-bottom: 10px;
534
+ padding: 8px 12px;
535
+ background: #fffbeb;
536
+ border-left: 3px solid #f59e0b;
537
+ border-radius: 0 8px 8px 0;
538
+ }
539
+ .bias-title {
540
+ font-weight: 700;
541
+ color: #92400e;
542
+ margin-bottom: 4px;
543
+ }
544
+ .bias-evidence {
545
+ color: #78716c;
546
+ font-size: 0.82rem;
547
+ }
548
+ .severity-tag {
549
+ display: inline-block;
550
+ padding: 1px 6px;
551
+ border-radius: 4px;
552
+ font-size: 0.65rem;
553
+ font-weight: 800;
554
+ letter-spacing: 0.5px;
555
+ margin-right: 4px;
556
+ vertical-align: middle;
557
+ }
558
+ .severity-high { background: #fee2e2; color: #dc2626; }
559
+ .severity-medium { background: #fff7ed; color: #ea580c; }
560
+ .severity-low { background: #fefce8; color: #ca8a04; }
561
+
562
+ .source-tag {
563
+ display: inline-block;
564
+ padding: 1px 6px;
565
+ border-radius: 4px;
566
+ font-size: 0.65rem;
567
+ font-weight: 800;
568
+ letter-spacing: 0.5px;
569
+ margin-right: 4px;
570
+ vertical-align: middle;
571
+ }
572
+ .source-doctor { background: #dbeafe; color: #1d4ed8; }
573
+ .source-ai { background: #ede9fe; color: #7c3aed; }
574
+ .source-both { background: #e0e7ff; color: #4338ca; }
575
+ .source-imaging { background: #dbeafe; color: #1d4ed8; }
576
+ .source-clinical { background: #fef3c7; color: #b45309; }
577
+
578
+ .missed-findings {
579
+ margin-top: 8px;
580
+ padding: 8px 12px;
581
+ background: #fef2f2;
582
+ border-radius: 8px;
583
+ }
584
+ .missed-findings strong {
585
+ color: #991b1b;
586
+ }
587
+
588
+ /* SigLIP */
589
+ .siglip-section {
590
+ margin-top: 8px;
591
+ padding: 8px 12px;
592
+ background: #f0f9ff;
593
+ border-radius: 8px;
594
+ }
595
+ .siglip-section strong {
596
+ color: #0369a1;
597
+ }
598
+ .sign-present {
599
+ color: #166534;
600
+ }
601
+ .sign-absent {
602
+ color: #94a3b8;
603
+ }
604
+
605
+ /* Devil's Advocate */
606
+ .mnm-item {
607
+ margin-bottom: 10px;
608
+ padding: 8px 12px;
609
+ background: #fef2f2;
610
+ border-left: 3px solid #ef4444;
611
+ border-radius: 0 8px 8px 0;
612
+ }
613
+ .mnm-title {
614
+ font-weight: 700;
615
+ color: #991b1b;
616
+ font-size: 0.92rem;
617
+ }
618
+ .mnm-item ul {
619
+ margin-top: 4px;
620
+ }
621
+ .challenge-item {
622
+ margin-bottom: 8px;
623
+ padding: 8px 12px;
624
+ background: #faf5ff;
625
+ border-left: 3px solid #8b5cf6;
626
+ border-radius: 0 8px 8px 0;
627
+ }
628
+ .challenge-claim {
629
+ font-weight: 700;
630
+ color: #5b21b6;
631
+ margin-bottom: 2px;
632
+ }
633
+ .challenge-counter {
634
+ color: #6b7280;
635
+ font-size: 0.82rem;
636
+ }
637
+
638
+ /* Consultant */
639
+ .consultation-note {
640
+ padding: 12px 16px;
641
+ background: linear-gradient(135deg, #f8fafc, #f0f9ff);
642
+ border-radius: 10px;
643
+ border: 1px solid #e2e8f0;
644
+ color: #1e293b;
645
+ font-size: 0.86rem;
646
+ line-height: 1.65;
647
+ margin-bottom: 10px;
648
+ }
649
+ .consultation-note p {
650
+ margin: 0 0 8px;
651
+ }
652
+ .consultation-note p:last-child {
653
+ margin-bottom: 0;
654
+ }
655
+ .alt-diagnoses {
656
+ margin-bottom: 8px;
657
+ }
658
+ .urgency-tag {
659
+ display: inline-block;
660
+ padding: 1px 6px;
661
+ border-radius: 4px;
662
+ font-size: 0.65rem;
663
+ font-weight: 800;
664
+ letter-spacing: 0.5px;
665
+ vertical-align: middle;
666
+ }
667
+ .urgency-critical { background: #fee2e2; color: #dc2626; }
668
+ .urgency-high { background: #fff7ed; color: #ea580c; }
669
+ .urgency-moderate { background: #fefce8; color: #ca8a04; }
670
+ .urgency-unknown { background: #e2e8f0; color: #475569; }
671
+ .next-steps {
672
+ margin-bottom: 8px;
673
+ }
674
+ .confidence-note {
675
+ padding: 8px 12px;
676
+ background: #f8fafc;
677
+ border-radius: 8px;
678
+ color: #64748b;
679
+ font-size: 0.8rem;
680
+ }
681
+
682
+ /* Pipeline error */
683
+ .pipeline-error {
684
+ padding: 14px;
685
+ background: #fef2f2;
686
+ border: 1px solid #fecaca;
687
+ border-radius: 10px;
688
+ color: #991b1b;
689
+ font-weight: 600;
690
+ }
691
+
692
+ /* ===== Footer ===== */
693
+ .footer-text {
694
+ text-align: center;
695
+ padding: 18px;
696
+ margin-top: 24px;
697
+ font-size: 0.76rem;
698
+ color: #94a3b8;
699
+ border-top: 1px solid #e2e8f0;
700
+ display: flex;
701
+ align-items: center;
702
+ justify-content: center;
703
+ gap: 6px;
704
+ flex-wrap: wrap;
705
+ }
706
+ .footer-chip {
707
+ display: inline-block;
708
+ padding: 1px 8px;
709
+ border-radius: 4px;
710
+ background: #f1f5f9;
711
+ color: #475569;
712
+ font-weight: 600;
713
+ font-size: 0.72rem;
714
+ }
715
+ .footer-sep {
716
+ color: #cbd5e1;
717
+ margin: 0 2px;
718
+ }
719
+
720
+ /* ===== Responsive ===== */
721
+ @media (max-width: 768px) {
722
+ .gradio-container { max-width: 100% !important; }
723
+ .hero-banner h1 { font-size: 1.5rem; }
724
+ .case-cards-overlay { grid-template-columns: 1fr; }
725
+ }
726
+ """