| """
|
| NEXUS Streamlit Demo Application
|
|
|
| Interactive demo for the NEXUS Maternal-Neonatal Care Platform.
|
| Built with Google HAI-DEF models for the MedGemma Impact Challenge.
|
|
|
| HAI-DEF Models Used:
|
| - MedSigLIP: Medical image analysis (anemia, jaundice detection)
|
| - HeAR: Health acoustic representations (cry analysis)
|
| - MedGemma: Clinical reasoning and synthesis
|
| """
|
|
|
| import streamlit as st
|
| from pathlib import Path
|
| import sys
|
| import os
|
| import tempfile
|
| import json
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
|
| st.set_page_config(
|
| page_title="NEXUS - Maternal-Neonatal Care",
|
| page_icon="👶",
|
| layout="wide",
|
| initial_sidebar_state="expanded",
|
| )
|
|
|
|
|
| st.markdown("""
|
| <style>
|
| .main-header {
|
| font-size: 2.5rem;
|
| font-weight: bold;
|
| color: #1f77b4;
|
| text-align: center;
|
| margin-bottom: 1rem;
|
| }
|
| .sub-header {
|
| font-size: 1.2rem;
|
| color: #666;
|
| text-align: center;
|
| margin-bottom: 2rem;
|
| }
|
| .risk-high {
|
| background-color: #ffcccc;
|
| border: 2px solid #ff0000;
|
| padding: 1rem;
|
| border-radius: 10px;
|
| }
|
| .risk-medium {
|
| background-color: #fff3cd;
|
| border: 2px solid #ffc107;
|
| padding: 1rem;
|
| border-radius: 10px;
|
| }
|
| .risk-low {
|
| background-color: #d4edda;
|
| border: 2px solid #28a745;
|
| padding: 1rem;
|
| border-radius: 10px;
|
| }
|
| .metric-card {
|
| background-color: #f8f9fa;
|
| padding: 1rem;
|
| border-radius: 10px;
|
| text-align: center;
|
| }
|
| .model-badge {
|
| display: inline-block;
|
| padding: 2px 10px;
|
| border-radius: 12px;
|
| font-size: 0.78rem;
|
| font-weight: 600;
|
| color: white;
|
| letter-spacing: 0.3px;
|
| }
|
| .stMetric > div {
|
| background-color: #f8f9fa;
|
| padding: 0.5rem;
|
| border-radius: 8px;
|
| }
|
| </style>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| @st.cache_resource
|
| def load_anemia_detector():
|
| """Load anemia detector model with error handling."""
|
| try:
|
| from nexus.anemia_detector import AnemiaDetector
|
| detector = AnemiaDetector()
|
| return detector, None
|
| except Exception as e:
|
| return None, str(e)
|
|
|
|
|
| @st.cache_resource
|
| def load_jaundice_detector():
|
| """Load jaundice detector model with error handling."""
|
| try:
|
| from nexus.jaundice_detector import JaundiceDetector
|
| detector = JaundiceDetector()
|
| return detector, None
|
| except Exception as e:
|
| return None, str(e)
|
|
|
|
|
| @st.cache_resource
|
| def load_cry_analyzer():
|
| """Load cry analyzer with error handling."""
|
| try:
|
| from nexus.cry_analyzer import CryAnalyzer
|
| analyzer = CryAnalyzer()
|
| return analyzer, None
|
| except Exception as e:
|
| return None, str(e)
|
|
|
|
|
| @st.cache_resource
|
| def load_clinical_synthesizer():
|
| """Load clinical synthesizer (MedGemma) with error handling."""
|
| try:
|
| import os
|
| from nexus.clinical_synthesizer import ClinicalSynthesizer
|
| use_medgemma = os.environ.get("NEXUS_USE_MEDGEMMA", "true").lower() != "false"
|
| synthesizer = ClinicalSynthesizer(use_medgemma=use_medgemma)
|
| return synthesizer, None
|
| except Exception as e:
|
| return None, str(e)
|
|
|
|
|
| def get_hai_def_info():
|
| """Get HAI-DEF models information with validated accuracy numbers."""
|
| return {
|
| "MedSigLIP": {
|
| "name": "MedSigLIP (google/medsiglip-448)",
|
| "use": "Image analysis for anemia and jaundice detection + bilirubin regression",
|
| "method": "Zero-shot classification (max-similarity, 8 prompts/class) + trained SVM/LR classifiers on embeddings",
|
| "accuracy": "Anemia: trained classifier on augmented data, Jaundice: trained classifier on 2,235 images, Bilirubin: MAE 2.67 mg/dL (r=0.77)",
|
| "badge": "Vision",
|
| "badge_color": "#388e3c",
|
| },
|
| "HeAR": {
|
| "name": "HeAR (google/hear-pytorch)",
|
| "use": "Infant cry analysis for asphyxia and cry type classification",
|
| "method": "512-dim health acoustic embeddings + trained linear classifier on donate-a-cry dataset (5-class: hungry, belly_pain, burping, discomfort, tired)",
|
| "accuracy": "Trained cry type classifier with asphyxia risk derivation from distress patterns",
|
| "badge": "Audio",
|
| "badge_color": "#f57c00",
|
| },
|
| "MedGemma": {
|
| "name": "MedGemma 1.5 4B (google/medgemma-1.5-4b-it)",
|
| "use": "Clinical reasoning and recommendation synthesis",
|
| "method": "4-bit NF4 quantized inference with WHO IMNCI protocol-aligned synthesis and 6-agent reasoning traces",
|
| "accuracy": "Protocol-aligned clinical recommendations with structured reasoning chains",
|
| "badge": "Language",
|
| "badge_color": "#1976d2",
|
| },
|
| }
|
|
|
|
|
| def main():
|
| """Main application."""
|
|
|
|
|
| st.markdown('<div class="main-header">NEXUS</div>', unsafe_allow_html=True)
|
| st.markdown(
|
| '<div class="sub-header">AI-Powered Maternal-Neonatal Care Platform</div>',
|
| unsafe_allow_html=True
|
| )
|
|
|
|
|
| with st.sidebar:
|
| st.markdown("## 🏥 NEXUS")
|
| st.markdown("---")
|
|
|
| assessment_type = st.radio(
|
| "Select Assessment Type",
|
| [
|
| "Maternal Anemia Screening",
|
| "Neonatal Jaundice Detection",
|
| "Cry Analysis",
|
| "Combined Assessment",
|
| "Agentic Workflow",
|
| "HAI-DEF Models Info"
|
| ],
|
| index=0,
|
| )
|
|
|
| st.markdown("---")
|
| st.markdown("### About NEXUS")
|
| st.markdown("""
|
| NEXUS uses AI to provide non-invasive screening for:
|
| - **Maternal Anemia** via conjunctiva imaging
|
| - **Neonatal Jaundice** via skin color analysis
|
| - **Birth Asphyxia** via cry pattern analysis
|
|
|
| Built with **Google HAI-DEF models** for the MedGemma Impact Challenge 2026.
|
| """)
|
|
|
| st.markdown("---")
|
| st.markdown("### Edge AI Mode")
|
| edge_mode = st.toggle("Enable Edge AI Mode", value=False, key="edge_mode")
|
| if edge_mode:
|
| st.success("Edge AI: INT8 quantized models + offline inference")
|
| else:
|
| st.info("Cloud mode: Full-precision HAI-DEF models")
|
|
|
| st.markdown("---")
|
| st.markdown("### HAI-DEF Models")
|
| st.markdown("""
|
| - **MedSigLIP**: Vision (trained classifiers)
|
| - **HeAR**: Audio (trained cry classifier)
|
| - **MedGemma 1.5**: Clinical AI (4-bit NF4)
|
| """)
|
|
|
|
|
| if edge_mode:
|
| render_edge_ai_banner()
|
|
|
|
|
| if assessment_type == "Maternal Anemia Screening":
|
| render_anemia_screening()
|
| elif assessment_type == "Neonatal Jaundice Detection":
|
| render_jaundice_detection()
|
| elif assessment_type == "Cry Analysis":
|
| render_cry_analysis()
|
| elif assessment_type == "Combined Assessment":
|
| render_combined_assessment()
|
| elif assessment_type == "Agentic Workflow":
|
| render_agentic_workflow()
|
| else:
|
| render_hai_def_info()
|
|
|
|
|
| def render_edge_ai_banner():
|
| """Show Edge AI mode status and model metrics."""
|
| st.markdown("""
|
| <div style="background: linear-gradient(135deg, #1a237e 0%, #0d47a1 100%);
|
| color: white; padding: 1rem 1.5rem; border-radius: 10px; margin-bottom: 1rem;">
|
| <h4 style="margin:0; color: white;">Edge AI Mode Active</h4>
|
| <p style="margin: 0.3rem 0 0 0; opacity: 0.9; font-size: 0.9rem;">
|
| Running INT8 quantized models for offline-capable inference on low-resource devices.
|
| </p>
|
| </div>
|
| """, unsafe_allow_html=True)
|
|
|
| col1, col2, col3, col4 = st.columns(4)
|
| with col1:
|
| st.metric("MedSigLIP INT8", "111.2 MB", "-86% memory")
|
| with col2:
|
| st.metric("Acoustic Model", "0.6 MB", "INT8 quantized")
|
| with col3:
|
| st.metric("Text Embeddings", "12 KB", "Pre-computed")
|
| with col4:
|
| st.metric("Total Edge Size", "~289 MB", "Offline-ready")
|
|
|
| with st.expander("Edge AI Details"):
|
| st.markdown("""
|
| **Quantization**: Dynamic INT8 (PyTorch `quantize_dynamic`, qnnpack backend)
|
|
|
| | Component | Cloud (FP32) | Edge (INT8) | Compression |
|
| |-----------|-------------|-------------|-------------|
|
| | MedSigLIP Vision | 812.6 MB | 111.2 MB | **7.31x** |
|
| | Acoustic Model | 0.665 MB | 0.599 MB | 1.11x |
|
| | CPU Latency | 97.7 ms | ~65 ms (ARM est.) | ~1.5x faster |
|
|
|
| **Target Devices**: Android 8.0+, ARM Cortex-A53, 2GB RAM
|
|
|
| **Offline Capabilities**:
|
| - Image analysis via INT8 MedSigLIP + pre-computed binary text embeddings
|
| - Audio analysis via INT8 acoustic feature extractor
|
| - Clinical reasoning via rule-based WHO IMNCI protocols (no MedGemma required)
|
| """)
|
|
|
|
|
| def _cleanup_temp(path: str) -> None:
|
| """Safely remove a temporary file."""
|
| try:
|
| if path and os.path.exists(path):
|
| os.unlink(path)
|
| except OSError:
|
| pass
|
|
|
|
|
| def _save_upload_to_temp(uploaded_file, suffix: str) -> str:
|
| """Save an uploaded file to a temporary path and return the path."""
|
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
| try:
|
| tmp.write(uploaded_file.getvalue())
|
| tmp.close()
|
| return tmp.name
|
| except Exception:
|
| tmp.close()
|
| _cleanup_temp(tmp.name)
|
| raise
|
|
|
|
|
| def _model_badge(name: str, color: str) -> str:
|
| """Return an HTML badge for displaying which HAI-DEF model is active."""
|
| return (
|
| f'<span style="background:{color}; color:white; padding:2px 10px; '
|
| f'border-radius:12px; font-size:0.78rem; font-weight:600; '
|
| f'letter-spacing:0.3px;">{name}</span>'
|
| )
|
|
|
|
|
| def render_anemia_screening():
|
| """Render anemia screening interface."""
|
| st.header("Maternal Anemia Screening")
|
| st.markdown(
|
| f"Upload a clear image of the inner eyelid (conjunctiva) for anemia screening. "
|
| f'{_model_badge("MedSigLIP", "#388e3c")}',
|
| unsafe_allow_html=True,
|
| )
|
|
|
| col1, col2 = st.columns([1, 1])
|
|
|
| with col1:
|
| st.subheader("Upload Image")
|
| uploaded_file = st.file_uploader(
|
| "Choose a conjunctiva image",
|
| type=["jpg", "jpeg", "png"],
|
| key="anemia_upload"
|
| )
|
|
|
| if uploaded_file:
|
| image_bytes = uploaded_file.getvalue()
|
| st.image(image_bytes, caption="Uploaded Image", use_container_width=True)
|
|
|
| with col2:
|
| st.subheader("Analysis Results")
|
|
|
| if uploaded_file:
|
| with st.spinner("Analyzing image..."):
|
| tmp_path = None
|
| try:
|
| detector, load_err = load_anemia_detector()
|
| if detector is None:
|
| st.error(f"Could not load model: {load_err}")
|
| return
|
|
|
| tmp_path = _save_upload_to_temp(uploaded_file, ".jpg")
|
|
|
| result = detector.detect(tmp_path)
|
| color_info = detector.analyze_color_features(tmp_path)
|
|
|
|
|
| risk_class = f"risk-{result['risk_level']}"
|
| st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
|
|
|
| if result["is_anemic"]:
|
| st.error("⚠️ ANEMIA DETECTED")
|
| else:
|
| st.success("✅ No Anemia Detected")
|
|
|
| st.markdown("</div>", unsafe_allow_html=True)
|
|
|
|
|
| col_a, col_b, col_c = st.columns(3)
|
| with col_a:
|
| st.metric("Confidence", f"{result['confidence']:.1%}")
|
| with col_b:
|
| st.metric("Risk Level", result['risk_level'].upper())
|
| with col_c:
|
| st.metric("Est. Hemoglobin", f"{color_info['estimated_hemoglobin']} g/dL")
|
|
|
|
|
| st.markdown("### Recommendation")
|
| st.info(result["recommendation"])
|
|
|
|
|
| with st.expander("Technical Details"):
|
| st.json({
|
| "anemia_score": round(result["anemia_score"], 3),
|
| "healthy_score": round(result["healthy_score"], 3),
|
| "red_ratio": round(color_info["red_ratio"], 3),
|
| "pallor_index": round(color_info["pallor_index"], 3),
|
| })
|
|
|
| except Exception as e:
|
| st.error(f"Error analyzing image: {e}")
|
| finally:
|
| _cleanup_temp(tmp_path)
|
| else:
|
| st.info("👆 Upload an image to begin analysis")
|
|
|
|
|
| def render_jaundice_detection():
|
| """Render jaundice detection interface."""
|
| st.header("Neonatal Jaundice Detection")
|
| st.markdown(
|
| f"Upload an image of the newborn's skin or sclera for jaundice assessment. "
|
| f'{_model_badge("MedSigLIP", "#388e3c")}',
|
| unsafe_allow_html=True,
|
| )
|
|
|
| col1, col2 = st.columns([1, 1])
|
|
|
| with col1:
|
| st.subheader("Upload Image")
|
| uploaded_file = st.file_uploader(
|
| "Choose a neonatal image",
|
| type=["jpg", "jpeg", "png"],
|
| key="jaundice_upload"
|
| )
|
|
|
| if uploaded_file:
|
| image_bytes = uploaded_file.getvalue()
|
| st.image(image_bytes, caption="Uploaded Image", use_container_width=True)
|
|
|
|
|
| st.subheader("Patient Information (Optional)")
|
| age_days = st.number_input("Age (days)", min_value=0, max_value=28, value=3)
|
| birth_weight = st.number_input("Birth weight (grams)", min_value=500, max_value=5000, value=3000)
|
|
|
| with col2:
|
| st.subheader("Analysis Results")
|
|
|
| if uploaded_file:
|
| with st.spinner("Analyzing image..."):
|
| tmp_path = None
|
| try:
|
| detector, load_err = load_jaundice_detector()
|
| if detector is None:
|
| st.error(f"Could not load model: {load_err}")
|
| return
|
|
|
| tmp_path = _save_upload_to_temp(uploaded_file, ".jpg")
|
|
|
| result = detector.detect(tmp_path)
|
| zone_info = detector.analyze_kramer_zones(tmp_path)
|
|
|
|
|
| risk_class = "risk-high" if result["needs_phototherapy"] else (
|
| "risk-medium" if result["severity"] in ["moderate", "mild"] else "risk-low"
|
| )
|
| st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
|
|
|
| if result["has_jaundice"]:
|
| st.warning(f"⚠️ JAUNDICE DETECTED - {result['severity'].upper()}")
|
| else:
|
| st.success("✅ No Significant Jaundice")
|
|
|
| st.markdown("</div>", unsafe_allow_html=True)
|
|
|
|
|
| col_a, col_b, col_c = st.columns(3)
|
| with col_a:
|
| bili_value = result.get('estimated_bilirubin_ml', result.get('estimated_bilirubin', 0))
|
| bili_method = result.get('bilirubin_method', 'Color Analysis')
|
| st.metric("Est. Bilirubin", f"{bili_value} mg/dL")
|
| st.caption(f"Method: {bili_method}")
|
| with col_b:
|
| st.metric("Severity", result['severity'].upper())
|
| with col_c:
|
| st.metric("Kramer Zone", zone_info['kramer_zone'])
|
|
|
|
|
| if result["needs_phototherapy"]:
|
| st.error("🔆 PHOTOTHERAPY RECOMMENDED")
|
|
|
|
|
| st.markdown("### Recommendation")
|
| st.info(result["recommendation"])
|
|
|
|
|
| with st.expander("Kramer Zone Analysis"):
|
| st.write(f"**Zone**: {zone_info['kramer_zone']} - {zone_info['zone_description']}")
|
| st.write(f"**Yellow Index**: {zone_info['yellow_index']}")
|
| st.progress(min(zone_info['yellow_index'] * 2, 1.0))
|
|
|
|
|
| with st.expander("Technical Details"):
|
| details = {
|
| "jaundice_score": round(result["jaundice_score"], 3),
|
| "confidence": round(result["confidence"], 3),
|
| "model": result.get("model", "unknown"),
|
| "model_type": result.get("model_type", "unknown"),
|
| "bilirubin_method": result.get("bilirubin_method", "Color Analysis"),
|
| }
|
| if result.get("estimated_bilirubin_ml") is not None:
|
| details["bilirubin_ml"] = result["estimated_bilirubin_ml"]
|
| details["bilirubin_color"] = result["estimated_bilirubin"]
|
| st.json(details)
|
|
|
| except Exception as e:
|
| st.error(f"Error analyzing image: {e}")
|
| finally:
|
| _cleanup_temp(tmp_path)
|
| else:
|
| st.info("👆 Upload an image to begin analysis")
|
|
|
|
|
| def render_cry_analysis():
|
| """Render cry analysis interface."""
|
| st.header("Infant Cry Analysis")
|
| st.markdown(
|
| f"Upload an audio recording of the infant's cry for analysis. "
|
| f'{_model_badge("HeAR", "#f57c00")}',
|
| unsafe_allow_html=True,
|
| )
|
|
|
| col1, col2 = st.columns([1, 1])
|
|
|
| with col1:
|
| st.subheader("Upload Audio")
|
| uploaded_file = st.file_uploader(
|
| "Choose a cry audio file",
|
| type=["wav", "mp3", "ogg"],
|
| key="cry_upload"
|
| )
|
|
|
| if uploaded_file:
|
| st.audio(uploaded_file)
|
|
|
| with col2:
|
| st.subheader("Analysis Results")
|
|
|
| if uploaded_file:
|
| with st.spinner("Analyzing cry..."):
|
| tmp_path = None
|
| try:
|
| analyzer, load_err = load_cry_analyzer()
|
| if analyzer is None:
|
| st.error(f"Could not load model: {load_err}")
|
| return
|
|
|
| tmp_path = _save_upload_to_temp(uploaded_file, ".wav")
|
|
|
| result = analyzer.analyze(tmp_path)
|
|
|
|
|
| risk_class = f"risk-{result['risk_level']}"
|
| st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
|
|
|
| if result["is_abnormal"]:
|
| st.error("⚠️ ABNORMAL CRY PATTERN DETECTED")
|
| else:
|
| st.success("✅ Normal Cry Pattern")
|
|
|
| st.markdown("</div>", unsafe_allow_html=True)
|
|
|
|
|
| col_a, col_b, col_c = st.columns(3)
|
| with col_a:
|
| st.metric("Asphyxia Risk", f"{result['asphyxia_risk']:.1%}")
|
| with col_b:
|
| st.metric("Cry Type", result['cry_type'].title())
|
| with col_c:
|
| st.metric("F0 (Pitch)", f"{result['features']['f0_mean']:.0f} Hz")
|
|
|
|
|
| st.markdown("### Recommendation")
|
| st.info(result["recommendation"])
|
|
|
|
|
| with st.expander("Acoustic Features"):
|
| st.json(result["features"])
|
|
|
| except Exception as e:
|
| st.error(f"Error analyzing audio: {e}")
|
| finally:
|
| _cleanup_temp(tmp_path)
|
| else:
|
| st.info("👆 Upload an audio file to begin analysis")
|
|
|
|
|
| def render_combined_assessment():
|
| """Render combined assessment interface using Clinical Synthesizer."""
|
| st.header("Combined Clinical Assessment")
|
| st.markdown(
|
| f"Upload multiple inputs for a comprehensive assessment using **MedGemma Clinical Synthesizer**. "
|
| f"This combines findings from all HAI-DEF models to provide integrated clinical recommendations. "
|
| f'{_model_badge("MedSigLIP", "#388e3c")} '
|
| f'{_model_badge("HeAR", "#f57c00")} '
|
| f'{_model_badge("MedGemma", "#1976d2")}',
|
| unsafe_allow_html=True,
|
| )
|
|
|
|
|
|
|
| st.session_state.findings = {
|
| "anemia": None,
|
| "jaundice": None,
|
| "cry": None
|
| }
|
|
|
| col1, col2, col3 = st.columns(3)
|
|
|
| with col1:
|
| st.subheader("🩸 Anemia Screening")
|
| anemia_file = st.file_uploader(
|
| "Conjunctiva image",
|
| type=["jpg", "jpeg", "png"],
|
| key="combined_anemia"
|
| )
|
| if anemia_file:
|
| st.image(anemia_file.getvalue(), use_container_width=True)
|
| with st.spinner("Analyzing..."):
|
| try:
|
| detector, load_err = load_anemia_detector()
|
| if detector is None:
|
| st.error(f"Model error: {load_err}")
|
| else:
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| tmp.write(anemia_file.getvalue())
|
| result = detector.detect(tmp.name)
|
| st.session_state.findings["anemia"] = result
|
| if result["is_anemic"]:
|
| st.error(f"Anemia: {result['risk_level'].upper()}")
|
| else:
|
| st.success("No Anemia")
|
| except Exception as e:
|
| st.error(f"Error: {e}")
|
|
|
| with col2:
|
| st.subheader("👶 Jaundice Detection")
|
| jaundice_file = st.file_uploader(
|
| "Neonatal skin image",
|
| type=["jpg", "jpeg", "png"],
|
| key="combined_jaundice"
|
| )
|
| if jaundice_file:
|
| st.image(jaundice_file.getvalue(), use_container_width=True)
|
| with st.spinner("Analyzing..."):
|
| try:
|
| detector, load_err = load_jaundice_detector()
|
| if detector is None:
|
| st.error(f"Model error: {load_err}")
|
| else:
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| tmp.write(jaundice_file.getvalue())
|
| result = detector.detect(tmp.name)
|
| st.session_state.findings["jaundice"] = result
|
| if result["has_jaundice"]:
|
| st.warning(f"Jaundice: {result['severity'].upper()}")
|
| else:
|
| st.success("No Jaundice")
|
| except Exception as e:
|
| st.error(f"Error: {e}")
|
|
|
| with col3:
|
| st.subheader("🔊 Cry Analysis")
|
| cry_file = st.file_uploader(
|
| "Cry audio",
|
| type=["wav", "mp3", "ogg"],
|
| key="combined_cry"
|
| )
|
| if cry_file:
|
| st.audio(cry_file)
|
| with st.spinner("Analyzing..."):
|
| try:
|
| analyzer, load_err = load_cry_analyzer()
|
| if analyzer is None:
|
| st.error(f"Model error: {load_err}")
|
| raise RuntimeError(load_err)
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| tmp.write(cry_file.getvalue())
|
| result = analyzer.analyze(tmp.name)
|
| st.session_state.findings["cry"] = result
|
| if result["is_abnormal"]:
|
| st.error(f"Abnormal Cry: {result['risk_level'].upper()}")
|
| else:
|
| st.success("Normal Cry")
|
| except Exception as e:
|
| st.error(f"Error: {e}")
|
|
|
|
|
| st.markdown("---")
|
| st.subheader("🏥 Clinical Synthesis (MedGemma)")
|
|
|
|
|
| has_findings = any(v is not None for v in st.session_state.findings.values())
|
|
|
| if has_findings:
|
| if st.button("Generate Clinical Synthesis", type="primary"):
|
| with st.spinner("Synthesizing findings with MedGemma..."):
|
| try:
|
| synthesizer, load_err = load_clinical_synthesizer()
|
| if synthesizer is None:
|
| st.error(f"Could not load synthesizer: {load_err}")
|
| return
|
|
|
|
|
| findings = {}
|
| if st.session_state.findings["anemia"]:
|
| findings["anemia"] = st.session_state.findings["anemia"]
|
| if st.session_state.findings["jaundice"]:
|
| findings["jaundice"] = st.session_state.findings["jaundice"]
|
| if st.session_state.findings["cry"]:
|
| findings["cry"] = st.session_state.findings["cry"]
|
|
|
| synthesis = synthesizer.synthesize(findings)
|
|
|
|
|
| severity_level = synthesis.get("severity_level", "GREEN")
|
| severity_colors = {
|
| "GREEN": ("🟢", "#d4edda", "#155724"),
|
| "YELLOW": ("🟡", "#fff3cd", "#856404"),
|
| "RED": ("🔴", "#f8d7da", "#721c24")
|
| }
|
| emoji, bg_color, text_color = severity_colors.get(severity_level, ("⚪", "#f8f9fa", "#000"))
|
|
|
| st.markdown(f"""
|
| <div style="background-color: {bg_color}; padding: 1.5rem; border-radius: 10px; margin: 1rem 0;">
|
| <h3 style="color: {text_color}; margin: 0;">{emoji} Severity: {severity_level}</h3>
|
| <p style="color: {text_color}; font-size: 1.1rem; margin-top: 0.5rem;">{synthesis.get('severity_description', '')}</p>
|
| </div>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| st.markdown("### Summary")
|
| st.info(synthesis.get("summary", "No summary available"))
|
|
|
|
|
| if synthesis.get("immediate_actions"):
|
| st.markdown("### Immediate Actions")
|
| for action in synthesis["immediate_actions"]:
|
| st.markdown(f"- {action}")
|
|
|
|
|
| col_a, col_b = st.columns(2)
|
| with col_a:
|
| st.markdown("### Referral Status")
|
| if synthesis.get("referral_needed"):
|
| st.error(f"⚠️ REFERRAL NEEDED: {synthesis.get('referral_urgency', 'standard').upper()}")
|
| else:
|
| st.success("✅ No referral needed")
|
|
|
| with col_b:
|
| st.markdown("### Follow-up")
|
| st.info(synthesis.get("follow_up", "Schedule routine follow-up"))
|
|
|
|
|
| with st.expander("Technical Details"):
|
| model_name = synthesis.get("model", "unknown")
|
| st.json({
|
| "model": model_name,
|
| "model_id": synthesis.get("model_id", ""),
|
| "generated_at": synthesis.get("generated_at"),
|
| "urgent_conditions": synthesis.get("urgent_conditions", []),
|
| })
|
| if model_name and "Fallback" not in str(model_name):
|
| st.success(f"Synthesis powered by {model_name}")
|
| elif "Fallback" in str(model_name):
|
| st.warning("Using rule-based fallback (MedGemma unavailable)")
|
|
|
| except Exception as e:
|
| st.error(f"Error generating synthesis: {e}")
|
| else:
|
| st.info("👆 Upload at least one input (image or audio) to generate clinical synthesis")
|
|
|
|
|
| def render_hai_def_info():
|
| """Render HAI-DEF models information."""
|
| st.header("Google HAI-DEF Models")
|
| st.markdown("""
|
| NEXUS is built using **Google Health AI Developer Foundations (HAI-DEF)** models,
|
| designed specifically for healthcare applications in resource-limited settings.
|
| """)
|
|
|
| hai_def = get_hai_def_info()
|
|
|
|
|
| st.markdown("---")
|
| col1, col2 = st.columns([1, 2])
|
| with col1:
|
| st.markdown("### 🖼️ MedSigLIP")
|
| st.info("google/medsiglip-448\n\nHAI-DEF Vision Model")
|
| with col2:
|
| info = hai_def["MedSigLIP"]
|
| st.markdown(f"**Model**: {info['name']}")
|
| st.markdown(f"**Use Case**: {info['use']}")
|
| st.markdown(f"**Method**: {info['method']}")
|
| st.markdown(f"**Validated Performance**: {info['accuracy']}")
|
| st.markdown("""
|
| MedSigLIP enables zero-shot medical image classification using
|
| text prompts. NEXUS extends this with trained SVM/LR classifiers
|
| on MedSigLIP embeddings (with data augmentation) for improved
|
| accuracy, plus a novel 3-layer MLP regression head for continuous
|
| bilirubin prediction from frozen embeddings.
|
| """)
|
|
|
|
|
| st.markdown("---")
|
| col1, col2 = st.columns([1, 2])
|
| with col1:
|
| st.markdown("### 🔊 HeAR")
|
| st.info("google/hear-pytorch\n\nHAI-DEF Audio Model")
|
| with col2:
|
| info = hai_def["HeAR"]
|
| st.markdown(f"**Model**: {info['name']}")
|
| st.markdown(f"**Use Case**: {info['use']}")
|
| st.markdown(f"**Method**: {info['method']}")
|
| st.markdown(f"**Validated Performance**: {info['accuracy']}")
|
| st.markdown("""
|
| HeAR (Health Acoustic Representations) produces 512-dim embeddings
|
| from 2-second audio clips at 16kHz. NEXUS trains a linear classifier
|
| on HeAR embeddings for 5-class cry type classification (hungry,
|
| belly_pain, burping, discomfort, tired) and derives asphyxia risk
|
| from distress patterns.
|
| """)
|
|
|
|
|
| st.markdown("---")
|
| col1, col2 = st.columns([1, 2])
|
| with col1:
|
| st.markdown("### 🧠 MedGemma")
|
| st.info("google/medgemma-1.5-4b-it\n\nHAI-DEF Language Model")
|
| with col2:
|
| info = hai_def["MedGemma"]
|
| st.markdown(f"**Model**: {info['name']}")
|
| st.markdown(f"**Use Case**: {info['use']}")
|
| st.markdown(f"**Method**: {info['method']}")
|
| st.markdown(f"**Validated Performance**: {info['accuracy']}")
|
| st.markdown("""
|
| MedGemma 1.5 provides clinical reasoning capabilities via 4-bit NF4
|
| quantized inference (~2 GB VRAM). It synthesizes multi-modal findings
|
| into actionable recommendations following WHO IMNCI protocols,
|
| producing structured reasoning chains within the 6-agent pipeline.
|
| """)
|
|
|
|
|
| st.markdown("---")
|
| st.subheader("🏆 MedGemma Impact Challenge 2026")
|
| st.markdown("""
|
| NEXUS is being developed for the [MedGemma Impact Challenge](https://www.kaggle.com/competitions/medgemma-impact-challenge-2026)
|
| on Kaggle.
|
|
|
| **Competition Focus**: Solutions for resource-limited healthcare settings using HAI-DEF models.
|
|
|
| **NEXUS Impact**:
|
| - 📍 Target: Sub-Saharan Africa and South Asia
|
| - 👩⚕️ Users: Community Health Workers
|
| - 🎯 Goals: Reduce maternal/neonatal mortality
|
| - 📱 Deployment: Offline-capable mobile app
|
| """)
|
|
|
|
|
| def render_agentic_workflow():
|
| """Render the agentic workflow interface with reasoning traces."""
|
| st.header("Agentic Clinical Workflow")
|
| st.markdown(
|
| f"**6-Agent Pipeline** with step-by-step reasoning traces. "
|
| f"Each agent explains its clinical decision process, providing a full audit trail. "
|
| f'{_model_badge("MedSigLIP", "#388e3c")} '
|
| f'{_model_badge("HeAR", "#f57c00")} '
|
| f'{_model_badge("MedGemma", "#1976d2")}',
|
| unsafe_allow_html=True,
|
| )
|
|
|
|
|
| st.markdown("""
|
| <div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem; flex-wrap: wrap; margin: 1rem 0;">
|
| <div style="background: #e3f2fd; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #1976d2;">Triage</div>
|
| <span style="font-size: 1.5rem;">→</span>
|
| <div style="background: #e8f5e9; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #388e3c;">Image (MedSigLIP)</div>
|
| <span style="font-size: 1.5rem;">→</span>
|
| <div style="background: #fff3e0; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #f57c00;">Audio (HeAR)</div>
|
| <span style="font-size: 1.5rem;">→</span>
|
| <div style="background: #f3e5f5; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #7b1fa2;">Protocol (WHO)</div>
|
| <span style="font-size: 1.5rem;">→</span>
|
| <div style="background: #fce4ec; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #c62828;">Referral</div>
|
| <span style="font-size: 1.5rem;">→</span>
|
| <div style="background: #e0f7fa; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #00838f;">Synthesis (MedGemma)</div>
|
| </div>
|
| """, unsafe_allow_html=True)
|
|
|
| st.markdown("---")
|
|
|
|
|
| col_left, col_right = st.columns([1, 1])
|
|
|
| with col_left:
|
| st.subheader("Patient & Inputs")
|
| patient_type = st.selectbox("Patient Type", ["newborn", "pregnant"], key="agentic_patient")
|
|
|
|
|
| st.markdown("**Danger Signs**")
|
| danger_signs = []
|
| if patient_type == "pregnant":
|
| sign_options = [
|
| ("Severe headache", "high"),
|
| ("Blurred vision", "high"),
|
| ("Convulsions", "critical"),
|
| ("Severe abdominal pain", "high"),
|
| ("Vaginal bleeding", "critical"),
|
| ("High fever", "high"),
|
| ("Severe pallor", "medium"),
|
| ]
|
| else:
|
| sign_options = [
|
| ("Not breathing at birth", "critical"),
|
| ("Convulsions", "critical"),
|
| ("Severe chest indrawing", "high"),
|
| ("Not feeding", "high"),
|
| ("High fever (>38C)", "high"),
|
| ("Hypothermia (<35.5C)", "high"),
|
| ("Lethargy / unconscious", "critical"),
|
| ("Umbilical redness", "medium"),
|
| ]
|
|
|
| selected_signs = st.multiselect(
|
| "Select present danger signs",
|
| [s[0] for s in sign_options],
|
| key="agentic_signs"
|
| )
|
| for label, severity in sign_options:
|
| if label in selected_signs:
|
| danger_signs.append({
|
| "id": label.lower().replace(" ", "_"),
|
| "label": label,
|
| "severity": severity,
|
| "present": True,
|
| })
|
|
|
|
|
| st.markdown("**Clinical Images**")
|
| conjunctiva_file = st.file_uploader(
|
| "Conjunctiva image (anemia)", type=["jpg", "jpeg", "png"],
|
| key="agentic_conjunctiva"
|
| )
|
| skin_file = st.file_uploader(
|
| "Skin image (jaundice)", type=["jpg", "jpeg", "png"],
|
| key="agentic_skin"
|
| )
|
| cry_file = st.file_uploader(
|
| "Cry audio", type=["wav", "mp3", "ogg"],
|
| key="agentic_cry"
|
| )
|
|
|
| with col_right:
|
| st.subheader("Workflow Execution")
|
|
|
| if st.button("Run Agentic Assessment", type="primary", key="run_agentic"):
|
| with st.spinner("Running 6-agent workflow..."):
|
| try:
|
| from nexus.agentic_workflow import (
|
| AgenticWorkflowEngine,
|
| AgentPatientInfo,
|
| DangerSign,
|
| WorkflowInput,
|
| )
|
|
|
|
|
| _temp_paths = []
|
| conjunctiva_path = None
|
| skin_path = None
|
| cry_path = None
|
|
|
| if conjunctiva_file:
|
| conjunctiva_path = _save_upload_to_temp(conjunctiva_file, ".jpg")
|
| _temp_paths.append(conjunctiva_path)
|
|
|
| if skin_file:
|
| skin_path = _save_upload_to_temp(skin_file, ".jpg")
|
| _temp_paths.append(skin_path)
|
|
|
| if cry_file:
|
| cry_path = _save_upload_to_temp(cry_file, ".wav")
|
| _temp_paths.append(cry_path)
|
|
|
|
|
| signs = [
|
| DangerSign(
|
| id=s["id"], label=s["label"],
|
| severity=s["severity"], present=True,
|
| )
|
| for s in danger_signs
|
| ]
|
|
|
| info = AgentPatientInfo(patient_type=patient_type)
|
| workflow_input = WorkflowInput(
|
| patient_type=patient_type,
|
| patient_info=info,
|
| danger_signs=signs,
|
| conjunctiva_image=conjunctiva_path,
|
| skin_image=skin_path,
|
| cry_audio=cry_path,
|
| )
|
|
|
|
|
| anemia_det, _ = load_anemia_detector()
|
| jaundice_det, _ = load_jaundice_detector()
|
| cry_ana, _ = load_cry_analyzer()
|
| synth, _ = load_clinical_synthesizer()
|
|
|
| engine = AgenticWorkflowEngine(
|
| anemia_detector=anemia_det,
|
| jaundice_detector=jaundice_det,
|
| cry_analyzer=cry_ana,
|
| synthesizer=synth,
|
| )
|
| result = engine.execute(workflow_input)
|
|
|
| st.session_state["agentic_result"] = result
|
| st.success("Workflow complete!")
|
|
|
| except Exception as e:
|
| st.error(f"Workflow error: {e}")
|
| finally:
|
| for p in _temp_paths:
|
| _cleanup_temp(p)
|
|
|
|
|
| if "agentic_result" in st.session_state:
|
| result = st.session_state["agentic_result"]
|
|
|
| st.markdown("---")
|
|
|
|
|
| severity_colors = {
|
| "GREEN": ("#d4edda", "#155724", "Routine care"),
|
| "YELLOW": ("#fff3cd", "#856404", "Close monitoring"),
|
| "RED": ("#f8d7da", "#721c24", "Urgent referral"),
|
| }
|
| bg, fg, desc = severity_colors.get(result.who_classification, ("#f8f9fa", "#000", "Unknown"))
|
|
|
| st.markdown(f"""
|
| <div style="background: {bg}; color: {fg}; padding: 1.5rem; border-radius: 10px; text-align: center; margin: 1rem 0;">
|
| <h2 style="margin: 0;">WHO Classification: {result.who_classification}</h2>
|
| <p style="margin: 0.5rem 0 0 0; font-size: 1.1rem;">{desc}</p>
|
| </div>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| m1, m2, m3, m4 = st.columns(4)
|
| with m1:
|
| st.metric("Agents Run", len(result.agent_traces))
|
| with m2:
|
| st.metric("Total Time", f"{result.processing_time_ms:.0f} ms")
|
| with m3:
|
| referral_text = "Yes" if (result.referral_result and result.referral_result.referral_needed) else "No"
|
| st.metric("Referral Needed", referral_text)
|
| with m4:
|
| triage_score = result.triage_result.score if result.triage_result else 0
|
| st.metric("Triage Score", triage_score)
|
|
|
|
|
| st.subheader("Clinical Synthesis")
|
| st.info(result.clinical_synthesis)
|
|
|
| if result.immediate_actions:
|
| st.subheader("Immediate Actions")
|
| for action in result.immediate_actions:
|
| st.markdown(f"- {action}")
|
|
|
|
|
| st.markdown("---")
|
| st.subheader("Agent Pipeline Execution")
|
|
|
| agent_meta = {
|
| "TriageAgent": {"color": "#1976d2", "bg": "#e3f2fd", "icon": "1", "label": "Triage"},
|
| "ImageAnalysisAgent": {"color": "#388e3c", "bg": "#e8f5e9", "icon": "2", "label": "Image (MedSigLIP)"},
|
| "AudioAnalysisAgent": {"color": "#f57c00", "bg": "#fff3e0", "icon": "3", "label": "Audio (HeAR)"},
|
| "ProtocolAgent": {"color": "#7b1fa2", "bg": "#f3e5f5", "icon": "4", "label": "WHO Protocol"},
|
| "ReferralAgent": {"color": "#c62828", "bg": "#fce4ec", "icon": "5", "label": "Referral"},
|
| "SynthesisAgent": {"color": "#00838f", "bg": "#e0f7fa", "icon": "6", "label": "Synthesis (MedGemma)"},
|
| }
|
| status_symbols = {"success": "OK", "skipped": "SKIP", "error": "ERR"}
|
|
|
|
|
| trace_lookup = {t.agent_name: t for t in result.agent_traces}
|
|
|
|
|
| pipeline_html_parts = []
|
| for agent_name, meta in agent_meta.items():
|
| trace = trace_lookup.get(agent_name)
|
| if trace:
|
| status_sym = status_symbols.get(trace.status, "?")
|
| opacity = "1.0" if trace.status == "success" else "0.5"
|
| border_style = f"3px solid {meta['color']}" if trace.status == "success" else "2px dashed #999"
|
| time_label = f"{trace.processing_time_ms:.0f}ms"
|
| else:
|
| status_sym = "---"
|
| opacity = "0.3"
|
| border_style = "2px dashed #ccc"
|
| time_label = ""
|
|
|
| pipeline_html_parts.append(f"""
|
| <div style="background: {meta['bg']}; padding: 0.4rem 0.7rem; border-radius: 8px;
|
| border: {border_style}; opacity: {opacity}; text-align: center; min-width: 90px;">
|
| <div style="font-weight: bold; font-size: 0.8rem; color: {meta['color']};">{meta['label']}</div>
|
| <div style="font-size: 0.7rem; color: #666;">{status_sym} {time_label}</div>
|
| </div>
|
| """)
|
|
|
| pipeline_html = '<div style="display: flex; align-items: center; justify-content: center; gap: 0.3rem; flex-wrap: wrap; margin: 0.5rem 0;">'
|
| for i, part in enumerate(pipeline_html_parts):
|
| pipeline_html += part
|
| if i < len(pipeline_html_parts) - 1:
|
| pipeline_html += '<span style="font-size: 1.2rem; color: #999;">→</span>'
|
| pipeline_html += "</div>"
|
| st.markdown(pipeline_html, unsafe_allow_html=True)
|
|
|
|
|
| st.markdown("---")
|
| st.subheader("Agent Reasoning Traces")
|
|
|
| for trace in result.agent_traces:
|
| meta = agent_meta.get(trace.agent_name, {"color": "#666", "bg": "#f5f5f5", "label": trace.agent_name})
|
| status_emoji = {"success": "OK", "skipped": "SKIP", "error": "ERR"}.get(trace.status, "?")
|
|
|
| header_label = f"{meta['label']} [{status_emoji}] - {trace.confidence:.0%} confidence - {trace.processing_time_ms:.0f}ms"
|
| with st.expander(header_label, expanded=(trace.status == "success")):
|
|
|
| st.markdown(f"""
|
| <div style="background: {meta['bg']}; padding: 0.8rem 1rem; border-radius: 8px;
|
| border-left: 4px solid {meta['color']}; margin-bottom: 0.5rem;">
|
| <strong style="color: {meta['color']};">{trace.agent_name}</strong> |
|
| Status: <strong>{trace.status}</strong> |
|
| Confidence: <strong>{trace.confidence:.1%}</strong> |
|
| Time: <strong>{trace.processing_time_ms:.1f}ms</strong>
|
| </div>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| if trace.reasoning:
|
| st.markdown("**Reasoning Chain:**")
|
| for i, step in enumerate(trace.reasoning, 1):
|
| st.markdown(f"**Step {i}.** {step}")
|
|
|
|
|
| if trace.findings:
|
| st.markdown("**Key Findings:**")
|
| st.json(trace.findings)
|
|
|
|
|
| st.markdown("---")
|
| col_chart, col_summary = st.columns([2, 1])
|
|
|
| with col_chart:
|
| st.subheader("Processing Time by Agent")
|
| import pandas as pd
|
| chart_data = pd.DataFrame({
|
| "Agent": [agent_meta.get(t.agent_name, {}).get("label", t.agent_name) for t in result.agent_traces],
|
| "Time (ms)": [t.processing_time_ms for t in result.agent_traces],
|
| })
|
| st.bar_chart(chart_data.set_index("Agent"))
|
|
|
| with col_summary:
|
| st.subheader("Workflow Summary")
|
| total_time = result.processing_time_ms
|
| successful = sum(1 for t in result.agent_traces if t.status == "success")
|
| skipped = sum(1 for t in result.agent_traces if t.status == "skipped")
|
| errors = sum(1 for t in result.agent_traces if t.status == "error")
|
| st.markdown(f"""
|
| | Metric | Value |
|
| |--------|-------|
|
| | Total agents | {len(result.agent_traces)} |
|
| | Successful | {successful} |
|
| | Skipped | {skipped} |
|
| | Errors | {errors} |
|
| | Total time | {total_time:.0f} ms |
|
| | Avg per agent | {total_time / max(len(result.agent_traces), 1):.0f} ms |
|
| """)
|
|
|
|
|
| if result.referral_result and result.referral_result.referral_needed:
|
| st.markdown("---")
|
| st.subheader("Referral Details")
|
| ref = result.referral_result
|
| r1, r2, r3 = st.columns(3)
|
| with r1:
|
| st.metric("Urgency", ref.urgency.upper())
|
| with r2:
|
| st.metric("Facility", ref.facility_level.title())
|
| with r3:
|
| st.metric("Timeframe", ref.timeframe)
|
| st.warning(f"Reason: {ref.reason}")
|
|
|
|
|
|
|
| def render_footer():
|
| """Render footer."""
|
| st.markdown("---")
|
| st.markdown("""
|
| <div style="text-align: center; color: #666; font-size: 0.9rem;">
|
| <p>NEXUS - Built with Google HAI-DEF for MedGemma Impact Challenge 2026</p>
|
| <p>⚠️ This is a screening tool only. Always confirm with laboratory tests.</p>
|
| </div>
|
| """, unsafe_allow_html=True)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
| render_footer()
|
|
|