"""Callbacks connecting UI events to the LangGraph pipeline.""" import json import logging import os import numpy as np from PIL import Image from agents.graph import stream_pipeline from config import DEMO_CASES_DIR, ENABLE_MEDASR # ZeroGPU support: use @spaces.GPU when running on HF Spaces, no-op locally try: import spaces SPACES_AVAILABLE = True except ImportError: SPACES_AVAILABLE = False def gpu_decorator(duration: int = 180): """Decorator that uses @spaces.GPU on HF Spaces, no-op locally.""" def decorator(fn): if SPACES_AVAILABLE: return spaces.GPU(duration=duration)(fn) return fn return decorator logger = logging.getLogger(__name__) # Agent display info AGENT_INFO = { "diagnostician": ("Diagnostician", "Analyzing image independently..."), "bias_detector": ("Bias Detector", "Scanning for cognitive biases..."), "devil_advocate": ("Devil's Advocate", "Challenging the diagnosis..."), "consultant": ("Consultant", "Synthesizing consultation report..."), } # Demo case definitions — based on published case reports and clinical literature. # References: # Case 1: PMC3195099 (AP CXR vs CT in trauma pneumothorax detection) # Case 2: PMC6203039 (Acute aortic dissection: a missed diagnosis) # Case 3: PMC10683049 (PE masked by symptoms of mental disorders) DEMO_CASES = { "Case 1: Missed Pneumothorax": { "diagnosis": "Left rib contusion with musculoskeletal chest wall pain", "context": ( "32-year-old male, presented to ED after a motorcycle collision at ~40 mph. " "Helmet worn, no LOC. Chief complaint: left-sided chest pain worse with deep " "inspiration.\n\n" "Vitals: HR 104 bpm, BP 132/84 mmHg, RR 22/min, SpO2 96% on room air, " "Temp 37.1 C.\n\n" "Exam: Tenderness over left 4th-6th ribs, no crepitus, no subcutaneous " "emphysema palpated. Breath sounds reportedly equal bilaterally (noisy ED). " "Mild dyspnea attributed to pain.\n\n" "Labs: WBC 11.2, Hgb 14.1, Lactate 1.8 mmol/L.\n\n" "ED physician ordered AP chest X-ray (supine) — read as 'no acute " "cardiopulmonary abnormality, possible left rib fracture.' Patient was given " "ibuprofen and discharged with rib fracture precautions." ), "image_file": "case1_pneumothorax.png", "modality": "CXR", }, "Case 2: Aortic Dissection": { "diagnosis": "Acute gastroesophageal reflux / esophageal spasm", "context": ( "58-year-old male with a 15-year history of hypertension (poorly controlled, " "non-compliant with amlodipine). Presented to ED with sudden-onset severe " "retrosternal chest pain radiating to the interscapular back region, starting " "30 minutes ago.\n\n" "Vitals: BP 178/102 mmHg (right arm), 146/88 mmHg (left arm), HR 92 bpm, " "RR 20/min, SpO2 97%, Temp 37.0 C.\n\n" "Exam: Diaphoretic, visibly distressed. Abdomen soft, mild epigastric " "tenderness. Heart sounds normal, no murmur. Peripheral pulses intact but " "radial pulse asymmetry noted.\n\n" "Labs: Troponin I <0.01 (negative x2 at 0h and 3h), D-dimer 4,850 ng/mL " "(markedly elevated), WBC 13.4, Creatinine 1.3.\n\n" "ECG: Sinus tachycardia, nonspecific ST changes. Initial CXR ordered. " "ED physician considered ACS (ruled out by troponin), then attributed symptoms " "to acid reflux; prescribed IV pantoprazole and GI cocktail. Pain not relieved." ), "image_file": "case2_aortic_dissection.png", "modality": "CXR", }, "Case 3: Pulmonary Embolism": { "diagnosis": "Postpartum anxiety with hyperventilation syndrome", "context": ( "29-year-old female, G2P2, day 5 after emergency cesarean section (prolonged " "labor, general anesthesia). Presented with acute onset dyspnea and chest " "tightness at rest. Reports feeling of 'impending doom' and inability to catch " "breath.\n\n" "Vitals: HR 118 bpm, BP 108/72 mmHg, RR 28/min, SpO2 91% on room air " "(improved to 95% on 4L NC), Temp 37.3 C.\n\n" "Exam: Anxious-appearing, tachypneic. Lungs clear to auscultation. Mild " "right-sided pleuritic chest pain. Right calf tenderness and mild swelling " "noted but attributed to post-surgical immobility. No Homan sign.\n\n" "Labs: D-dimer 3,200 ng/mL (elevated, but 'expected postpartum'), " "WBC 10.8, Hgb 10.2, ABG on RA: pH 7.48, pO2 68 mmHg, pCO2 29 mmHg.\n\n" "OB team attributed symptoms to postpartum anxiety, prescribed lorazepam " "0.5 mg PRN. Psychiatry consult requested. No CTPA ordered initially." ), "image_file": "case3_pulmonary_embolism.png", "modality": "CXR", }, } @gpu_decorator(duration=900) def analyze_streaming(image: Image.Image | None, diagnosis: str, context: str, modality: str): """ Generator: run pipeline and yield single HTML output after each agent step. Each agent's output appears inline below its progress header. Uses @spaces.GPU on HF Spaces for ZeroGPU support. """ if image is None: yield '
Please upload a medical image.
' return if not diagnosis.strip(): yield '
Please enter the doctor\'s working diagnosis.
' return if not context.strip(): context = "No additional clinical context provided." if not isinstance(modality, str) or not modality.strip(): modality = "CXR" completed = {} agent_outputs = {} all_agents = ["diagnostician", "bias_detector", "devil_advocate", "consultant"] try: yield _build_pipeline(all_agents, completed, agent_outputs, active="diagnostician") accumulated_state = {} for node_name, state_update in stream_pipeline(image, diagnosis.strip(), context.strip(), modality.strip()): completed[node_name] = True accumulated_state.update(state_update) if state_update.get("error"): agent_outputs[node_name] = f'
{_esc(state_update.get("error"))}
' yield _build_pipeline(all_agents, completed, agent_outputs, error=node_name) return # Generate this agent's HTML output agent_outputs[node_name] = _format_agent_output(node_name, accumulated_state) idx = all_agents.index(node_name) if node_name in all_agents else -1 next_active = all_agents[idx + 1] if idx + 1 < len(all_agents) else None yield _build_pipeline(all_agents, completed, agent_outputs, active=next_active) except Exception as e: logger.exception("Pipeline failed") yield f'
Pipeline error: {_esc(e)}
' def _build_pipeline(all_agents, completed, agent_outputs, active=None, error=None) -> str: """Build combined progress + inline output HTML.""" from ui.components import _build_progress_html return _build_progress_html( completed=list(completed.keys()), active=active, error=error, agent_outputs=agent_outputs, ) def _format_agent_output(agent_id: str, state: dict) -> str: """Generate HTML content for a specific agent's output.""" if agent_id == "diagnostician": return _format_diagnostician(state) elif agent_id == "bias_detector": return _format_bias_detector(state) elif agent_id == "devil_advocate": return _format_devil_advocate(state) elif agent_id == "consultant": return _format_consultant(state) return "" def _esc(text: object) -> str: """Escape HTML special characters.""" return str(text).replace("&", "&").replace("<", "<").replace(">", ">") def _format_diagnostician(state: dict) -> str: diag = state.get("diagnostician_output") or {} parts = [] # Structured findings findings_list = diag.get("findings_list", []) if findings_list: items = [] for f in findings_list: if isinstance(f, dict): name = _esc(f.get("finding", "")) desc = _esc(f.get("description", "")) source = f.get("source", "").strip().lower() source_tag = "" if source in ("imaging", "clinical", "both"): source_tag = f' {_esc(source)}' line = f"
  • {name}{source_tag}: {desc}
  • " if desc else f"
  • {name}{source_tag}
  • " items.append(line) else: items.append(f"
  • {_esc(str(f))}
  • ") parts.append(f'
    Findings
    ') # Differential diagnoses differentials = diag.get("differential_diagnoses", []) if differentials: items = [] for d in differentials: if isinstance(d, dict): name = _esc(d.get("diagnosis", "")) reason = _esc(d.get("reasoning", "")) items.append(f"
  • {name}: {reason}
  • " if reason else f"
  • {name}
  • ") else: items.append(f"
  • {_esc(str(d))}
  • ") parts.append(f'
    Differential Diagnoses
      {"".join(items)}
    ') # Fallback: raw text if no structured data if not parts: raw = diag.get("findings", "") if raw: parts.append(f'
    {_esc(raw).replace(chr(10), "
    ")}
    ') return "".join(parts) def _format_bias_detector(state: dict) -> str: bias_out = state.get("bias_detector_output") or {} parts = [] # Discrepancy summary (always show if present) disc = bias_out.get("discrepancy_summary", "") if disc: parts.append(f'
    {_esc(disc)}
    ') # Biases biases = bias_out.get("identified_biases", []) for b in biases: severity = b.get("severity", "").strip().lower() bias_type = _esc(b.get("type", "Unknown")) evidence = _esc(b.get("evidence", "")) source = b.get("source", "").strip().lower() if severity in ("low", "medium", "high"): sev_tag = f'{severity.upper()}' else: sev_tag = "" if source in ("doctor", "ai", "both"): src_tag = f'{source.upper()}' else: src_tag = "" parts.append( f'
    ' f'
    {sev_tag} {src_tag} {bias_type}
    ' f'
    {evidence}
    ' f'
    ' ) # Missed findings missed = bias_out.get("missed_findings", []) if missed: items = "".join(f"
  • {_esc(f)}
  • " for f in missed) parts.append(f'
    Missed Findings
    ') # SigLIP sign verification sign_results = bias_out.get("consistency_check", []) if isinstance(sign_results, list) and sign_results: meaningful = [r for r in sign_results if r.get("confidence") != "inconclusive"] if meaningful: items = [] for r in meaningful: conf = r.get("confidence", "?") sign = _esc(r.get("sign", "?")) css_cls = "sign-present" if "present" in conf else "sign-absent" items.append(f'
  • {sign} — {conf}
  • ') parts.append( f'
    ' f'Image Verification (MedSigLIP)' f'' f'
    ' ) return "".join(parts) def _format_devil_advocate(state: dict) -> str: da_out = state.get("devils_advocate_output") or {} parts = [] # Must-not-miss mnm = da_out.get("must_not_miss", []) for m in mnm: dx = _esc(m.get("diagnosis", "?")) why = _esc(m.get("why_dangerous", "")) signs = _esc(m.get("supporting_signs", "")) test = _esc(m.get("rule_out_test", "")) details = "" if why: details += f"
  • Why dangerous: {why}
  • " if signs: details += f"
  • Supporting signs: {signs}
  • " if test: details += f"
  • Rule-out test: {test}
  • " parts.append( f'
    ' f'
    {dx}
    ' f'' f'
    ' ) # Challenges challenges = da_out.get("challenges", []) if challenges: for c in challenges: claim = _esc(c.get("claim", "")) counter = _esc(c.get("counter_evidence", "")) parts.append( f'
    ' f'
    {claim}
    ' f'
    {counter}
    ' f'
    ' ) # Recommended workup workup = da_out.get("recommended_workup", []) if workup: items = "".join(f"
  • {_esc(str(w))}
  • " for w in workup) parts.append(f'
    Recommended Workup
    ') # Fallback: ensure non-empty so the collapsible block can expand if not parts: parts.append('
    No structured challenges parsed.
    ') return "".join(parts) def _format_consultant(state: dict) -> str: ref = state.get("consultant_output") or {} da_out = state.get("devils_advocate_output") or {} parts = [] # Consultation note — the main human-readable report note = ref.get("consultation_note", "") if note: paragraphs = _esc(note).split("\n") formatted = "".join(f"

    {p.strip()}

    " for p in paragraphs if p.strip()) parts.append(f'
    {formatted}
    ') # Alternative diagnoses to consider alt_raw = ref.get("alternative_diagnoses", "") if alt_raw: try: alts = json.loads(alt_raw) if isinstance(alt_raw, str) else alt_raw if not isinstance(alts, list): alts = [] if alts: items = [] for a in alts: urgency_raw = str(a.get("urgency", "")).strip().lower() urgency = urgency_raw if urgency_raw in {"critical", "high", "moderate"} else "moderate" urgency_label = urgency.upper() dx = _esc(a.get("diagnosis", "?")) ev = _esc(a.get("evidence", "")) ns = _esc(a.get("next_step", "")) detail = f" — {ev}" if ev else "" step = f"
    Next step: {ns}" if ns else "" items.append( f'
  • {urgency_label} ' f"{dx}{detail}{step}
  • " ) parts.append(f'
    Consider
    ') except (json.JSONDecodeError, TypeError): pass # Immediate actions (merged from Devil's Advocate + Consultant) workup = da_out.get("recommended_workup", []) if isinstance(da_out, dict) else [] actions = ref.get("immediate_actions", []) safe_workup = [str(x).strip() for x in workup if str(x).strip()] safe_actions = [str(x).strip() for x in actions if str(x).strip()] all_items = list(dict.fromkeys(safe_workup + safe_actions)) if all_items: items = "".join(f"
  • {_esc(item)}
  • " for item in all_items) parts.append(f'
    Recommended Actions
    ') # Confidence note if ref.get("confidence_note"): parts.append(f'
    {_esc(ref["confidence_note"])}
    ') return "".join(parts) @gpu_decorator(duration=60) def transcribe_audio(audio, existing_context: str = ""): """ Transcribe audio input using MedASR. Generator that yields (context_text, status_html) for streaming UI feedback. Appends transcribed text to any existing context. Uses @spaces.GPU on HF Spaces for ZeroGPU support. """ def _status_html(cls: str, text: str) -> str: return f'
    {text}
    ' if audio is None: yield existing_context, _status_html("voice-idle", "No audio recorded. Click the microphone to start.") return if not ENABLE_MEDASR: yield existing_context, _status_html("voice-error", "MedASR is disabled (set ENABLE_MEDASR=true)") return # Step 1: Show processing state sr, audio_data = audio duration = len(audio_data) / sr if sr > 0 else 0 yield existing_context, _status_html( "voice-processing", f' Transcribing {duration:.1f}s of audio with MedASR...' ) try: from models import medasr_client # Convert to float32 mono if audio_data.dtype != np.float32: if np.issubdtype(audio_data.dtype, np.integer): audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max else: audio_data = audio_data.astype(np.float32) if audio_data.ndim > 1: audio_data = audio_data.mean(axis=1) # Resample to 16kHz if needed (MedASR expects 16000Hz) target_sr = 16000 if sr != target_sr: from scipy.signal import resample num_samples = int(len(audio_data) * target_sr / sr) audio_data = resample(audio_data, num_samples).astype(np.float32) sr = target_sr # Step 2: Run transcription text = medasr_client.transcribe(audio_data, sampling_rate=sr) if not text.strip(): yield existing_context, _status_html("voice-error", "No speech detected. Please try again.") return # Step 3: Append to existing context if existing_context.strip(): new_context = existing_context.rstrip() + "\n\n" + text else: new_context = text word_count = len(text.split()) yield new_context, _status_html( "voice-success", f'✓ Transcribed {word_count} words ({duration:.1f}s) — text added to context above' ) except Exception as e: logger.exception("MedASR transcription failed") yield existing_context, _status_html("voice-error", f"Transcription failed: {e}") def load_demo(demo_name: str | None): """Load a demo case into the UI inputs.""" if demo_name is None or demo_name not in DEMO_CASES: return None, "", "", "CXR" case = DEMO_CASES[demo_name] image_path = os.path.join(DEMO_CASES_DIR, case["image_file"]) image = None if os.path.exists(image_path): image = Image.open(image_path) else: logger.warning("Demo image not found: %s", image_path) modality = case.get("modality") or "CXR" return image, case["diagnosis"], case["context"], modality