""" RadTriage AI — HuggingFace Spaces Live Demo ============================================ Multimodal Radiology Triage & Report Drafting with MedGemma Features: - Live inference: Upload any medical image for real-time analysis - Pre-computed cases: 3 NIH Chest X-ray cases for instant demo Competition: MedGemma Impact Challenge (Kaggle) Tracks: Main Track + Agentic Workflow Prize """ # ── Patch gradio_client schema bug (bool schema crash) ──── import gradio_client.utils as _gcu _orig_get_type = _gcu.get_type def _patched_get_type(schema): if isinstance(schema, bool): return "bool" return _orig_get_type(schema) _gcu.get_type = _patched_get_type _orig_json_schema = _gcu._json_schema_to_python_type def _patched_json_schema(schema, defs=None): if isinstance(schema, bool): return "bool" return _orig_json_schema(schema, defs) _gcu._json_schema_to_python_type = _patched_json_schema # ── End patch ───────────────────────────────────────────── import gradio as gr import json import torch import os import time from PIL import Image # ── Load MedGemma ───────────────────────────────────────── print("Loading MedGemma 4B-IT...") from transformers import AutoProcessor, AutoModelForImageTextToText model_id = "google/medgemma-4b-it" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device} | CUDA available: {torch.cuda.is_available()}") if device == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") processor = AutoProcessor.from_pretrained(model_id) if device == "cuda": model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" ) else: model = AutoModelForImageTextToText.from_pretrained(model_id) print("WARNING: Running on CPU — inference will be extremely slow (30+ min per image)") print(f"MedGemma loaded on {model.device}") # ── Pipeline Functions ──────────────────────────────────── def classify_triage(findings_json): """Classify urgency — only checks abnormal findings for critical keywords.""" critical_keywords = [ "hemorrhage", "pneumothorax", "tension", "stroke", "embolism", "dissection", "perforation", "herniation", "midline shift", "mass effect", "acute infarct", "tamponade" ] findings = findings_json.get("findings", []) abnormal_findings = [f for f in findings if f.get("severity") not in ("normal", None)] has_severe = any(f.get("severity") == "severe" for f in findings) has_moderate = any(f.get("severity") == "moderate" for f in findings) is_critical = findings_json.get("overall_impression_category") == "critical" has_critical_keyword = any( any(kw in f.get("description", "").lower() for kw in critical_keywords) for f in abnormal_findings ) if is_critical or has_critical_keyword or has_severe: level, time_sens = "CRITICAL", "Immediate" elif has_moderate: level, time_sens = "URGENT", "Within 2-4 hours" else: level, time_sens = "ROUTINE", "Within 24 hours" warnings = [] for f in findings: if f.get("confidence", 1.0) < 0.6 and f.get("severity") != "normal": warnings.append(f"Low confidence ({f['confidence']:.0%}): {f['description']}") return level, time_sens, warnings def run_pass1(image, indication="Routine screening", priors="None"): """Pass 1: Image -> Structured Findings JSON.""" print(f"[Pass 1] Starting — indication: {indication}") comparison = "true" if priors != "None" else "false" prompt = f"""You are an expert radiologist assistant analyzing a Chest X-ray image. Clinical indication: {indication} Prior studies: {priors} Specifically evaluate for: pneumothorax, pleural effusion, consolidation, cardiomegaly, mediastinal widening, rib fractures, pulmonary edema, masses, nodules, and lines/tubes. Analyze the image systematically and respond ONLY with a valid JSON object (no markdown, no explanation) following this exact schema: {{ "modality_detected": "", "anatomical_region": "", "image_quality": "", "quality_issues": [], "comparison_available": {comparison}, "findings": [ {{"id": 1, "description": "", "location": "", "severity": "", "confidence": 0.0}} ], "overall_impression_category": "", "missing_context": ["Prior imaging for comparison", "Complete clinical history"] }}""" messages = [{"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ]}] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) print(f"[Pass 1] Input tokens: {inputs['input_ids'].shape[-1]} — generating...") with torch.inference_mode(): output_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False) input_len = inputs["input_ids"].shape[-1] result = processor.decode(output_ids[0][input_len:], skip_special_tokens=True) print(f"[Pass 1] Done — output length: {len(result)} chars") cleaned = result.strip() if cleaned.startswith("```"): cleaned = cleaned.split("\n", 1)[1] if cleaned.endswith("```"): cleaned = cleaned.rsplit("```", 1)[0] try: return json.loads(cleaned.strip()), result except json.JSONDecodeError: print(f"[Pass 1] JSON parse failed. Raw output: {result[:200]}") return None, result def run_pass2(findings_json, indication="Routine screening"): """Pass 2: Structured Findings -> Narrative Report (NO image).""" print(f"[Pass 2] Starting — generating narrative report") prompt = f"""You are a radiologist drafting a formal report. Convert the following structured findings into a professional radiology report following ACR reporting guidelines. STRUCTURED FINDINGS: {json.dumps(findings_json, indent=2)} CLINICAL INDICATION: {indication} Generate a report with these exact sections: EXAMINATION: [Modality and technique] CLINICAL INDICATION: [From context provided] COMPARISON: [Prior studies or "None available"] FINDINGS: [Detailed narrative description using standard radiology terminology.] IMPRESSION: [Numbered list of key findings, most critical first.] RULES: - ONLY describe findings from the structured data. Do NOT invent new findings. - If confidence < 0.6, use "possible" or "cannot exclude". - If image_quality is not "adequate", include a TECHNIQUE LIMITATION section. - End with RECOMMENDATION if any finding needs follow-up.""" messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) print(f"[Pass 2] Input tokens: {inputs['input_ids'].shape[-1]} — generating...") with torch.inference_mode(): output_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False) input_len = inputs["input_ids"].shape[-1] result = processor.decode(output_ids[0][input_len:], skip_special_tokens=True) print(f"[Pass 2] Done — output length: {len(result)} chars") return result def format_results(findings, report, triage, time_sens, warnings, p1_time, p2_time, indication, source_label): """Format pipeline results for display.""" triage_colors = {"CRITICAL": "\U0001f534", "URGENT": "\U0001f7e0", "ROUTINE": "\U0001f7e2"} emoji = triage_colors.get(triage, "\u26aa") triage_md = f"""## {emoji} {triage} **Time Sensitivity:** {time_sens} **Clinical Indication:** {indication} **Source:** {source_label} **Pipeline Time:** Pass 1: {p1_time:.0f}s | Pass 2: {p2_time:.0f}s | Total: {p1_time + p2_time:.0f}s """ findings_md = f"""**Modality:** {findings.get("modality_detected", "N/A")} **Image Quality:** {findings.get("image_quality", "N/A")} **Overall Impression:** {findings.get("overall_impression_category", "N/A").upper()} ### Findings ({len(findings.get("findings", []))} detected) """ for f in findings.get("findings", []): sev_emoji = {"severe": "\U0001f534", "moderate": "\U0001f7e0", "mild": "\U0001f535", "normal": "\U0001f7e2"}.get(f.get("severity", ""), "\u26aa") conf = f.get("confidence", 0) * 100 conf_label = "High" if conf >= 80 else "Moderate" if conf >= 60 else "Low" findings_md += f"""{sev_emoji} **{f.get("description", "")}** > Location: {f.get("location", "")} | Severity: {f.get("severity", "")} | Confidence: {conf:.0f}% ({conf_label}) """ if findings.get("missing_context"): findings_md += "\n### Missing Context Checklist\n\n" for item in findings["missing_context"]: findings_md += f"- \u2610 {item}\n" low_conf = [f for f in findings.get("findings", []) if f.get("confidence", 1) < 0.6 and f.get("severity") != "normal"] if low_conf: warnings_md = "" for f in low_conf: warnings_md += f"\u26a0\ufe0f **Low confidence ({f['confidence']:.0%}):** {f['description']}\n\n" warnings_md += "*These findings require radiologist verification.*" else: warnings_md = "\u2705 All abnormal findings have confidence \u2265 60%. No calibration warnings.\n\n*All findings should still be verified by a qualified radiologist.*" json_output = json.dumps(findings, indent=2) return triage_md, findings_md, report, warnings_md, json_output # ── Live Inference Function ─────────────────────────────── def analyze_uploaded_image(image, indication, priors): """Run full pipeline on user-uploaded image.""" if image is None: return "\u26a0\ufe0f Please upload a medical image.", "", "", "", "" if not indication.strip(): indication = "Routine screening" if not priors.strip(): priors = "None" try: print(f"\n{'='*60}") print(f"[Pipeline] Starting analysis...") print(f"[Pipeline] Image size: {image.size}, mode: {image.mode}") print(f"[Pipeline] Indication: {indication}") print(f"{'='*60}") # Pass 1 t1 = time.time() findings, raw = run_pass1(image, indication, priors) p1_time = time.time() - t1 print(f"[Pipeline] Pass 1 completed in {p1_time:.0f}s") if findings is None: return f"\u26a0\ufe0f Pass 1 JSON parse failed.\n\nRaw model output:\n```\n{raw[:500]}\n```", "", "", "", "" # Pass 2 t2 = time.time() report = run_pass2(findings, indication) p2_time = time.time() - t2 print(f"[Pipeline] Pass 2 completed in {p2_time:.0f}s") # Triage level, time_sens, warnings = classify_triage(findings) print(f"[Pipeline] Triage: {level} | Total time: {p1_time + p2_time:.0f}s") return format_results(findings, report, level, time_sens, warnings, p1_time, p2_time, indication, "Live MedGemma Inference") except Exception as e: import traceback error_msg = traceback.format_exc() print(f"[Pipeline] ERROR: {error_msg}") return f"\u26a0\ufe0f Pipeline error:\n```\n{str(e)}\n```", "", "", "", "" # ── Pre-computed Cases ──────────────────────────────────── CASES = { "Case 1: Shortness of Breath + Cough": { "indication": "Shortness of breath and productive cough x 3 days", "findings": { "modality_detected": "Chest X-ray", "anatomical_region": "Chest", "image_quality": "adequate", "quality_issues": [], "comparison_available": False, "findings": [ {"id": 1, "description": "There is increased opacity in the right lower lung field, potentially representing consolidation or atelectasis.", "location": "Right lower lung field", "severity": "moderate", "confidence": 0.80}, {"id": 2, "description": "The heart size appears mildly enlarged.", "location": "Heart", "severity": "mild", "confidence": 0.60}, {"id": 3, "description": "There are lines and tubes present, likely representing medical devices.", "location": "Chest", "severity": "normal", "confidence": 0.90} ], "overall_impression_category": "abnormal", "missing_context": ["Prior imaging for comparison", "Complete clinical history", "Relevant lab values (WBC, procalcitonin)", "Oxygen saturation"] }, "report": "EXAMINATION: Chest X-ray\n\nCLINICAL INDICATION: Shortness of breath and productive cough x 3 days\n\nCOMPARISON: None available\n\nFINDINGS:\n1. There is increased opacity in the right lower lung field, potentially representing consolidation or atelectasis.\n2. The heart size appears mildly enlarged.\n3. There are lines and tubes present, likely representing medical devices.\n\nIMPRESSION:\n1. Increased opacity in the right lower lung field, potentially representing consolidation or atelectasis.\n2. Mild cardiomegaly.\n3. Lines and tubes present, likely representing medical devices.\n\nRECOMMENDATION: Consider further evaluation with CT chest to better characterize the right lower lung opacity.", "triage": "URGENT", "time_sensitivity": "Within 2-4 hours", "pass1_time": 259, "pass2_time": 130 }, "Case 2: Chest Pain \u2014 Rule Out ACS": { "indication": "Chest pain, rule out acute coronary syndrome", "findings": { "modality_detected": "Chest X-ray", "anatomical_region": "Chest", "image_quality": "adequate", "quality_issues": [], "comparison_available": False, "findings": [ {"id": 1, "description": "The lungs appear clear bilaterally.", "location": "Lungs", "severity": "normal", "confidence": 0.95}, {"id": 2, "description": "The heart size is within normal limits.", "location": "Heart", "severity": "normal", "confidence": 0.90}, {"id": 3, "description": "The mediastinum appears normal in width.", "location": "Mediastinum", "severity": "normal", "confidence": 0.95}, {"id": 4, "description": "No obvious rib fractures are identified.", "location": "Ribs", "severity": "normal", "confidence": 0.90}, {"id": 5, "description": "No pleural effusion is evident.", "location": "Pleural space", "severity": "normal", "confidence": 0.95}, {"id": 6, "description": "No pneumothorax is identified.", "location": "Pleural space", "severity": "normal", "confidence": 0.95}, {"id": 7, "description": "No consolidation is seen.", "location": "Lungs", "severity": "normal", "confidence": 0.95}, {"id": 8, "description": "Lines/tubes are present, likely a central line.", "location": "Chest", "severity": "normal", "confidence": 0.90} ], "overall_impression_category": "normal", "missing_context": ["Prior imaging for comparison", "Complete clinical history", "Troponin levels", "ECG findings"] }, "report": "EXAMINATION: Chest X-ray, PA and lateral views.\n\nCLINICAL INDICATION: Chest pain, rule out acute coronary syndrome.\n\nCOMPARISON: None available.\n\nFINDINGS:\n1. The lungs are clear bilaterally.\n2. The heart size is within normal limits.\n3. The mediastinum is normal in width.\n4. No obvious rib fractures are identified.\n5. No pleural effusion is evident.\n6. No pneumothorax is identified.\n7. No consolidation is seen.\n8. A line/tube is present, likely a central line.\n\nIMPRESSION:\n1. No acute cardiopulmonary process identified.\n2. A line/tube is present, likely a central line.\n\nRECOMMENDATION: Consider follow-up with a cardiology consultation if chest pain persists.", "triage": "ROUTINE", "time_sensitivity": "Within 24 hours", "pass1_time": 466, "pass2_time": 160 }, "Case 3: Pre-operative Clearance": { "indication": "Pre-operative clearance for hip replacement", "findings": { "modality_detected": "Chest X-ray", "anatomical_region": "Chest", "image_quality": "adequate", "quality_issues": [], "comparison_available": False, "findings": [ {"id": 1, "description": "Lines/tubes present, including endotracheal tube, nasogastric tube, and two central venous catheters.", "location": "Chest", "severity": "normal", "confidence": 0.95}, {"id": 2, "description": "The heart size appears mildly enlarged.", "location": "Heart", "severity": "mild", "confidence": 0.80}, {"id": 3, "description": "There is some increased opacity in the right lung field, potentially representing mild pulmonary edema or atelectasis.", "location": "Right lung field", "severity": "mild", "confidence": 0.70}, {"id": 4, "description": "No pneumothorax is identified.", "location": "Pleural space", "severity": "normal", "confidence": 0.90}, {"id": 5, "description": "No pleural effusion is identified.", "location": "Pleural space", "severity": "normal", "confidence": 0.90}, {"id": 6, "description": "No consolidation is identified.", "location": "Lungs", "severity": "normal", "confidence": 0.90}, {"id": 7, "description": "No rib fractures are identified.", "location": "Ribs", "severity": "normal", "confidence": 0.90}, {"id": 8, "description": "Mediastinal width appears within normal limits.", "location": "Mediastinum", "severity": "normal", "confidence": 0.80}, {"id": 9, "description": "No masses or nodules are identified.", "location": "Lungs", "severity": "normal", "confidence": 0.90} ], "overall_impression_category": "abnormal", "missing_context": ["Prior imaging for comparison", "Complete clinical history", "Current medications", "Cardiac history"] }, "report": "EXAMINATION: Chest X-ray\n\nCLINICAL INDICATION: Pre-operative clearance for hip replacement\n\nCOMPARISON: None available\n\nFINDINGS:\n1. Lines/tubes are present throughout the chest, including an endotracheal tube, nasogastric tube, and two central venous catheters.\n2. The heart size appears mildly enlarged.\n3. There is some increased opacity in the right lung field, potentially representing mild pulmonary edema or atelectasis.\n4. No pneumothorax is identified.\n5. No pleural effusion is identified.\n6. No consolidation is identified.\n7. No rib fractures are identified.\n8. Mediastinal width appears within normal limits.\n9. No masses or nodules are identified.\n\nIMPRESSION:\n1. Lines/tubes are present throughout the chest.\n2. Mild cardiomegaly.\n3. Possible mild pulmonary edema or atelectasis in the right lung field.\n4. No pneumothorax or pleural effusion.\n5. No consolidation.\n6. No rib fractures.\n7. Normal mediastinal width.\n8. No masses or nodules.\n\nRECOMMENDATION: Follow-up chest X-ray is recommended in 6 months to assess for resolution of the right lung field opacity.", "triage": "ROUTINE", "time_sensitivity": "Within 24 hours", "pass1_time": 530, "pass2_time": 253 } } def analyze_precomputed(case_name): """Display pre-computed results.""" if case_name not in CASES: return "", "", "", "", "" case = CASES[case_name] return format_results( case["findings"], case["report"], case["triage"], case["time_sensitivity"], [], case["pass1_time"], case["pass2_time"], case["indication"], "Pre-computed (NIH Chest X-ray)" ) # ── Gradio Interface ────────────────────────────────────── HEADER = """# \U0001f3e5 RadTriage AI \u2014 Live Demo ### Multimodal Radiology Triage & Report Drafting with MedGemma **Two-pass pipeline** for reduced hallucination in medical image interpretation: 1. **Pass 1** \u2192 Structured findings extraction with confidence scores (MedGemma Vision+Text) 2. **Pass 2** \u2192 Narrative report grounded in structured findings only (MedGemma Text) 3. **Triage** \u2192 CRITICAL / URGENT / ROUTINE classification Built with [MedGemma 4B-IT](https://huggingface.co/google/medgemma-4b-it) | [Kaggle Notebook](https://www.kaggle.com/code/pramodmisra2020/radtriage-ai-radiology-triage-pipeline) | [GitHub](https://github.com/pramodmisra/radtriage-ai) | [Model Card](https://huggingface.co/pramodmisra/radtriage-ai-medgemma-4b) > \u26a0\ufe0f **For demonstration purposes only. Not for clinical use.** """ ARCHITECTURE = """### Two-Pass Architecture ``` Medical Image \u2500\u2500\u2510 \u251c\u2500\u2500\u25b6 [Pass 1: Structured Findings JSON] \u2500\u2500\u25b6 [Pass 2: Narrative Report] Clinical Info \u2500\u2500\u2518 \u2502 \u251c\u2500\u2500\u25b6 Triage (CRITICAL / URGENT / ROUTINE) \u251c\u2500\u2500\u25b6 Confidence Calibration \u2514\u2500\u2500\u25b6 Missing Context Checklist ``` **Why two passes?** Pass 2 generates from structured findings only (no image). It *cannot hallucinate findings not detected in Pass 1*. """ with gr.Blocks( title="RadTriage AI \u2014 Live Demo", theme=gr.themes.Base(primary_hue="blue", secondary_hue="green", neutral_hue="slate"), ) as demo: gr.Markdown(HEADER) with gr.Accordion("Architecture Details", open=False): gr.Markdown(ARCHITECTURE) gr.Markdown("---") with gr.Tabs(): # ── TAB 1: Live Upload ── with gr.TabItem("\U0001f52c Upload & Analyze (Live Inference)"): gr.Markdown("### Upload a medical image for real-time MedGemma analysis") gr.Markdown("*\u23f3 Inference takes **3-8 minutes** on T4 GPU. The button will show a loading spinner — please wait for it to complete. Check Space logs for progress.*") with gr.Row(): with gr.Column(scale=1): upload_image = gr.Image(type="pil", label="Upload Medical Image") upload_indication = gr.Textbox( label="Clinical Indication", placeholder="e.g., Shortness of breath and cough x 3 days", value="Routine screening" ) upload_priors = gr.Textbox( label="Prior Studies", placeholder="e.g., Chest X-ray 6 months ago showed clear lungs", value="None" ) upload_btn = gr.Button("\U0001f52c Analyze with MedGemma", variant="primary", size="lg") with gr.Tabs(): with gr.TabItem("\U0001f6a8 Triage"): live_triage = gr.Markdown() with gr.TabItem("\U0001f52c Structured Findings"): live_findings = gr.Markdown() with gr.TabItem("\U0001f4dd Narrative Report"): live_report = gr.Textbox(label="Draft Report (editable)", lines=18, interactive=True) with gr.TabItem("\u26a0\ufe0f Calibration"): live_warnings = gr.Markdown() with gr.TabItem("\U0001f4ca Raw JSON"): live_json = gr.Textbox(label="Structured Findings JSON", lines=20, interactive=False) upload_btn.click( fn=analyze_uploaded_image, inputs=[upload_image, upload_indication, upload_priors], outputs=[live_triage, live_findings, live_report, live_warnings, live_json], api_name="analyze", ) # ── TAB 2: Pre-computed Cases ── with gr.TabItem("\U0001f4cb Pre-computed Cases (Instant)"): gr.Markdown("### Pre-computed results from real NIH Chest X-ray cases") gr.Markdown("*These results were generated by MedGemma 4B-IT. Select a case to view instantly.*") case_selector = gr.Dropdown( choices=list(CASES.keys()), value=list(CASES.keys())[0], label="Select Case", ) case_btn = gr.Button("View Analysis", variant="primary") with gr.Tabs(): with gr.TabItem("\U0001f6a8 Triage"): pre_triage = gr.Markdown() with gr.TabItem("\U0001f52c Structured Findings"): pre_findings = gr.Markdown() with gr.TabItem("\U0001f4dd Narrative Report"): pre_report = gr.Textbox(label="Draft Report (editable)", lines=18, interactive=True) with gr.TabItem("\u26a0\ufe0f Calibration"): pre_warnings = gr.Markdown() with gr.TabItem("\U0001f4ca Raw JSON"): pre_json = gr.Textbox(label="Structured Findings JSON", lines=20, interactive=False) case_btn.click( fn=analyze_precomputed, inputs=[case_selector], outputs=[pre_triage, pre_findings, pre_report, pre_warnings, pre_json], ) demo.load( fn=lambda: analyze_precomputed(list(CASES.keys())[0]), outputs=[pre_triage, pre_findings, pre_report, pre_warnings, pre_json], ) gr.Markdown("---") gr.Markdown("""### Results Summary (Pre-computed Cases) | Case | Indication | Findings | Triage | Time | |------|-----------|----------|--------|------| | 1 | SOB + productive cough | 3 (2 abnormal) | \U0001f7e0 URGENT | 389s | | 2 | Chest pain, r/o ACS | 8 (0 abnormal) | \U0001f7e2 ROUTINE | 627s | | 3 | Pre-op clearance | 9 (2 abnormal) | \U0001f7e2 ROUTINE | 783s | """) gr.Markdown("""--- *Built for the [MedGemma Impact Challenge](https://www.kaggle.com/competitions/med-gemma-impact-challenge) | For demonstration and research purposes only.* """) if __name__ == "__main__": demo.queue(default_concurrency_limit=1) demo.launch()