Md Shahabul Alam commited on
Commit ·
29db30b
1
Parent(s): 865fb68
Deploy NEXUS Streamlit demo to HuggingFace Spaces
Browse files- Dockerfile +43 -0
- README.md +189 -7
- app.py +57 -0
- models/linear_probes/anemia_classifier_metadata.json +30 -0
- models/linear_probes/bilirubin_regression_results.json +207 -0
- models/linear_probes/cry_classifier_metadata.json +31 -0
- models/linear_probes/jaundice_classifier_metadata.json +31 -0
- models/linear_probes/linear_probe_results.json +18 -0
- requirements_spaces.txt +31 -0
- src/demo/__init__.py +0 -0
- src/demo/streamlit_app.py +1189 -0
- src/nexus/__init__.py +10 -0
- src/nexus/agentic_workflow.py +1296 -0
- src/nexus/anemia_detector.py +580 -0
- src/nexus/clinical_synthesizer.py +548 -0
- src/nexus/cry_analyzer.py +662 -0
- src/nexus/hear_preprocessing.py +320 -0
- src/nexus/jaundice_detector.py +716 -0
- src/nexus/pipeline.py +663 -0
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace Spaces Docker SDK — NEXUS Streamlit Demo
|
| 2 |
+
# Docs: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 3 |
+
|
| 4 |
+
FROM python:3.12-slim
|
| 5 |
+
|
| 6 |
+
# Create non-root user (required by HF Spaces)
|
| 7 |
+
RUN useradd -m -u 1000 user
|
| 8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 9 |
+
|
| 10 |
+
# Install system dependencies for audio processing
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
libsndfile1 \
|
| 13 |
+
ffmpeg \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Copy requirements and install as user
|
| 19 |
+
COPY --chown=user ./requirements_spaces.txt requirements_spaces.txt
|
| 20 |
+
RUN pip install --no-cache-dir --upgrade -r requirements_spaces.txt
|
| 21 |
+
|
| 22 |
+
# Switch to non-root user
|
| 23 |
+
USER user
|
| 24 |
+
|
| 25 |
+
# Copy source code
|
| 26 |
+
COPY --chown=user ./src/ src/
|
| 27 |
+
COPY --chown=user ./models/ models/
|
| 28 |
+
COPY --chown=user ./app.py .
|
| 29 |
+
|
| 30 |
+
# Set environment
|
| 31 |
+
ENV PYTHONPATH=/app/src
|
| 32 |
+
ENV STREAMLIT_SERVER_PORT=7860
|
| 33 |
+
ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 34 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
| 35 |
+
ENV STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
|
| 36 |
+
|
| 37 |
+
EXPOSE 7860
|
| 38 |
+
|
| 39 |
+
CMD ["python", "-m", "streamlit", "run", "src/demo/streamlit_app.py", \
|
| 40 |
+
"--server.port=7860", \
|
| 41 |
+
"--server.address=0.0.0.0", \
|
| 42 |
+
"--server.headless=true", \
|
| 43 |
+
"--browser.gatherUsageStats=false"]
|
README.md
CHANGED
|
@@ -1,12 +1,194 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: NEXUS
|
| 3 |
+
emoji: "\U0001FA7A"
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: true
|
| 9 |
+
license: cc-by-4.0
|
| 10 |
+
tags:
|
| 11 |
+
- medgemma
|
| 12 |
+
- medical-ai
|
| 13 |
+
- hai-def
|
| 14 |
+
- maternal-health
|
| 15 |
+
- neonatal-care
|
| 16 |
---
|
| 17 |
|
| 18 |
+
# NEXUS - AI-Powered Maternal-Neonatal Assessment Platform
|
| 19 |
+
|
| 20 |
+
> Non-invasive screening for maternal anemia, neonatal jaundice, and birth asphyxia using Google HAI-DEF models
|
| 21 |
+
|
| 22 |
+
[](https://creativecommons.org/licenses/by/4.0/)
|
| 23 |
+
[](https://www.kaggle.com/competitions/med-gemma-impact-challenge)
|
| 24 |
+
|
| 25 |
+
## Overview
|
| 26 |
+
|
| 27 |
+
NEXUS transforms smartphones into diagnostic screening tools for Community Health Workers in low-resource settings. Using 3 Google HAI-DEF models in a 6-agent clinical workflow, it provides non-invasive assessment for:
|
| 28 |
+
|
| 29 |
+
- **Maternal anemia** from conjunctiva images (MedSigLIP)
|
| 30 |
+
- **Neonatal jaundice** from skin images with bilirubin regression (MedSigLIP)
|
| 31 |
+
- **Birth asphyxia** from cry audio analysis (HeAR)
|
| 32 |
+
- **Clinical synthesis** with WHO IMNCI protocol alignment (MedGemma)
|
| 33 |
+
|
| 34 |
+
## HAI-DEF Models
|
| 35 |
+
|
| 36 |
+
| Model | HuggingFace ID | Purpose |
|
| 37 |
+
|-------|----------------|---------|
|
| 38 |
+
| **MedSigLIP** | `google/medsiglip-448` | Anemia + jaundice detection, bilirubin regression |
|
| 39 |
+
| **HeAR** | `google/hear-pytorch` | Cry audio analysis for birth asphyxia |
|
| 40 |
+
| **MedGemma 4B** | `google/medgemma-4b-it` | Clinical reasoning and synthesis |
|
| 41 |
+
|
| 42 |
+
## Architecture
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
6-Agent Clinical Workflow:
|
| 46 |
+
Triage -> Image Analysis (MedSigLIP) -> Audio Analysis (HeAR)
|
| 47 |
+
-> WHO Protocol -> Referral Decision -> Clinical Synthesis (MedGemma)
|
| 48 |
+
|
| 49 |
+
Each agent produces structured reasoning traces for a full audit trail.
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Quick Start
|
| 53 |
+
|
| 54 |
+
### Prerequisites
|
| 55 |
+
- Python 3.10+
|
| 56 |
+
- HuggingFace token (for gated HAI-DEF models)
|
| 57 |
+
|
| 58 |
+
### Setup
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
# Clone and install
|
| 62 |
+
git clone <repo-url>
|
| 63 |
+
cd nexus
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
|
| 66 |
+
# Set HuggingFace token (required for MedSigLIP, MedGemma)
|
| 67 |
+
export HF_TOKEN=hf_your_token_here
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### Run the Demo
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
# Streamlit interactive demo
|
| 74 |
+
PYTHONPATH=src streamlit run src/demo/streamlit_app.py
|
| 75 |
+
|
| 76 |
+
# FastAPI backend
|
| 77 |
+
PYTHONPATH=src uvicorn api.main:app --reload
|
| 78 |
+
|
| 79 |
+
# Run tests
|
| 80 |
+
PYTHONPATH=src python -m pytest tests/ -v
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Train Models
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# Train linear probes (anemia + jaundice classifiers)
|
| 87 |
+
PYTHONPATH=src python scripts/training/train_linear_probes.py
|
| 88 |
+
|
| 89 |
+
# Train bilirubin regression head
|
| 90 |
+
PYTHONPATH=src python scripts/training/finetune_bilirubin_regression.py
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
### HuggingFace Spaces
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
# Local test of HF Spaces entry point
|
| 97 |
+
python app.py
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Project Structure
|
| 101 |
+
|
| 102 |
+
```
|
| 103 |
+
nexus/
|
| 104 |
+
├── src/nexus/ # Core platform
|
| 105 |
+
│ ├── anemia_detector.py # MedSigLIP anemia detection
|
| 106 |
+
│ ├── jaundice_detector.py # MedSigLIP jaundice + bilirubin regression
|
| 107 |
+
│ ├── cry_analyzer.py # HeAR cry analysis
|
| 108 |
+
│ ├── clinical_synthesizer.py # MedGemma clinical synthesis
|
| 109 |
+
│ ├── agentic_workflow.py # 6-agent workflow engine
|
| 110 |
+
│ └── pipeline.py # Unified assessment pipeline
|
| 111 |
+
├── src/demo/streamlit_app.py # Interactive Streamlit demo
|
| 112 |
+
├── api/main.py # FastAPI backend
|
| 113 |
+
├── scripts/
|
| 114 |
+
│ ├── training/
|
| 115 |
+
│ │ ├── train_linear_probes.py # MedSigLIP embedding classifiers
|
| 116 |
+
│ │ ├── finetune_bilirubin_regression.py # Novel bilirubin regression
|
| 117 |
+
│ │ ├── train_anemia.py # Anemia-specific training
|
| 118 |
+
│ │ ├── train_jaundice.py # Jaundice-specific training
|
| 119 |
+
│ │ └── train_cry.py # Cry classifier training
|
| 120 |
+
│ └── edge/
|
| 121 |
+
│ ├── quantize_models.py # INT8 quantization
|
| 122 |
+
│ └── export_embeddings.py # Pre-computed text embeddings
|
| 123 |
+
├── notebooks/
|
| 124 |
+
│ ├── 01_anemia_detection.ipynb
|
| 125 |
+
│ ├── 02_jaundice_detection.ipynb
|
| 126 |
+
│ ├── 03_cry_analysis.ipynb
|
| 127 |
+
│ └── 04_bilirubin_regression.ipynb # Novel task reproducibility
|
| 128 |
+
├── tests/
|
| 129 |
+
│ ├── test_pipeline.py # Pipeline tests
|
| 130 |
+
│ ├── test_agentic_workflow.py # Agentic workflow tests (41 tests)
|
| 131 |
+
│ └── test_hai_def_integration.py # HAI-DEF model compliance
|
| 132 |
+
├── models/
|
| 133 |
+
│ ├── linear_probes/ # Trained classifiers + regressor
|
| 134 |
+
│ └── edge/ # Quantized models + embeddings
|
| 135 |
+
├── data/
|
| 136 |
+
│ ├── raw/ # Raw datasets (Eyes-Defy-Anemia, NeoJaundice, CryCeleb)
|
| 137 |
+
│ └── protocols/ # WHO IMNCI protocols
|
| 138 |
+
├── submission/
|
| 139 |
+
│ ├── writeup.md # Competition writeup (3 pages)
|
| 140 |
+
│ └── video/ # Demo video script and assets
|
| 141 |
+
├── app.py # HuggingFace Spaces entry point
|
| 142 |
+
├── requirements.txt # Full dependencies
|
| 143 |
+
└── requirements_spaces.txt # HF Spaces minimal dependencies
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## Key Results
|
| 147 |
+
|
| 148 |
+
| Task | Method | Performance |
|
| 149 |
+
|------|--------|-------------|
|
| 150 |
+
| Anemia zero-shot | MedSigLIP (max-similarity, 8 prompts/class) | Screening capability |
|
| 151 |
+
| Jaundice classification | MedSigLIP linear probe | 68.9% accuracy |
|
| 152 |
+
| **Bilirubin regression** | **MedSigLIP + MLP head** | **MAE: 2.667 mg/dL, r=0.77** |
|
| 153 |
+
| Cry analysis | HeAR + acoustic features | Qualitative assessment |
|
| 154 |
+
| Clinical synthesis | MedGemma + WHO IMNCI | Protocol-aligned recommendations |
|
| 155 |
+
|
| 156 |
+
### Novel Task: Bilirubin Regression
|
| 157 |
+
Frozen MedSigLIP embeddings -> 2-layer MLP -> continuous bilirubin (mg/dL) prediction.
|
| 158 |
+
Trained on 2,235 NeoJaundice images with ground truth serum bilirubin.
|
| 159 |
+
**MAE: 2.667 mg/dL, Pearson r: 0.7725 (p < 1e-67)**
|
| 160 |
+
|
| 161 |
+
### Edge AI
|
| 162 |
+
- INT8 dynamic quantization: 812.6 MB -> 111.2 MB (7.31x compression)
|
| 163 |
+
- Pre-computed text embeddings: 12 KB (no text encoder on device)
|
| 164 |
+
- Total on-device: ~289 MB
|
| 165 |
+
|
| 166 |
+
## Competition Tracks
|
| 167 |
+
|
| 168 |
+
- **Main Track**: Comprehensive maternal-neonatal assessment platform
|
| 169 |
+
- **Agentic Workflow Prize**: 6-agent pipeline with reasoning traces and audit trail
|
| 170 |
+
|
| 171 |
+
## Tests
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
# All tests
|
| 175 |
+
PYTHONPATH=src python -m pytest tests/ -v
|
| 176 |
+
|
| 177 |
+
# Agentic workflow only (41 tests)
|
| 178 |
+
PYTHONPATH=src python -m pytest tests/test_agentic_workflow.py -v
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## License
|
| 182 |
+
|
| 183 |
+
[CC BY 4.0](LICENSE)
|
| 184 |
+
|
| 185 |
+
## Acknowledgments
|
| 186 |
+
|
| 187 |
+
- Google Health AI Developer Foundations team
|
| 188 |
+
- NeoJaundice dataset (Figshare)
|
| 189 |
+
- Eyes-Defy-Anemia dataset (Kaggle)
|
| 190 |
+
- WHO IMNCI protocol guidelines
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
Built with Google HAI-DEF for the MedGemma Impact Challenge 2026
|
app.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NEXUS - HuggingFace Spaces Entry Point
|
| 3 |
+
|
| 4 |
+
Launches the Streamlit demo for the NEXUS Maternal-Neonatal Care Platform.
|
| 5 |
+
Built with Google HAI-DEF models for the MedGemma Impact Challenge 2026.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Ensure src/ is on the Python path for imports
|
| 14 |
+
ROOT = Path(__file__).parent
|
| 15 |
+
SRC_DIR = ROOT / "src"
|
| 16 |
+
if str(SRC_DIR) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(SRC_DIR))
|
| 18 |
+
|
| 19 |
+
# Set environment defaults for HF Spaces
|
| 20 |
+
os.environ.setdefault("STREAMLIT_SERVER_PORT", "7860")
|
| 21 |
+
os.environ.setdefault("STREAMLIT_SERVER_ADDRESS", "0.0.0.0")
|
| 22 |
+
os.environ.setdefault("STREAMLIT_SERVER_HEADLESS", "true")
|
| 23 |
+
os.environ.setdefault("STREAMLIT_BROWSER_GATHER_USAGE_STATS", "false")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
app_path = SRC_DIR / "demo" / "streamlit_app.py"
|
| 28 |
+
if not app_path.exists():
|
| 29 |
+
print(f"ERROR: Streamlit app not found at {app_path}")
|
| 30 |
+
sys.exit(1)
|
| 31 |
+
|
| 32 |
+
port = os.environ.get("PORT", os.environ["STREAMLIT_SERVER_PORT"])
|
| 33 |
+
os.environ["STREAMLIT_SERVER_PORT"] = str(port)
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
subprocess.run(
|
| 37 |
+
[
|
| 38 |
+
sys.executable, "-m", "streamlit", "run",
|
| 39 |
+
str(app_path),
|
| 40 |
+
f"--server.port={port}",
|
| 41 |
+
f"--server.address={os.environ['STREAMLIT_SERVER_ADDRESS']}",
|
| 42 |
+
f"--server.headless={os.environ['STREAMLIT_SERVER_HEADLESS']}",
|
| 43 |
+
f"--browser.gatherUsageStats={os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS']}",
|
| 44 |
+
],
|
| 45 |
+
check=True,
|
| 46 |
+
env={**os.environ, "PYTHONPATH": str(SRC_DIR)},
|
| 47 |
+
)
|
| 48 |
+
except subprocess.CalledProcessError as e:
|
| 49 |
+
print(f"ERROR: Streamlit process exited with code {e.returncode}")
|
| 50 |
+
sys.exit(e.returncode)
|
| 51 |
+
except FileNotFoundError:
|
| 52 |
+
print("ERROR: Streamlit not installed. Run: pip install streamlit")
|
| 53 |
+
sys.exit(1)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
main()
|
models/linear_probes/anemia_classifier_metadata.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "SVM_RBF",
|
| 3 |
+
"embedding_source": "MedSigLIP (google/medsiglip-448)",
|
| 4 |
+
"embedding_dim": 1152,
|
| 5 |
+
"num_classes": 2,
|
| 6 |
+
"classes": {
|
| 7 |
+
"healthy": 0,
|
| 8 |
+
"anemic": 1
|
| 9 |
+
},
|
| 10 |
+
"cv_accuracy_mean": 0.9994269340974211,
|
| 11 |
+
"cv_accuracy_std": 0.0011461318051575909,
|
| 12 |
+
"num_original_samples": 218,
|
| 13 |
+
"num_augmented_samples": 1744,
|
| 14 |
+
"augmentations_per_image": 7,
|
| 15 |
+
"all_results": {
|
| 16 |
+
"LogisticRegression": {
|
| 17 |
+
"mean_accuracy": 0.8985096993050752,
|
| 18 |
+
"std_accuracy": 0.008415256920621202
|
| 19 |
+
},
|
| 20 |
+
"SVM_RBF": {
|
| 21 |
+
"mean_accuracy": 0.9994269340974211,
|
| 22 |
+
"std_accuracy": 0.0011461318051575909
|
| 23 |
+
},
|
| 24 |
+
"SVM_Linear": {
|
| 25 |
+
"mean_accuracy": 0.8899186509896915,
|
| 26 |
+
"std_accuracy": 0.011746435929843532
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"seed": 42
|
| 30 |
+
}
|
models/linear_probes/bilirubin_regression_results.json
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mae": 2.564,
|
| 3 |
+
"rmse": 3.416,
|
| 4 |
+
"pearson_r": 0.7783,
|
| 5 |
+
"pearson_p": 1.7171921789198235e-69,
|
| 6 |
+
"bland_altman": {
|
| 7 |
+
"mean_diff": -0.506,
|
| 8 |
+
"std_diff": 3.379,
|
| 9 |
+
"loa_upper": 6.116,
|
| 10 |
+
"loa_lower": -7.129
|
| 11 |
+
},
|
| 12 |
+
"test_size": 336,
|
| 13 |
+
"train_size": 1563,
|
| 14 |
+
"val_size": 336,
|
| 15 |
+
"input_dim": 1152,
|
| 16 |
+
"hidden_dim": 256,
|
| 17 |
+
"epochs_trained": 58,
|
| 18 |
+
"best_val_loss": 3.7143,
|
| 19 |
+
"bilirubin_range": {
|
| 20 |
+
"min": 0.0,
|
| 21 |
+
"max": 25.7,
|
| 22 |
+
"mean": 11.2,
|
| 23 |
+
"std": 5.2
|
| 24 |
+
},
|
| 25 |
+
"history": {
|
| 26 |
+
"train_loss": [
|
| 27 |
+
17.543,
|
| 28 |
+
12.0442,
|
| 29 |
+
6.9412,
|
| 30 |
+
4.2175,
|
| 31 |
+
3.5428,
|
| 32 |
+
3.4781,
|
| 33 |
+
3.0782,
|
| 34 |
+
2.8347,
|
| 35 |
+
2.7914,
|
| 36 |
+
2.5293,
|
| 37 |
+
2.393,
|
| 38 |
+
2.2627,
|
| 39 |
+
2.1357,
|
| 40 |
+
2.1498,
|
| 41 |
+
1.875,
|
| 42 |
+
2.0569,
|
| 43 |
+
1.843,
|
| 44 |
+
1.7077,
|
| 45 |
+
1.7084,
|
| 46 |
+
1.6893,
|
| 47 |
+
1.7543,
|
| 48 |
+
2.0793,
|
| 49 |
+
2.1218,
|
| 50 |
+
2.1285,
|
| 51 |
+
2.0992,
|
| 52 |
+
1.9611,
|
| 53 |
+
1.93,
|
| 54 |
+
1.8854,
|
| 55 |
+
1.9694,
|
| 56 |
+
1.6901,
|
| 57 |
+
1.699,
|
| 58 |
+
1.7061,
|
| 59 |
+
1.5767,
|
| 60 |
+
1.6265,
|
| 61 |
+
1.5394,
|
| 62 |
+
1.4675,
|
| 63 |
+
1.3684,
|
| 64 |
+
1.4486,
|
| 65 |
+
1.2866,
|
| 66 |
+
1.3152,
|
| 67 |
+
1.2613,
|
| 68 |
+
1.1721,
|
| 69 |
+
1.1946,
|
| 70 |
+
1.2039,
|
| 71 |
+
1.1949,
|
| 72 |
+
1.129,
|
| 73 |
+
1.0557,
|
| 74 |
+
1.0699,
|
| 75 |
+
1.0325,
|
| 76 |
+
1.0427,
|
| 77 |
+
1.0431,
|
| 78 |
+
1.0722,
|
| 79 |
+
1.0071,
|
| 80 |
+
1.0187,
|
| 81 |
+
0.8847,
|
| 82 |
+
0.9988,
|
| 83 |
+
0.942,
|
| 84 |
+
0.9464
|
| 85 |
+
],
|
| 86 |
+
"val_loss": [
|
| 87 |
+
18.4316,
|
| 88 |
+
13.9118,
|
| 89 |
+
6.9486,
|
| 90 |
+
4.5588,
|
| 91 |
+
5.5443,
|
| 92 |
+
4.184,
|
| 93 |
+
4.8748,
|
| 94 |
+
4.0967,
|
| 95 |
+
4.0286,
|
| 96 |
+
4.1705,
|
| 97 |
+
4.0592,
|
| 98 |
+
3.921,
|
| 99 |
+
4.1161,
|
| 100 |
+
4.0279,
|
| 101 |
+
3.9931,
|
| 102 |
+
3.8783,
|
| 103 |
+
3.8742,
|
| 104 |
+
3.8394,
|
| 105 |
+
3.949,
|
| 106 |
+
3.8805,
|
| 107 |
+
3.8673,
|
| 108 |
+
3.9437,
|
| 109 |
+
4.1339,
|
| 110 |
+
4.3688,
|
| 111 |
+
4.5384,
|
| 112 |
+
4.0601,
|
| 113 |
+
3.9022,
|
| 114 |
+
3.7252,
|
| 115 |
+
3.9551,
|
| 116 |
+
3.9791,
|
| 117 |
+
3.7946,
|
| 118 |
+
4.0627,
|
| 119 |
+
3.815,
|
| 120 |
+
4.0698,
|
| 121 |
+
4.0345,
|
| 122 |
+
3.9504,
|
| 123 |
+
3.8177,
|
| 124 |
+
3.8626,
|
| 125 |
+
3.8044,
|
| 126 |
+
3.7743,
|
| 127 |
+
3.8432,
|
| 128 |
+
3.8456,
|
| 129 |
+
3.7143,
|
| 130 |
+
3.8196,
|
| 131 |
+
3.8955,
|
| 132 |
+
3.7218,
|
| 133 |
+
3.7605,
|
| 134 |
+
3.7768,
|
| 135 |
+
3.7581,
|
| 136 |
+
3.7667,
|
| 137 |
+
3.7499,
|
| 138 |
+
3.7481,
|
| 139 |
+
3.7286,
|
| 140 |
+
3.7502,
|
| 141 |
+
3.7814,
|
| 142 |
+
3.734,
|
| 143 |
+
3.7887,
|
| 144 |
+
3.7414
|
| 145 |
+
],
|
| 146 |
+
"val_mae": [
|
| 147 |
+
10.19,
|
| 148 |
+
7.908,
|
| 149 |
+
4.388,
|
| 150 |
+
3.118,
|
| 151 |
+
3.652,
|
| 152 |
+
2.965,
|
| 153 |
+
3.299,
|
| 154 |
+
2.901,
|
| 155 |
+
2.884,
|
| 156 |
+
2.947,
|
| 157 |
+
2.876,
|
| 158 |
+
2.814,
|
| 159 |
+
2.93,
|
| 160 |
+
2.866,
|
| 161 |
+
2.854,
|
| 162 |
+
2.792,
|
| 163 |
+
2.794,
|
| 164 |
+
2.77,
|
| 165 |
+
2.836,
|
| 166 |
+
2.798,
|
| 167 |
+
2.787,
|
| 168 |
+
2.814,
|
| 169 |
+
2.931,
|
| 170 |
+
3.052,
|
| 171 |
+
3.148,
|
| 172 |
+
2.89,
|
| 173 |
+
2.803,
|
| 174 |
+
2.691,
|
| 175 |
+
2.83,
|
| 176 |
+
2.837,
|
| 177 |
+
2.737,
|
| 178 |
+
2.884,
|
| 179 |
+
2.749,
|
| 180 |
+
2.901,
|
| 181 |
+
2.874,
|
| 182 |
+
2.829,
|
| 183 |
+
2.761,
|
| 184 |
+
2.778,
|
| 185 |
+
2.734,
|
| 186 |
+
2.721,
|
| 187 |
+
2.761,
|
| 188 |
+
2.774,
|
| 189 |
+
2.692,
|
| 190 |
+
2.749,
|
| 191 |
+
2.803,
|
| 192 |
+
2.699,
|
| 193 |
+
2.714,
|
| 194 |
+
2.719,
|
| 195 |
+
2.704,
|
| 196 |
+
2.716,
|
| 197 |
+
2.717,
|
| 198 |
+
2.704,
|
| 199 |
+
2.699,
|
| 200 |
+
2.711,
|
| 201 |
+
2.731,
|
| 202 |
+
2.701,
|
| 203 |
+
2.736,
|
| 204 |
+
2.706
|
| 205 |
+
]
|
| 206 |
+
}
|
| 207 |
+
}
|
models/linear_probes/cry_classifier_metadata.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "SVM_RBF",
|
| 3 |
+
"embedding_source": "HeAR (google/hear-pytorch)",
|
| 4 |
+
"embedding_dim": 512,
|
| 5 |
+
"num_classes": 5,
|
| 6 |
+
"classes": {
|
| 7 |
+
"belly_pain": 0,
|
| 8 |
+
"burping": 1,
|
| 9 |
+
"discomfort": 2,
|
| 10 |
+
"hungry": 3,
|
| 11 |
+
"tired": 4
|
| 12 |
+
},
|
| 13 |
+
"cv_accuracy_mean": 0.8380793119923554,
|
| 14 |
+
"cv_accuracy_std": 0.008077431438521396,
|
| 15 |
+
"num_samples": 457,
|
| 16 |
+
"all_results": {
|
| 17 |
+
"LogisticRegression": {
|
| 18 |
+
"mean_accuracy": 0.7985905398948876,
|
| 19 |
+
"std_accuracy": 0.028055714127978745
|
| 20 |
+
},
|
| 21 |
+
"SVM_RBF": {
|
| 22 |
+
"mean_accuracy": 0.8380793119923554,
|
| 23 |
+
"std_accuracy": 0.008077431438521396
|
| 24 |
+
},
|
| 25 |
+
"SVM_Linear": {
|
| 26 |
+
"mean_accuracy": 0.765862398471094,
|
| 27 |
+
"std_accuracy": 0.013071624843302853
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"seed": 42
|
| 31 |
+
}
|
models/linear_probes/jaundice_classifier_metadata.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "SVM_RBF",
|
| 3 |
+
"embedding_source": "MedSigLIP (google/medsiglip-448)",
|
| 4 |
+
"embedding_dim": 1152,
|
| 5 |
+
"num_classes": 2,
|
| 6 |
+
"classes": {
|
| 7 |
+
"normal": 0,
|
| 8 |
+
"jaundice": 1
|
| 9 |
+
},
|
| 10 |
+
"bilirubin_threshold": 5.0,
|
| 11 |
+
"cv_accuracy_mean": 0.967337807606264,
|
| 12 |
+
"cv_accuracy_std": 0.002197637886396911,
|
| 13 |
+
"num_original_samples": 2235,
|
| 14 |
+
"num_augmented_samples": 8940,
|
| 15 |
+
"augmentations_per_image": 3,
|
| 16 |
+
"all_results": {
|
| 17 |
+
"LogisticRegression": {
|
| 18 |
+
"mean_accuracy": 0.9422818791946309,
|
| 19 |
+
"std_accuracy": 0.004750953150245027
|
| 20 |
+
},
|
| 21 |
+
"SVM_RBF": {
|
| 22 |
+
"mean_accuracy": 0.967337807606264,
|
| 23 |
+
"std_accuracy": 0.002197637886396911
|
| 24 |
+
},
|
| 25 |
+
"SVM_Linear": {
|
| 26 |
+
"mean_accuracy": 0.9322147651006712,
|
| 27 |
+
"std_accuracy": 0.006743027683714353
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"seed": 42
|
| 31 |
+
}
|
models/linear_probes/linear_probe_results.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"anemia": {
|
| 3 |
+
"accuracy": 0.5227272727272727,
|
| 4 |
+
"precision": 0.5185185185185185,
|
| 5 |
+
"recall": 0.6363636363636364,
|
| 6 |
+
"f1": 0.5714285714285714,
|
| 7 |
+
"train_size": 174,
|
| 8 |
+
"test_size": 44
|
| 9 |
+
},
|
| 10 |
+
"jaundice": {
|
| 11 |
+
"accuracy": 0.6957494407158836,
|
| 12 |
+
"precision": 0.6854460093896714,
|
| 13 |
+
"recall": 0.6790697674418604,
|
| 14 |
+
"f1": 0.6822429906542056,
|
| 15 |
+
"train_size": 1788,
|
| 16 |
+
"test_size": 447
|
| 17 |
+
}
|
| 18 |
+
}
|
requirements_spaces.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NEXUS - HuggingFace Spaces Dependencies
|
| 2 |
+
# Minimal set for Streamlit demo deployment (CPU)
|
| 3 |
+
|
| 4 |
+
torch>=2.1.0
|
| 5 |
+
transformers>=4.44.0
|
| 6 |
+
accelerate>=0.25.0
|
| 7 |
+
safetensors>=0.4.0
|
| 8 |
+
sentencepiece>=0.1.99
|
| 9 |
+
huggingface_hub>=0.20.0
|
| 10 |
+
|
| 11 |
+
# Audio
|
| 12 |
+
librosa>=0.10.0
|
| 13 |
+
soundfile>=0.12.0
|
| 14 |
+
|
| 15 |
+
# Image
|
| 16 |
+
Pillow>=10.0.0
|
| 17 |
+
|
| 18 |
+
# Data
|
| 19 |
+
numpy>=1.24.0
|
| 20 |
+
pandas>=2.0.0
|
| 21 |
+
scipy>=1.11.0
|
| 22 |
+
scikit-learn>=1.3.0
|
| 23 |
+
|
| 24 |
+
# Demo
|
| 25 |
+
streamlit>=1.28.0
|
| 26 |
+
plotly>=5.18.0
|
| 27 |
+
|
| 28 |
+
# Utilities
|
| 29 |
+
pyyaml>=6.0.0
|
| 30 |
+
tqdm>=4.66.0
|
| 31 |
+
joblib>=1.3.0
|
src/demo/__init__.py
ADDED
|
File without changes
|
src/demo/streamlit_app.py
ADDED
|
@@ -0,0 +1,1189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NEXUS Streamlit Demo Application
|
| 3 |
+
|
| 4 |
+
Interactive demo for the NEXUS Maternal-Neonatal Care Platform.
|
| 5 |
+
Built with Google HAI-DEF models for the MedGemma Impact Challenge.
|
| 6 |
+
|
| 7 |
+
HAI-DEF Models Used:
|
| 8 |
+
- MedSigLIP: Medical image analysis (anemia, jaundice detection)
|
| 9 |
+
- HeAR: Health acoustic representations (cry analysis)
|
| 10 |
+
- MedGemma: Clinical reasoning and synthesis
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import streamlit as st
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
import tempfile
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
# Add parent directory to path for imports
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 22 |
+
|
| 23 |
+
# Page configuration
|
| 24 |
+
st.set_page_config(
|
| 25 |
+
page_title="NEXUS - Maternal-Neonatal Care",
|
| 26 |
+
page_icon="👶",
|
| 27 |
+
layout="wide",
|
| 28 |
+
initial_sidebar_state="expanded",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Custom CSS
|
| 32 |
+
st.markdown("""
|
| 33 |
+
<style>
|
| 34 |
+
.main-header {
|
| 35 |
+
font-size: 2.5rem;
|
| 36 |
+
font-weight: bold;
|
| 37 |
+
color: #1f77b4;
|
| 38 |
+
text-align: center;
|
| 39 |
+
margin-bottom: 1rem;
|
| 40 |
+
}
|
| 41 |
+
.sub-header {
|
| 42 |
+
font-size: 1.2rem;
|
| 43 |
+
color: #666;
|
| 44 |
+
text-align: center;
|
| 45 |
+
margin-bottom: 2rem;
|
| 46 |
+
}
|
| 47 |
+
.risk-high {
|
| 48 |
+
background-color: #ffcccc;
|
| 49 |
+
border: 2px solid #ff0000;
|
| 50 |
+
padding: 1rem;
|
| 51 |
+
border-radius: 10px;
|
| 52 |
+
}
|
| 53 |
+
.risk-medium {
|
| 54 |
+
background-color: #fff3cd;
|
| 55 |
+
border: 2px solid #ffc107;
|
| 56 |
+
padding: 1rem;
|
| 57 |
+
border-radius: 10px;
|
| 58 |
+
}
|
| 59 |
+
.risk-low {
|
| 60 |
+
background-color: #d4edda;
|
| 61 |
+
border: 2px solid #28a745;
|
| 62 |
+
padding: 1rem;
|
| 63 |
+
border-radius: 10px;
|
| 64 |
+
}
|
| 65 |
+
.metric-card {
|
| 66 |
+
background-color: #f8f9fa;
|
| 67 |
+
padding: 1rem;
|
| 68 |
+
border-radius: 10px;
|
| 69 |
+
text-align: center;
|
| 70 |
+
}
|
| 71 |
+
.model-badge {
|
| 72 |
+
display: inline-block;
|
| 73 |
+
padding: 2px 10px;
|
| 74 |
+
border-radius: 12px;
|
| 75 |
+
font-size: 0.78rem;
|
| 76 |
+
font-weight: 600;
|
| 77 |
+
color: white;
|
| 78 |
+
letter-spacing: 0.3px;
|
| 79 |
+
}
|
| 80 |
+
.stMetric > div {
|
| 81 |
+
background-color: #f8f9fa;
|
| 82 |
+
padding: 0.5rem;
|
| 83 |
+
border-radius: 8px;
|
| 84 |
+
}
|
| 85 |
+
</style>
|
| 86 |
+
""", unsafe_allow_html=True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@st.cache_resource
|
| 90 |
+
def load_anemia_detector():
|
| 91 |
+
"""Load anemia detector model with error handling."""
|
| 92 |
+
try:
|
| 93 |
+
from nexus.anemia_detector import AnemiaDetector
|
| 94 |
+
detector = AnemiaDetector()
|
| 95 |
+
return detector, None
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return None, str(e)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@st.cache_resource
|
| 101 |
+
def load_jaundice_detector():
|
| 102 |
+
"""Load jaundice detector model with error handling."""
|
| 103 |
+
try:
|
| 104 |
+
from nexus.jaundice_detector import JaundiceDetector
|
| 105 |
+
detector = JaundiceDetector()
|
| 106 |
+
return detector, None
|
| 107 |
+
except Exception as e:
|
| 108 |
+
return None, str(e)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@st.cache_resource
|
| 112 |
+
def load_cry_analyzer():
|
| 113 |
+
"""Load cry analyzer with error handling."""
|
| 114 |
+
try:
|
| 115 |
+
from nexus.cry_analyzer import CryAnalyzer
|
| 116 |
+
analyzer = CryAnalyzer()
|
| 117 |
+
return analyzer, None
|
| 118 |
+
except Exception as e:
|
| 119 |
+
return None, str(e)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@st.cache_resource
|
| 123 |
+
def load_clinical_synthesizer():
|
| 124 |
+
"""Load clinical synthesizer (MedGemma) with error handling."""
|
| 125 |
+
try:
|
| 126 |
+
import os
|
| 127 |
+
from nexus.clinical_synthesizer import ClinicalSynthesizer
|
| 128 |
+
use_medgemma = os.environ.get("NEXUS_USE_MEDGEMMA", "true").lower() != "false"
|
| 129 |
+
synthesizer = ClinicalSynthesizer(use_medgemma=use_medgemma)
|
| 130 |
+
return synthesizer, None
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return None, str(e)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_hai_def_info():
|
| 136 |
+
"""Get HAI-DEF models information with validated accuracy numbers."""
|
| 137 |
+
return {
|
| 138 |
+
"MedSigLIP": {
|
| 139 |
+
"name": "MedSigLIP (google/medsiglip-448)",
|
| 140 |
+
"use": "Image analysis for anemia and jaundice detection + bilirubin regression",
|
| 141 |
+
"method": "Zero-shot classification (max-similarity, 8 prompts/class) + trained SVM/LR classifiers on embeddings",
|
| 142 |
+
"accuracy": "Anemia: trained classifier on augmented data, Jaundice: trained classifier on 2,235 images, Bilirubin: MAE 2.67 mg/dL (r=0.77)",
|
| 143 |
+
"badge": "Vision",
|
| 144 |
+
"badge_color": "#388e3c",
|
| 145 |
+
},
|
| 146 |
+
"HeAR": {
|
| 147 |
+
"name": "HeAR (google/hear-pytorch)",
|
| 148 |
+
"use": "Infant cry analysis for asphyxia and cry type classification",
|
| 149 |
+
"method": "512-dim health acoustic embeddings + trained linear classifier on donate-a-cry dataset (5-class: hungry, belly_pain, burping, discomfort, tired)",
|
| 150 |
+
"accuracy": "Trained cry type classifier with asphyxia risk derivation from distress patterns",
|
| 151 |
+
"badge": "Audio",
|
| 152 |
+
"badge_color": "#f57c00",
|
| 153 |
+
},
|
| 154 |
+
"MedGemma": {
|
| 155 |
+
"name": "MedGemma 1.5 4B (google/medgemma-1.5-4b-it)",
|
| 156 |
+
"use": "Clinical reasoning and recommendation synthesis",
|
| 157 |
+
"method": "4-bit NF4 quantized inference with WHO IMNCI protocol-aligned synthesis and 6-agent reasoning traces",
|
| 158 |
+
"accuracy": "Protocol-aligned clinical recommendations with structured reasoning chains",
|
| 159 |
+
"badge": "Language",
|
| 160 |
+
"badge_color": "#1976d2",
|
| 161 |
+
},
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def main():
|
| 166 |
+
"""Main application."""
|
| 167 |
+
|
| 168 |
+
# Header
|
| 169 |
+
st.markdown('<div class="main-header">NEXUS</div>', unsafe_allow_html=True)
|
| 170 |
+
st.markdown(
|
| 171 |
+
'<div class="sub-header">AI-Powered Maternal-Neonatal Care Platform</div>',
|
| 172 |
+
unsafe_allow_html=True
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Sidebar
|
| 176 |
+
with st.sidebar:
|
| 177 |
+
st.markdown("## 🏥 NEXUS")
|
| 178 |
+
st.markdown("---")
|
| 179 |
+
|
| 180 |
+
assessment_type = st.radio(
|
| 181 |
+
"Select Assessment Type",
|
| 182 |
+
[
|
| 183 |
+
"Maternal Anemia Screening",
|
| 184 |
+
"Neonatal Jaundice Detection",
|
| 185 |
+
"Cry Analysis",
|
| 186 |
+
"Combined Assessment",
|
| 187 |
+
"Agentic Workflow",
|
| 188 |
+
"HAI-DEF Models Info"
|
| 189 |
+
],
|
| 190 |
+
index=0,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
st.markdown("---")
|
| 194 |
+
st.markdown("### About NEXUS")
|
| 195 |
+
st.markdown("""
|
| 196 |
+
NEXUS uses AI to provide non-invasive screening for:
|
| 197 |
+
- **Maternal Anemia** via conjunctiva imaging
|
| 198 |
+
- **Neonatal Jaundice** via skin color analysis
|
| 199 |
+
- **Birth Asphyxia** via cry pattern analysis
|
| 200 |
+
|
| 201 |
+
Built with **Google HAI-DEF models** for the MedGemma Impact Challenge 2026.
|
| 202 |
+
""")
|
| 203 |
+
|
| 204 |
+
st.markdown("---")
|
| 205 |
+
st.markdown("### Edge AI Mode")
|
| 206 |
+
edge_mode = st.toggle("Enable Edge AI Mode", value=False, key="edge_mode")
|
| 207 |
+
if edge_mode:
|
| 208 |
+
st.success("Edge AI: INT8 quantized models + offline inference")
|
| 209 |
+
else:
|
| 210 |
+
st.info("Cloud mode: Full-precision HAI-DEF models")
|
| 211 |
+
|
| 212 |
+
st.markdown("---")
|
| 213 |
+
st.markdown("### HAI-DEF Models")
|
| 214 |
+
st.markdown("""
|
| 215 |
+
- **MedSigLIP**: Vision (trained classifiers)
|
| 216 |
+
- **HeAR**: Audio (trained cry classifier)
|
| 217 |
+
- **MedGemma 1.5**: Clinical AI (4-bit NF4)
|
| 218 |
+
""")
|
| 219 |
+
|
| 220 |
+
# Show Edge AI banner when enabled
|
| 221 |
+
if edge_mode:
|
| 222 |
+
render_edge_ai_banner()
|
| 223 |
+
|
| 224 |
+
# Main content based on selection
|
| 225 |
+
if assessment_type == "Maternal Anemia Screening":
|
| 226 |
+
render_anemia_screening()
|
| 227 |
+
elif assessment_type == "Neonatal Jaundice Detection":
|
| 228 |
+
render_jaundice_detection()
|
| 229 |
+
elif assessment_type == "Cry Analysis":
|
| 230 |
+
render_cry_analysis()
|
| 231 |
+
elif assessment_type == "Combined Assessment":
|
| 232 |
+
render_combined_assessment()
|
| 233 |
+
elif assessment_type == "Agentic Workflow":
|
| 234 |
+
render_agentic_workflow()
|
| 235 |
+
else:
|
| 236 |
+
render_hai_def_info()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def render_edge_ai_banner():
|
| 240 |
+
"""Show Edge AI mode status and model metrics."""
|
| 241 |
+
st.markdown("""
|
| 242 |
+
<div style="background: linear-gradient(135deg, #1a237e 0%, #0d47a1 100%);
|
| 243 |
+
color: white; padding: 1rem 1.5rem; border-radius: 10px; margin-bottom: 1rem;">
|
| 244 |
+
<h4 style="margin:0; color: white;">Edge AI Mode Active</h4>
|
| 245 |
+
<p style="margin: 0.3rem 0 0 0; opacity: 0.9; font-size: 0.9rem;">
|
| 246 |
+
Running INT8 quantized models for offline-capable inference on low-resource devices.
|
| 247 |
+
</p>
|
| 248 |
+
</div>
|
| 249 |
+
""", unsafe_allow_html=True)
|
| 250 |
+
|
| 251 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 252 |
+
with col1:
|
| 253 |
+
st.metric("MedSigLIP INT8", "111.2 MB", "-86% memory")
|
| 254 |
+
with col2:
|
| 255 |
+
st.metric("Acoustic Model", "0.6 MB", "INT8 quantized")
|
| 256 |
+
with col3:
|
| 257 |
+
st.metric("Text Embeddings", "12 KB", "Pre-computed")
|
| 258 |
+
with col4:
|
| 259 |
+
st.metric("Total Edge Size", "~289 MB", "Offline-ready")
|
| 260 |
+
|
| 261 |
+
with st.expander("Edge AI Details"):
|
| 262 |
+
st.markdown("""
|
| 263 |
+
**Quantization**: Dynamic INT8 (PyTorch `quantize_dynamic`, qnnpack backend)
|
| 264 |
+
|
| 265 |
+
| Component | Cloud (FP32) | Edge (INT8) | Compression |
|
| 266 |
+
|-----------|-------------|-------------|-------------|
|
| 267 |
+
| MedSigLIP Vision | 812.6 MB | 111.2 MB | **7.31x** |
|
| 268 |
+
| Acoustic Model | 0.665 MB | 0.599 MB | 1.11x |
|
| 269 |
+
| CPU Latency | 97.7 ms | ~65 ms (ARM est.) | ~1.5x faster |
|
| 270 |
+
|
| 271 |
+
**Target Devices**: Android 8.0+, ARM Cortex-A53, 2GB RAM
|
| 272 |
+
|
| 273 |
+
**Offline Capabilities**:
|
| 274 |
+
- Image analysis via INT8 MedSigLIP + pre-computed binary text embeddings
|
| 275 |
+
- Audio analysis via INT8 acoustic feature extractor
|
| 276 |
+
- Clinical reasoning via rule-based WHO IMNCI protocols (no MedGemma required)
|
| 277 |
+
""")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _cleanup_temp(path: str) -> None:
|
| 281 |
+
"""Safely remove a temporary file."""
|
| 282 |
+
try:
|
| 283 |
+
if path and os.path.exists(path):
|
| 284 |
+
os.unlink(path)
|
| 285 |
+
except OSError:
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _save_upload_to_temp(uploaded_file, suffix: str) -> str:
|
| 290 |
+
"""Save an uploaded file to a temporary path and return the path."""
|
| 291 |
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
| 292 |
+
try:
|
| 293 |
+
tmp.write(uploaded_file.getvalue())
|
| 294 |
+
tmp.close()
|
| 295 |
+
return tmp.name
|
| 296 |
+
except Exception:
|
| 297 |
+
tmp.close()
|
| 298 |
+
_cleanup_temp(tmp.name)
|
| 299 |
+
raise
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _model_badge(name: str, color: str) -> str:
|
| 303 |
+
"""Return an HTML badge for displaying which HAI-DEF model is active."""
|
| 304 |
+
return (
|
| 305 |
+
f'<span style="background:{color}; color:white; padding:2px 10px; '
|
| 306 |
+
f'border-radius:12px; font-size:0.78rem; font-weight:600; '
|
| 307 |
+
f'letter-spacing:0.3px;">{name}</span>'
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def render_anemia_screening():
|
| 312 |
+
"""Render anemia screening interface."""
|
| 313 |
+
st.header("Maternal Anemia Screening")
|
| 314 |
+
st.markdown(
|
| 315 |
+
f"Upload a clear image of the inner eyelid (conjunctiva) for anemia screening. "
|
| 316 |
+
f'{_model_badge("MedSigLIP", "#388e3c")}',
|
| 317 |
+
unsafe_allow_html=True,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
col1, col2 = st.columns([1, 1])
|
| 321 |
+
|
| 322 |
+
with col1:
|
| 323 |
+
st.subheader("Upload Image")
|
| 324 |
+
uploaded_file = st.file_uploader(
|
| 325 |
+
"Choose a conjunctiva image",
|
| 326 |
+
type=["jpg", "jpeg", "png"],
|
| 327 |
+
key="anemia_upload"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if uploaded_file:
|
| 331 |
+
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
|
| 332 |
+
|
| 333 |
+
with col2:
|
| 334 |
+
st.subheader("Analysis Results")
|
| 335 |
+
|
| 336 |
+
if uploaded_file:
|
| 337 |
+
with st.spinner("Analyzing image..."):
|
| 338 |
+
tmp_path = None
|
| 339 |
+
try:
|
| 340 |
+
detector, load_err = load_anemia_detector()
|
| 341 |
+
if detector is None:
|
| 342 |
+
st.error(f"Could not load model: {load_err}")
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
tmp_path = _save_upload_to_temp(uploaded_file, ".jpg")
|
| 346 |
+
|
| 347 |
+
result = detector.detect(tmp_path)
|
| 348 |
+
color_info = detector.analyze_color_features(tmp_path)
|
| 349 |
+
|
| 350 |
+
# Display results
|
| 351 |
+
risk_class = f"risk-{result['risk_level']}"
|
| 352 |
+
st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
|
| 353 |
+
|
| 354 |
+
if result["is_anemic"]:
|
| 355 |
+
st.error("⚠️ ANEMIA DETECTED")
|
| 356 |
+
else:
|
| 357 |
+
st.success("✅ No Anemia Detected")
|
| 358 |
+
|
| 359 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 360 |
+
|
| 361 |
+
# Metrics
|
| 362 |
+
col_a, col_b, col_c = st.columns(3)
|
| 363 |
+
with col_a:
|
| 364 |
+
st.metric("Confidence", f"{result['confidence']:.1%}")
|
| 365 |
+
with col_b:
|
| 366 |
+
st.metric("Risk Level", result['risk_level'].upper())
|
| 367 |
+
with col_c:
|
| 368 |
+
st.metric("Est. Hemoglobin", f"{color_info['estimated_hemoglobin']} g/dL")
|
| 369 |
+
|
| 370 |
+
# Recommendation
|
| 371 |
+
st.markdown("### Recommendation")
|
| 372 |
+
st.info(result["recommendation"])
|
| 373 |
+
|
| 374 |
+
# Color analysis
|
| 375 |
+
with st.expander("Technical Details"):
|
| 376 |
+
st.json({
|
| 377 |
+
"anemia_score": round(result["anemia_score"], 3),
|
| 378 |
+
"healthy_score": round(result["healthy_score"], 3),
|
| 379 |
+
"red_ratio": round(color_info["red_ratio"], 3),
|
| 380 |
+
"pallor_index": round(color_info["pallor_index"], 3),
|
| 381 |
+
})
|
| 382 |
+
|
| 383 |
+
except Exception as e:
|
| 384 |
+
st.error(f"Error analyzing image: {e}")
|
| 385 |
+
finally:
|
| 386 |
+
_cleanup_temp(tmp_path)
|
| 387 |
+
else:
|
| 388 |
+
st.info("👆 Upload an image to begin analysis")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def render_jaundice_detection():
|
| 392 |
+
"""Render jaundice detection interface."""
|
| 393 |
+
st.header("Neonatal Jaundice Detection")
|
| 394 |
+
st.markdown(
|
| 395 |
+
f"Upload an image of the newborn's skin or sclera for jaundice assessment. "
|
| 396 |
+
f'{_model_badge("MedSigLIP", "#388e3c")}',
|
| 397 |
+
unsafe_allow_html=True,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
col1, col2 = st.columns([1, 1])
|
| 401 |
+
|
| 402 |
+
with col1:
|
| 403 |
+
st.subheader("Upload Image")
|
| 404 |
+
uploaded_file = st.file_uploader(
|
| 405 |
+
"Choose a neonatal image",
|
| 406 |
+
type=["jpg", "jpeg", "png"],
|
| 407 |
+
key="jaundice_upload"
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if uploaded_file:
|
| 411 |
+
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
|
| 412 |
+
|
| 413 |
+
# Patient info
|
| 414 |
+
st.subheader("Patient Information (Optional)")
|
| 415 |
+
age_days = st.number_input("Age (days)", min_value=0, max_value=28, value=3)
|
| 416 |
+
birth_weight = st.number_input("Birth weight (grams)", min_value=500, max_value=5000, value=3000)
|
| 417 |
+
|
| 418 |
+
with col2:
|
| 419 |
+
st.subheader("Analysis Results")
|
| 420 |
+
|
| 421 |
+
if uploaded_file:
|
| 422 |
+
with st.spinner("Analyzing image..."):
|
| 423 |
+
tmp_path = None
|
| 424 |
+
try:
|
| 425 |
+
detector, load_err = load_jaundice_detector()
|
| 426 |
+
if detector is None:
|
| 427 |
+
st.error(f"Could not load model: {load_err}")
|
| 428 |
+
return
|
| 429 |
+
|
| 430 |
+
tmp_path = _save_upload_to_temp(uploaded_file, ".jpg")
|
| 431 |
+
|
| 432 |
+
result = detector.detect(tmp_path)
|
| 433 |
+
zone_info = detector.analyze_kramer_zones(tmp_path)
|
| 434 |
+
|
| 435 |
+
# Display results
|
| 436 |
+
risk_class = "risk-high" if result["needs_phototherapy"] else (
|
| 437 |
+
"risk-medium" if result["severity"] in ["moderate", "mild"] else "risk-low"
|
| 438 |
+
)
|
| 439 |
+
st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
|
| 440 |
+
|
| 441 |
+
if result["has_jaundice"]:
|
| 442 |
+
st.warning(f"⚠️ JAUNDICE DETECTED - {result['severity'].upper()}")
|
| 443 |
+
else:
|
| 444 |
+
st.success("✅ No Significant Jaundice")
|
| 445 |
+
|
| 446 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 447 |
+
|
| 448 |
+
# Metrics - show ML bilirubin if available
|
| 449 |
+
col_a, col_b, col_c = st.columns(3)
|
| 450 |
+
with col_a:
|
| 451 |
+
bili_value = result.get('estimated_bilirubin_ml', result.get('estimated_bilirubin', 0))
|
| 452 |
+
bili_method = result.get('bilirubin_method', 'Color Analysis')
|
| 453 |
+
st.metric("Est. Bilirubin", f"{bili_value} mg/dL")
|
| 454 |
+
st.caption(f"Method: {bili_method}")
|
| 455 |
+
with col_b:
|
| 456 |
+
st.metric("Severity", result['severity'].upper())
|
| 457 |
+
with col_c:
|
| 458 |
+
st.metric("Kramer Zone", zone_info['kramer_zone'])
|
| 459 |
+
|
| 460 |
+
# Phototherapy indicator
|
| 461 |
+
if result["needs_phototherapy"]:
|
| 462 |
+
st.error("🔆 PHOTOTHERAPY RECOMMENDED")
|
| 463 |
+
|
| 464 |
+
# Recommendation
|
| 465 |
+
st.markdown("### Recommendation")
|
| 466 |
+
st.info(result["recommendation"])
|
| 467 |
+
|
| 468 |
+
# Zone analysis
|
| 469 |
+
with st.expander("Kramer Zone Analysis"):
|
| 470 |
+
st.write(f"**Zone**: {zone_info['kramer_zone']} - {zone_info['zone_description']}")
|
| 471 |
+
st.write(f"**Yellow Index**: {zone_info['yellow_index']}")
|
| 472 |
+
st.progress(min(zone_info['yellow_index'] * 2, 1.0))
|
| 473 |
+
|
| 474 |
+
# Technical details
|
| 475 |
+
with st.expander("Technical Details"):
|
| 476 |
+
details = {
|
| 477 |
+
"jaundice_score": round(result["jaundice_score"], 3),
|
| 478 |
+
"confidence": round(result["confidence"], 3),
|
| 479 |
+
"model": result.get("model", "unknown"),
|
| 480 |
+
"model_type": result.get("model_type", "unknown"),
|
| 481 |
+
"bilirubin_method": result.get("bilirubin_method", "Color Analysis"),
|
| 482 |
+
}
|
| 483 |
+
if result.get("estimated_bilirubin_ml") is not None:
|
| 484 |
+
details["bilirubin_ml"] = result["estimated_bilirubin_ml"]
|
| 485 |
+
details["bilirubin_color"] = result["estimated_bilirubin"]
|
| 486 |
+
st.json(details)
|
| 487 |
+
|
| 488 |
+
except Exception as e:
|
| 489 |
+
st.error(f"Error analyzing image: {e}")
|
| 490 |
+
finally:
|
| 491 |
+
_cleanup_temp(tmp_path)
|
| 492 |
+
else:
|
| 493 |
+
st.info("👆 Upload an image to begin analysis")
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def render_cry_analysis():
|
| 497 |
+
"""Render cry analysis interface."""
|
| 498 |
+
st.header("Infant Cry Analysis")
|
| 499 |
+
st.markdown(
|
| 500 |
+
f"Upload an audio recording of the infant's cry for analysis. "
|
| 501 |
+
f'{_model_badge("HeAR", "#f57c00")}',
|
| 502 |
+
unsafe_allow_html=True,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
col1, col2 = st.columns([1, 1])
|
| 506 |
+
|
| 507 |
+
with col1:
|
| 508 |
+
st.subheader("Upload Audio")
|
| 509 |
+
uploaded_file = st.file_uploader(
|
| 510 |
+
"Choose a cry audio file",
|
| 511 |
+
type=["wav", "mp3", "ogg"],
|
| 512 |
+
key="cry_upload"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if uploaded_file:
|
| 516 |
+
st.audio(uploaded_file)
|
| 517 |
+
|
| 518 |
+
with col2:
|
| 519 |
+
st.subheader("Analysis Results")
|
| 520 |
+
|
| 521 |
+
if uploaded_file:
|
| 522 |
+
with st.spinner("Analyzing cry..."):
|
| 523 |
+
tmp_path = None
|
| 524 |
+
try:
|
| 525 |
+
analyzer, load_err = load_cry_analyzer()
|
| 526 |
+
if analyzer is None:
|
| 527 |
+
st.error(f"Could not load model: {load_err}")
|
| 528 |
+
return
|
| 529 |
+
|
| 530 |
+
tmp_path = _save_upload_to_temp(uploaded_file, ".wav")
|
| 531 |
+
|
| 532 |
+
result = analyzer.analyze(tmp_path)
|
| 533 |
+
|
| 534 |
+
# Display results
|
| 535 |
+
risk_class = f"risk-{result['risk_level']}"
|
| 536 |
+
st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
|
| 537 |
+
|
| 538 |
+
if result["is_abnormal"]:
|
| 539 |
+
st.error("⚠️ ABNORMAL CRY PATTERN DETECTED")
|
| 540 |
+
else:
|
| 541 |
+
st.success("✅ Normal Cry Pattern")
|
| 542 |
+
|
| 543 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 544 |
+
|
| 545 |
+
# Metrics
|
| 546 |
+
col_a, col_b, col_c = st.columns(3)
|
| 547 |
+
with col_a:
|
| 548 |
+
st.metric("Asphyxia Risk", f"{result['asphyxia_risk']:.1%}")
|
| 549 |
+
with col_b:
|
| 550 |
+
st.metric("Cry Type", result['cry_type'].title())
|
| 551 |
+
with col_c:
|
| 552 |
+
st.metric("F0 (Pitch)", f"{result['features']['f0_mean']:.0f} Hz")
|
| 553 |
+
|
| 554 |
+
# Recommendation
|
| 555 |
+
st.markdown("### Recommendation")
|
| 556 |
+
st.info(result["recommendation"])
|
| 557 |
+
|
| 558 |
+
# Acoustic features
|
| 559 |
+
with st.expander("Acoustic Features"):
|
| 560 |
+
st.json(result["features"])
|
| 561 |
+
|
| 562 |
+
except Exception as e:
|
| 563 |
+
st.error(f"Error analyzing audio: {e}")
|
| 564 |
+
finally:
|
| 565 |
+
_cleanup_temp(tmp_path)
|
| 566 |
+
else:
|
| 567 |
+
st.info("👆 Upload an audio file to begin analysis")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def render_combined_assessment():
|
| 571 |
+
"""Render combined assessment interface using Clinical Synthesizer."""
|
| 572 |
+
st.header("Combined Clinical Assessment")
|
| 573 |
+
st.markdown(
|
| 574 |
+
f"Upload multiple inputs for a comprehensive assessment using **MedGemma Clinical Synthesizer**. "
|
| 575 |
+
f"This combines findings from all HAI-DEF models to provide integrated clinical recommendations. "
|
| 576 |
+
f'{_model_badge("MedSigLIP", "#388e3c")} '
|
| 577 |
+
f'{_model_badge("HeAR", "#f57c00")} '
|
| 578 |
+
f'{_model_badge("MedGemma", "#1976d2")}',
|
| 579 |
+
unsafe_allow_html=True,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Reset findings each time this tab is rendered to prevent
|
| 583 |
+
# stale data from previous patients contaminating results
|
| 584 |
+
st.session_state.findings = {
|
| 585 |
+
"anemia": None,
|
| 586 |
+
"jaundice": None,
|
| 587 |
+
"cry": None
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
col1, col2, col3 = st.columns(3)
|
| 591 |
+
|
| 592 |
+
with col1:
|
| 593 |
+
st.subheader("🩸 Anemia Screening")
|
| 594 |
+
anemia_file = st.file_uploader(
|
| 595 |
+
"Conjunctiva image",
|
| 596 |
+
type=["jpg", "jpeg", "png"],
|
| 597 |
+
key="combined_anemia"
|
| 598 |
+
)
|
| 599 |
+
if anemia_file:
|
| 600 |
+
st.image(anemia_file, use_container_width=True)
|
| 601 |
+
with st.spinner("Analyzing..."):
|
| 602 |
+
try:
|
| 603 |
+
detector, load_err = load_anemia_detector()
|
| 604 |
+
if detector is None:
|
| 605 |
+
st.error(f"Model error: {load_err}")
|
| 606 |
+
else:
|
| 607 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| 608 |
+
tmp.write(anemia_file.getvalue())
|
| 609 |
+
result = detector.detect(tmp.name)
|
| 610 |
+
st.session_state.findings["anemia"] = result
|
| 611 |
+
if result["is_anemic"]:
|
| 612 |
+
st.error(f"Anemia: {result['risk_level'].upper()}")
|
| 613 |
+
else:
|
| 614 |
+
st.success("No Anemia")
|
| 615 |
+
except Exception as e:
|
| 616 |
+
st.error(f"Error: {e}")
|
| 617 |
+
|
| 618 |
+
with col2:
|
| 619 |
+
st.subheader("👶 Jaundice Detection")
|
| 620 |
+
jaundice_file = st.file_uploader(
|
| 621 |
+
"Neonatal skin image",
|
| 622 |
+
type=["jpg", "jpeg", "png"],
|
| 623 |
+
key="combined_jaundice"
|
| 624 |
+
)
|
| 625 |
+
if jaundice_file:
|
| 626 |
+
st.image(jaundice_file, use_container_width=True)
|
| 627 |
+
with st.spinner("Analyzing..."):
|
| 628 |
+
try:
|
| 629 |
+
detector, load_err = load_jaundice_detector()
|
| 630 |
+
if detector is None:
|
| 631 |
+
st.error(f"Model error: {load_err}")
|
| 632 |
+
else:
|
| 633 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
|
| 634 |
+
tmp.write(jaundice_file.getvalue())
|
| 635 |
+
result = detector.detect(tmp.name)
|
| 636 |
+
st.session_state.findings["jaundice"] = result
|
| 637 |
+
if result["has_jaundice"]:
|
| 638 |
+
st.warning(f"Jaundice: {result['severity'].upper()}")
|
| 639 |
+
else:
|
| 640 |
+
st.success("No Jaundice")
|
| 641 |
+
except Exception as e:
|
| 642 |
+
st.error(f"Error: {e}")
|
| 643 |
+
|
| 644 |
+
with col3:
|
| 645 |
+
st.subheader("🔊 Cry Analysis")
|
| 646 |
+
cry_file = st.file_uploader(
|
| 647 |
+
"Cry audio",
|
| 648 |
+
type=["wav", "mp3", "ogg"],
|
| 649 |
+
key="combined_cry"
|
| 650 |
+
)
|
| 651 |
+
if cry_file:
|
| 652 |
+
st.audio(cry_file)
|
| 653 |
+
with st.spinner("Analyzing..."):
|
| 654 |
+
try:
|
| 655 |
+
analyzer, load_err = load_cry_analyzer()
|
| 656 |
+
if analyzer is None:
|
| 657 |
+
st.error(f"Model error: {load_err}")
|
| 658 |
+
raise RuntimeError(load_err)
|
| 659 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 660 |
+
tmp.write(cry_file.getvalue())
|
| 661 |
+
result = analyzer.analyze(tmp.name)
|
| 662 |
+
st.session_state.findings["cry"] = result
|
| 663 |
+
if result["is_abnormal"]:
|
| 664 |
+
st.error(f"Abnormal Cry: {result['risk_level'].upper()}")
|
| 665 |
+
else:
|
| 666 |
+
st.success("Normal Cry")
|
| 667 |
+
except Exception as e:
|
| 668 |
+
st.error(f"Error: {e}")
|
| 669 |
+
|
| 670 |
+
# Clinical Synthesis Section
|
| 671 |
+
st.markdown("---")
|
| 672 |
+
st.subheader("🏥 Clinical Synthesis (MedGemma)")
|
| 673 |
+
|
| 674 |
+
# Check if any findings are available
|
| 675 |
+
has_findings = any(v is not None for v in st.session_state.findings.values())
|
| 676 |
+
|
| 677 |
+
if has_findings:
|
| 678 |
+
if st.button("Generate Clinical Synthesis", type="primary"):
|
| 679 |
+
with st.spinner("Synthesizing findings with MedGemma..."):
|
| 680 |
+
try:
|
| 681 |
+
synthesizer, load_err = load_clinical_synthesizer()
|
| 682 |
+
if synthesizer is None:
|
| 683 |
+
st.error(f"Could not load synthesizer: {load_err}")
|
| 684 |
+
return
|
| 685 |
+
|
| 686 |
+
# Prepare findings dict
|
| 687 |
+
findings = {}
|
| 688 |
+
if st.session_state.findings["anemia"]:
|
| 689 |
+
findings["anemia"] = st.session_state.findings["anemia"]
|
| 690 |
+
if st.session_state.findings["jaundice"]:
|
| 691 |
+
findings["jaundice"] = st.session_state.findings["jaundice"]
|
| 692 |
+
if st.session_state.findings["cry"]:
|
| 693 |
+
findings["cry"] = st.session_state.findings["cry"]
|
| 694 |
+
|
| 695 |
+
synthesis = synthesizer.synthesize(findings)
|
| 696 |
+
|
| 697 |
+
# Display synthesis results
|
| 698 |
+
severity_level = synthesis.get("severity_level", "GREEN")
|
| 699 |
+
severity_colors = {
|
| 700 |
+
"GREEN": ("🟢", "#d4edda", "#155724"),
|
| 701 |
+
"YELLOW": ("🟡", "#fff3cd", "#856404"),
|
| 702 |
+
"RED": ("🔴", "#f8d7da", "#721c24")
|
| 703 |
+
}
|
| 704 |
+
emoji, bg_color, text_color = severity_colors.get(severity_level, ("⚪", "#f8f9fa", "#000"))
|
| 705 |
+
|
| 706 |
+
st.markdown(f"""
|
| 707 |
+
<div style="background-color: {bg_color}; padding: 1.5rem; border-radius: 10px; margin: 1rem 0;">
|
| 708 |
+
<h3 style="color: {text_color}; margin: 0;">{emoji} Severity: {severity_level}</h3>
|
| 709 |
+
<p style="color: {text_color}; font-size: 1.1rem; margin-top: 0.5rem;">{synthesis.get('severity_description', '')}</p>
|
| 710 |
+
</div>
|
| 711 |
+
""", unsafe_allow_html=True)
|
| 712 |
+
|
| 713 |
+
# Summary
|
| 714 |
+
st.markdown("### Summary")
|
| 715 |
+
st.info(synthesis.get("summary", "No summary available"))
|
| 716 |
+
|
| 717 |
+
# Actions
|
| 718 |
+
if synthesis.get("immediate_actions"):
|
| 719 |
+
st.markdown("### Immediate Actions")
|
| 720 |
+
for action in synthesis["immediate_actions"]:
|
| 721 |
+
st.markdown(f"- {action}")
|
| 722 |
+
|
| 723 |
+
# Referral
|
| 724 |
+
col_a, col_b = st.columns(2)
|
| 725 |
+
with col_a:
|
| 726 |
+
st.markdown("### Referral Status")
|
| 727 |
+
if synthesis.get("referral_needed"):
|
| 728 |
+
st.error(f"⚠️ REFERRAL NEEDED: {synthesis.get('referral_urgency', 'standard').upper()}")
|
| 729 |
+
else:
|
| 730 |
+
st.success("✅ No referral needed")
|
| 731 |
+
|
| 732 |
+
with col_b:
|
| 733 |
+
st.markdown("### Follow-up")
|
| 734 |
+
st.info(synthesis.get("follow_up", "Schedule routine follow-up"))
|
| 735 |
+
|
| 736 |
+
# Technical details
|
| 737 |
+
with st.expander("Technical Details"):
|
| 738 |
+
model_name = synthesis.get("model", "unknown")
|
| 739 |
+
st.json({
|
| 740 |
+
"model": model_name,
|
| 741 |
+
"model_id": synthesis.get("model_id", ""),
|
| 742 |
+
"generated_at": synthesis.get("generated_at"),
|
| 743 |
+
"urgent_conditions": synthesis.get("urgent_conditions", []),
|
| 744 |
+
})
|
| 745 |
+
if model_name and "Fallback" not in str(model_name):
|
| 746 |
+
st.success(f"Synthesis powered by {model_name}")
|
| 747 |
+
elif "Fallback" in str(model_name):
|
| 748 |
+
st.warning("Using rule-based fallback (MedGemma unavailable)")
|
| 749 |
+
|
| 750 |
+
except Exception as e:
|
| 751 |
+
st.error(f"Error generating synthesis: {e}")
|
| 752 |
+
else:
|
| 753 |
+
st.info("👆 Upload at least one input (image or audio) to generate clinical synthesis")
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def render_hai_def_info():
|
| 757 |
+
"""Render HAI-DEF models information."""
|
| 758 |
+
st.header("Google HAI-DEF Models")
|
| 759 |
+
st.markdown("""
|
| 760 |
+
NEXUS is built using **Google Health AI Developer Foundations (HAI-DEF)** models,
|
| 761 |
+
designed specifically for healthcare applications in resource-limited settings.
|
| 762 |
+
""")
|
| 763 |
+
|
| 764 |
+
hai_def = get_hai_def_info()
|
| 765 |
+
|
| 766 |
+
# MedSigLIP
|
| 767 |
+
st.markdown("---")
|
| 768 |
+
col1, col2 = st.columns([1, 2])
|
| 769 |
+
with col1:
|
| 770 |
+
st.markdown("### 🖼️ MedSigLIP")
|
| 771 |
+
st.info("google/medsiglip-448\n\nHAI-DEF Vision Model")
|
| 772 |
+
with col2:
|
| 773 |
+
info = hai_def["MedSigLIP"]
|
| 774 |
+
st.markdown(f"**Model**: {info['name']}")
|
| 775 |
+
st.markdown(f"**Use Case**: {info['use']}")
|
| 776 |
+
st.markdown(f"**Method**: {info['method']}")
|
| 777 |
+
st.markdown(f"**Validated Performance**: {info['accuracy']}")
|
| 778 |
+
st.markdown("""
|
| 779 |
+
MedSigLIP enables zero-shot medical image classification using
|
| 780 |
+
text prompts. NEXUS extends this with trained SVM/LR classifiers
|
| 781 |
+
on MedSigLIP embeddings (with data augmentation) for improved
|
| 782 |
+
accuracy, plus a novel 3-layer MLP regression head for continuous
|
| 783 |
+
bilirubin prediction from frozen embeddings.
|
| 784 |
+
""")
|
| 785 |
+
|
| 786 |
+
# HeAR
|
| 787 |
+
st.markdown("---")
|
| 788 |
+
col1, col2 = st.columns([1, 2])
|
| 789 |
+
with col1:
|
| 790 |
+
st.markdown("### 🔊 HeAR")
|
| 791 |
+
st.info("google/hear-pytorch\n\nHAI-DEF Audio Model")
|
| 792 |
+
with col2:
|
| 793 |
+
info = hai_def["HeAR"]
|
| 794 |
+
st.markdown(f"**Model**: {info['name']}")
|
| 795 |
+
st.markdown(f"**Use Case**: {info['use']}")
|
| 796 |
+
st.markdown(f"**Method**: {info['method']}")
|
| 797 |
+
st.markdown(f"**Validated Performance**: {info['accuracy']}")
|
| 798 |
+
st.markdown("""
|
| 799 |
+
HeAR (Health Acoustic Representations) produces 512-dim embeddings
|
| 800 |
+
from 2-second audio clips at 16kHz. NEXUS trains a linear classifier
|
| 801 |
+
on HeAR embeddings for 5-class cry type classification (hungry,
|
| 802 |
+
belly_pain, burping, discomfort, tired) and derives asphyxia risk
|
| 803 |
+
from distress patterns.
|
| 804 |
+
""")
|
| 805 |
+
|
| 806 |
+
# MedGemma
|
| 807 |
+
st.markdown("---")
|
| 808 |
+
col1, col2 = st.columns([1, 2])
|
| 809 |
+
with col1:
|
| 810 |
+
st.markdown("### 🧠 MedGemma")
|
| 811 |
+
st.info("google/medgemma-1.5-4b-it\n\nHAI-DEF Language Model")
|
| 812 |
+
with col2:
|
| 813 |
+
info = hai_def["MedGemma"]
|
| 814 |
+
st.markdown(f"**Model**: {info['name']}")
|
| 815 |
+
st.markdown(f"**Use Case**: {info['use']}")
|
| 816 |
+
st.markdown(f"**Method**: {info['method']}")
|
| 817 |
+
st.markdown(f"**Validated Performance**: {info['accuracy']}")
|
| 818 |
+
st.markdown("""
|
| 819 |
+
MedGemma 1.5 provides clinical reasoning capabilities via 4-bit NF4
|
| 820 |
+
quantized inference (~2 GB VRAM). It synthesizes multi-modal findings
|
| 821 |
+
into actionable recommendations following WHO IMNCI protocols,
|
| 822 |
+
producing structured reasoning chains within the 6-agent pipeline.
|
| 823 |
+
""")
|
| 824 |
+
|
| 825 |
+
# Competition Info
|
| 826 |
+
st.markdown("---")
|
| 827 |
+
st.subheader("🏆 MedGemma Impact Challenge 2026")
|
| 828 |
+
st.markdown("""
|
| 829 |
+
NEXUS is being developed for the [MedGemma Impact Challenge](https://www.kaggle.com/competitions/medgemma-impact-challenge-2026)
|
| 830 |
+
on Kaggle.
|
| 831 |
+
|
| 832 |
+
**Competition Focus**: Solutions for resource-limited healthcare settings using HAI-DEF models.
|
| 833 |
+
|
| 834 |
+
**NEXUS Impact**:
|
| 835 |
+
- 📍 Target: Sub-Saharan Africa and South Asia
|
| 836 |
+
- 👩⚕️ Users: Community Health Workers
|
| 837 |
+
- 🎯 Goals: Reduce maternal/neonatal mortality
|
| 838 |
+
- 📱 Deployment: Offline-capable mobile app
|
| 839 |
+
""")
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def render_agentic_workflow():
|
| 843 |
+
"""Render the agentic workflow interface with reasoning traces."""
|
| 844 |
+
st.header("Agentic Clinical Workflow")
|
| 845 |
+
st.markdown(
|
| 846 |
+
f"**6-Agent Pipeline** with step-by-step reasoning traces. "
|
| 847 |
+
f"Each agent explains its clinical decision process, providing a full audit trail. "
|
| 848 |
+
f'{_model_badge("MedSigLIP", "#388e3c")} '
|
| 849 |
+
f'{_model_badge("HeAR", "#f57c00")} '
|
| 850 |
+
f'{_model_badge("MedGemma", "#1976d2")}',
|
| 851 |
+
unsafe_allow_html=True,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# Pipeline diagram
|
| 855 |
+
st.markdown("""
|
| 856 |
+
<div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem; flex-wrap: wrap; margin: 1rem 0;">
|
| 857 |
+
<div style="background: #e3f2fd; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #1976d2;">Triage</div>
|
| 858 |
+
<span style="font-size: 1.5rem;">→</span>
|
| 859 |
+
<div style="background: #e8f5e9; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #388e3c;">Image (MedSigLIP)</div>
|
| 860 |
+
<span style="font-size: 1.5rem;">→</span>
|
| 861 |
+
<div style="background: #fff3e0; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #f57c00;">Audio (HeAR)</div>
|
| 862 |
+
<span style="font-size: 1.5rem;">→</span>
|
| 863 |
+
<div style="background: #f3e5f5; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #7b1fa2;">Protocol (WHO)</div>
|
| 864 |
+
<span style="font-size: 1.5rem;">→</span>
|
| 865 |
+
<div style="background: #fce4ec; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #c62828;">Referral</div>
|
| 866 |
+
<span style="font-size: 1.5rem;">→</span>
|
| 867 |
+
<div style="background: #e0f7fa; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #00838f;">Synthesis (MedGemma)</div>
|
| 868 |
+
</div>
|
| 869 |
+
""", unsafe_allow_html=True)
|
| 870 |
+
|
| 871 |
+
st.markdown("---")
|
| 872 |
+
|
| 873 |
+
# Input section
|
| 874 |
+
col_left, col_right = st.columns([1, 1])
|
| 875 |
+
|
| 876 |
+
with col_left:
|
| 877 |
+
st.subheader("Patient & Inputs")
|
| 878 |
+
patient_type = st.selectbox("Patient Type", ["newborn", "pregnant"], key="agentic_patient")
|
| 879 |
+
|
| 880 |
+
# Danger signs
|
| 881 |
+
st.markdown("**Danger Signs**")
|
| 882 |
+
danger_signs = []
|
| 883 |
+
if patient_type == "pregnant":
|
| 884 |
+
sign_options = [
|
| 885 |
+
("Severe headache", "high"),
|
| 886 |
+
("Blurred vision", "high"),
|
| 887 |
+
("Convulsions", "critical"),
|
| 888 |
+
("Severe abdominal pain", "high"),
|
| 889 |
+
("Vaginal bleeding", "critical"),
|
| 890 |
+
("High fever", "high"),
|
| 891 |
+
("Severe pallor", "medium"),
|
| 892 |
+
]
|
| 893 |
+
else:
|
| 894 |
+
sign_options = [
|
| 895 |
+
("Not breathing at birth", "critical"),
|
| 896 |
+
("Convulsions", "critical"),
|
| 897 |
+
("Severe chest indrawing", "high"),
|
| 898 |
+
("Not feeding", "high"),
|
| 899 |
+
("High fever (>38C)", "high"),
|
| 900 |
+
("Hypothermia (<35.5C)", "high"),
|
| 901 |
+
("Lethargy / unconscious", "critical"),
|
| 902 |
+
("Umbilical redness", "medium"),
|
| 903 |
+
]
|
| 904 |
+
|
| 905 |
+
selected_signs = st.multiselect(
|
| 906 |
+
"Select present danger signs",
|
| 907 |
+
[s[0] for s in sign_options],
|
| 908 |
+
key="agentic_signs"
|
| 909 |
+
)
|
| 910 |
+
for label, severity in sign_options:
|
| 911 |
+
if label in selected_signs:
|
| 912 |
+
danger_signs.append({
|
| 913 |
+
"id": label.lower().replace(" ", "_"),
|
| 914 |
+
"label": label,
|
| 915 |
+
"severity": severity,
|
| 916 |
+
"present": True,
|
| 917 |
+
})
|
| 918 |
+
|
| 919 |
+
# Image uploads
|
| 920 |
+
st.markdown("**Clinical Images**")
|
| 921 |
+
conjunctiva_file = st.file_uploader(
|
| 922 |
+
"Conjunctiva image (anemia)", type=["jpg", "jpeg", "png"],
|
| 923 |
+
key="agentic_conjunctiva"
|
| 924 |
+
)
|
| 925 |
+
skin_file = st.file_uploader(
|
| 926 |
+
"Skin image (jaundice)", type=["jpg", "jpeg", "png"],
|
| 927 |
+
key="agentic_skin"
|
| 928 |
+
)
|
| 929 |
+
cry_file = st.file_uploader(
|
| 930 |
+
"Cry audio", type=["wav", "mp3", "ogg"],
|
| 931 |
+
key="agentic_cry"
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
with col_right:
|
| 935 |
+
st.subheader("Workflow Execution")
|
| 936 |
+
|
| 937 |
+
if st.button("Run Agentic Assessment", type="primary", key="run_agentic"):
|
| 938 |
+
with st.spinner("Running 6-agent workflow..."):
|
| 939 |
+
try:
|
| 940 |
+
from nexus.agentic_workflow import (
|
| 941 |
+
AgenticWorkflowEngine,
|
| 942 |
+
AgentPatientInfo,
|
| 943 |
+
DangerSign,
|
| 944 |
+
WorkflowInput,
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
# Save uploaded files (track for cleanup)
|
| 948 |
+
_temp_paths = []
|
| 949 |
+
conjunctiva_path = None
|
| 950 |
+
skin_path = None
|
| 951 |
+
cry_path = None
|
| 952 |
+
|
| 953 |
+
if conjunctiva_file:
|
| 954 |
+
conjunctiva_path = _save_upload_to_temp(conjunctiva_file, ".jpg")
|
| 955 |
+
_temp_paths.append(conjunctiva_path)
|
| 956 |
+
|
| 957 |
+
if skin_file:
|
| 958 |
+
skin_path = _save_upload_to_temp(skin_file, ".jpg")
|
| 959 |
+
_temp_paths.append(skin_path)
|
| 960 |
+
|
| 961 |
+
if cry_file:
|
| 962 |
+
cry_path = _save_upload_to_temp(cry_file, ".wav")
|
| 963 |
+
_temp_paths.append(cry_path)
|
| 964 |
+
|
| 965 |
+
# Build workflow input
|
| 966 |
+
signs = [
|
| 967 |
+
DangerSign(
|
| 968 |
+
id=s["id"], label=s["label"],
|
| 969 |
+
severity=s["severity"], present=True,
|
| 970 |
+
)
|
| 971 |
+
for s in danger_signs
|
| 972 |
+
]
|
| 973 |
+
|
| 974 |
+
info = AgentPatientInfo(patient_type=patient_type)
|
| 975 |
+
workflow_input = WorkflowInput(
|
| 976 |
+
patient_type=patient_type,
|
| 977 |
+
patient_info=info,
|
| 978 |
+
danger_signs=signs,
|
| 979 |
+
conjunctiva_image=conjunctiva_path,
|
| 980 |
+
skin_image=skin_path,
|
| 981 |
+
cry_audio=cry_path,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# Run workflow — reuse cached model instances when available
|
| 985 |
+
anemia_det, _ = load_anemia_detector()
|
| 986 |
+
jaundice_det, _ = load_jaundice_detector()
|
| 987 |
+
cry_ana, _ = load_cry_analyzer()
|
| 988 |
+
synth, _ = load_clinical_synthesizer()
|
| 989 |
+
|
| 990 |
+
engine = AgenticWorkflowEngine(
|
| 991 |
+
anemia_detector=anemia_det,
|
| 992 |
+
jaundice_detector=jaundice_det,
|
| 993 |
+
cry_analyzer=cry_ana,
|
| 994 |
+
synthesizer=synth,
|
| 995 |
+
)
|
| 996 |
+
result = engine.execute(workflow_input)
|
| 997 |
+
|
| 998 |
+
st.session_state["agentic_result"] = result
|
| 999 |
+
st.success("Workflow complete!")
|
| 1000 |
+
|
| 1001 |
+
except Exception as e:
|
| 1002 |
+
st.error(f"Workflow error: {e}")
|
| 1003 |
+
finally:
|
| 1004 |
+
for p in _temp_paths:
|
| 1005 |
+
_cleanup_temp(p)
|
| 1006 |
+
|
| 1007 |
+
# Results display
|
| 1008 |
+
if "agentic_result" in st.session_state:
|
| 1009 |
+
result = st.session_state["agentic_result"]
|
| 1010 |
+
|
| 1011 |
+
st.markdown("---")
|
| 1012 |
+
|
| 1013 |
+
# Overall classification
|
| 1014 |
+
severity_colors = {
|
| 1015 |
+
"GREEN": ("#d4edda", "#155724", "Routine care"),
|
| 1016 |
+
"YELLOW": ("#fff3cd", "#856404", "Close monitoring"),
|
| 1017 |
+
"RED": ("#f8d7da", "#721c24", "Urgent referral"),
|
| 1018 |
+
}
|
| 1019 |
+
bg, fg, desc = severity_colors.get(result.who_classification, ("#f8f9fa", "#000", "Unknown"))
|
| 1020 |
+
|
| 1021 |
+
st.markdown(f"""
|
| 1022 |
+
<div style="background: {bg}; color: {fg}; padding: 1.5rem; border-radius: 10px; text-align: center; margin: 1rem 0;">
|
| 1023 |
+
<h2 style="margin: 0;">WHO Classification: {result.who_classification}</h2>
|
| 1024 |
+
<p style="margin: 0.5rem 0 0 0; font-size: 1.1rem;">{desc}</p>
|
| 1025 |
+
</div>
|
| 1026 |
+
""", unsafe_allow_html=True)
|
| 1027 |
+
|
| 1028 |
+
# Key metrics
|
| 1029 |
+
m1, m2, m3, m4 = st.columns(4)
|
| 1030 |
+
with m1:
|
| 1031 |
+
st.metric("Agents Run", len(result.agent_traces))
|
| 1032 |
+
with m2:
|
| 1033 |
+
st.metric("Total Time", f"{result.processing_time_ms:.0f} ms")
|
| 1034 |
+
with m3:
|
| 1035 |
+
referral_text = "Yes" if (result.referral_result and result.referral_result.referral_needed) else "No"
|
| 1036 |
+
st.metric("Referral Needed", referral_text)
|
| 1037 |
+
with m4:
|
| 1038 |
+
triage_score = result.triage_result.score if result.triage_result else 0
|
| 1039 |
+
st.metric("Triage Score", triage_score)
|
| 1040 |
+
|
| 1041 |
+
# Clinical synthesis
|
| 1042 |
+
st.subheader("Clinical Synthesis")
|
| 1043 |
+
st.info(result.clinical_synthesis)
|
| 1044 |
+
|
| 1045 |
+
if result.immediate_actions:
|
| 1046 |
+
st.subheader("Immediate Actions")
|
| 1047 |
+
for action in result.immediate_actions:
|
| 1048 |
+
st.markdown(f"- {action}")
|
| 1049 |
+
|
| 1050 |
+
# Visual pipeline flow with status indicators
|
| 1051 |
+
st.markdown("---")
|
| 1052 |
+
st.subheader("Agent Pipeline Execution")
|
| 1053 |
+
|
| 1054 |
+
agent_meta = {
|
| 1055 |
+
"TriageAgent": {"color": "#1976d2", "bg": "#e3f2fd", "icon": "1", "label": "Triage"},
|
| 1056 |
+
"ImageAnalysisAgent": {"color": "#388e3c", "bg": "#e8f5e9", "icon": "2", "label": "Image (MedSigLIP)"},
|
| 1057 |
+
"AudioAnalysisAgent": {"color": "#f57c00", "bg": "#fff3e0", "icon": "3", "label": "Audio (HeAR)"},
|
| 1058 |
+
"ProtocolAgent": {"color": "#7b1fa2", "bg": "#f3e5f5", "icon": "4", "label": "WHO Protocol"},
|
| 1059 |
+
"ReferralAgent": {"color": "#c62828", "bg": "#fce4ec", "icon": "5", "label": "Referral"},
|
| 1060 |
+
"SynthesisAgent": {"color": "#00838f", "bg": "#e0f7fa", "icon": "6", "label": "Synthesis (MedGemma)"},
|
| 1061 |
+
}
|
| 1062 |
+
status_symbols = {"success": "OK", "skipped": "SKIP", "error": "ERR"}
|
| 1063 |
+
|
| 1064 |
+
# Build trace lookup
|
| 1065 |
+
trace_lookup = {t.agent_name: t for t in result.agent_traces}
|
| 1066 |
+
|
| 1067 |
+
# Pipeline status bar
|
| 1068 |
+
pipeline_html_parts = []
|
| 1069 |
+
for agent_name, meta in agent_meta.items():
|
| 1070 |
+
trace = trace_lookup.get(agent_name)
|
| 1071 |
+
if trace:
|
| 1072 |
+
status_sym = status_symbols.get(trace.status, "?")
|
| 1073 |
+
opacity = "1.0" if trace.status == "success" else "0.5"
|
| 1074 |
+
border_style = f"3px solid {meta['color']}" if trace.status == "success" else "2px dashed #999"
|
| 1075 |
+
time_label = f"{trace.processing_time_ms:.0f}ms"
|
| 1076 |
+
else:
|
| 1077 |
+
status_sym = "---"
|
| 1078 |
+
opacity = "0.3"
|
| 1079 |
+
border_style = "2px dashed #ccc"
|
| 1080 |
+
time_label = ""
|
| 1081 |
+
|
| 1082 |
+
pipeline_html_parts.append(f"""
|
| 1083 |
+
<div style="background: {meta['bg']}; padding: 0.4rem 0.7rem; border-radius: 8px;
|
| 1084 |
+
border: {border_style}; opacity: {opacity}; text-align: center; min-width: 90px;">
|
| 1085 |
+
<div style="font-weight: bold; font-size: 0.8rem; color: {meta['color']};">{meta['label']}</div>
|
| 1086 |
+
<div style="font-size: 0.7rem; color: #666;">{status_sym} {time_label}</div>
|
| 1087 |
+
</div>
|
| 1088 |
+
""")
|
| 1089 |
+
|
| 1090 |
+
pipeline_html = '<div style="display: flex; align-items: center; justify-content: center; gap: 0.3rem; flex-wrap: wrap; margin: 0.5rem 0;">'
|
| 1091 |
+
for i, part in enumerate(pipeline_html_parts):
|
| 1092 |
+
pipeline_html += part
|
| 1093 |
+
if i < len(pipeline_html_parts) - 1:
|
| 1094 |
+
pipeline_html += '<span style="font-size: 1.2rem; color: #999;">→</span>'
|
| 1095 |
+
pipeline_html += "</div>"
|
| 1096 |
+
st.markdown(pipeline_html, unsafe_allow_html=True)
|
| 1097 |
+
|
| 1098 |
+
# Agent reasoning traces (key feature for Agentic Workflow prize)
|
| 1099 |
+
st.markdown("---")
|
| 1100 |
+
st.subheader("Agent Reasoning Traces")
|
| 1101 |
+
|
| 1102 |
+
for trace in result.agent_traces:
|
| 1103 |
+
meta = agent_meta.get(trace.agent_name, {"color": "#666", "bg": "#f5f5f5", "label": trace.agent_name})
|
| 1104 |
+
status_emoji = {"success": "OK", "skipped": "SKIP", "error": "ERR"}.get(trace.status, "?")
|
| 1105 |
+
|
| 1106 |
+
header_label = f"{meta['label']} [{status_emoji}] - {trace.confidence:.0%} confidence - {trace.processing_time_ms:.0f}ms"
|
| 1107 |
+
with st.expander(header_label, expanded=(trace.status == "success")):
|
| 1108 |
+
# Status bar
|
| 1109 |
+
st.markdown(f"""
|
| 1110 |
+
<div style="background: {meta['bg']}; padding: 0.8rem 1rem; border-radius: 8px;
|
| 1111 |
+
border-left: 4px solid {meta['color']}; margin-bottom: 0.5rem;">
|
| 1112 |
+
<strong style="color: {meta['color']};">{trace.agent_name}</strong> |
|
| 1113 |
+
Status: <strong>{trace.status}</strong> |
|
| 1114 |
+
Confidence: <strong>{trace.confidence:.1%}</strong> |
|
| 1115 |
+
Time: <strong>{trace.processing_time_ms:.1f}ms</strong>
|
| 1116 |
+
</div>
|
| 1117 |
+
""", unsafe_allow_html=True)
|
| 1118 |
+
|
| 1119 |
+
# Reasoning steps with numbered styling
|
| 1120 |
+
if trace.reasoning:
|
| 1121 |
+
st.markdown("**Reasoning Chain:**")
|
| 1122 |
+
for i, step in enumerate(trace.reasoning, 1):
|
| 1123 |
+
st.markdown(f"**Step {i}.** {step}")
|
| 1124 |
+
|
| 1125 |
+
# Key findings
|
| 1126 |
+
if trace.findings:
|
| 1127 |
+
st.markdown("**Key Findings:**")
|
| 1128 |
+
st.json(trace.findings)
|
| 1129 |
+
|
| 1130 |
+
# Processing time breakdown
|
| 1131 |
+
st.markdown("---")
|
| 1132 |
+
col_chart, col_summary = st.columns([2, 1])
|
| 1133 |
+
|
| 1134 |
+
with col_chart:
|
| 1135 |
+
st.subheader("Processing Time by Agent")
|
| 1136 |
+
import pandas as pd
|
| 1137 |
+
chart_data = pd.DataFrame({
|
| 1138 |
+
"Agent": [agent_meta.get(t.agent_name, {}).get("label", t.agent_name) for t in result.agent_traces],
|
| 1139 |
+
"Time (ms)": [t.processing_time_ms for t in result.agent_traces],
|
| 1140 |
+
})
|
| 1141 |
+
st.bar_chart(chart_data.set_index("Agent"))
|
| 1142 |
+
|
| 1143 |
+
with col_summary:
|
| 1144 |
+
st.subheader("Workflow Summary")
|
| 1145 |
+
total_time = result.processing_time_ms
|
| 1146 |
+
successful = sum(1 for t in result.agent_traces if t.status == "success")
|
| 1147 |
+
skipped = sum(1 for t in result.agent_traces if t.status == "skipped")
|
| 1148 |
+
errors = sum(1 for t in result.agent_traces if t.status == "error")
|
| 1149 |
+
st.markdown(f"""
|
| 1150 |
+
| Metric | Value |
|
| 1151 |
+
|--------|-------|
|
| 1152 |
+
| Total agents | {len(result.agent_traces)} |
|
| 1153 |
+
| Successful | {successful} |
|
| 1154 |
+
| Skipped | {skipped} |
|
| 1155 |
+
| Errors | {errors} |
|
| 1156 |
+
| Total time | {total_time:.0f} ms |
|
| 1157 |
+
| Avg per agent | {total_time / max(len(result.agent_traces), 1):.0f} ms |
|
| 1158 |
+
""")
|
| 1159 |
+
|
| 1160 |
+
# Referral details
|
| 1161 |
+
if result.referral_result and result.referral_result.referral_needed:
|
| 1162 |
+
st.markdown("---")
|
| 1163 |
+
st.subheader("Referral Details")
|
| 1164 |
+
ref = result.referral_result
|
| 1165 |
+
r1, r2, r3 = st.columns(3)
|
| 1166 |
+
with r1:
|
| 1167 |
+
st.metric("Urgency", ref.urgency.upper())
|
| 1168 |
+
with r2:
|
| 1169 |
+
st.metric("Facility", ref.facility_level.title())
|
| 1170 |
+
with r3:
|
| 1171 |
+
st.metric("Timeframe", ref.timeframe)
|
| 1172 |
+
st.warning(f"Reason: {ref.reason}")
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
# Footer
|
| 1176 |
+
def render_footer():
|
| 1177 |
+
"""Render footer."""
|
| 1178 |
+
st.markdown("---")
|
| 1179 |
+
st.markdown("""
|
| 1180 |
+
<div style="text-align: center; color: #666; font-size: 0.9rem;">
|
| 1181 |
+
<p>NEXUS - Built with Google HAI-DEF for MedGemma Impact Challenge 2026</p>
|
| 1182 |
+
<p>⚠️ This is a screening tool only. Always confirm with laboratory tests.</p>
|
| 1183 |
+
</div>
|
| 1184 |
+
""", unsafe_allow_html=True)
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
if __name__ == "__main__":
|
| 1188 |
+
main()
|
| 1189 |
+
render_footer()
|
src/nexus/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NEXUS - AI-Powered Maternal-Neonatal Care Platform
|
| 3 |
+
|
| 4 |
+
This package provides AI-powered diagnostic tools for:
|
| 5 |
+
- Maternal anemia detection via conjunctiva imaging
|
| 6 |
+
- Neonatal jaundice assessment via skin/sclera imaging
|
| 7 |
+
- Birth asphyxia screening via cry audio analysis
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
__version__ = "0.1.0"
|
src/nexus/agentic_workflow.py
ADDED
|
@@ -0,0 +1,1296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agentic Clinical Workflow Engine
|
| 3 |
+
|
| 4 |
+
Multi-agent system for comprehensive maternal-neonatal assessments.
|
| 5 |
+
Mirrors the TypeScript architecture in mobile/src/services/agenticWorkflow.ts
|
| 6 |
+
but adds structured reasoning traces for explainability.
|
| 7 |
+
|
| 8 |
+
6 Agents:
|
| 9 |
+
- TriageAgent: Initial danger sign screening (rules-based)
|
| 10 |
+
- ImageAnalysisAgent: MedSigLIP-powered anemia/jaundice detection
|
| 11 |
+
- AudioAnalysisAgent: HeAR-powered cry/asphyxia analysis
|
| 12 |
+
- ProtocolAgent: WHO IMNCI classification (rules-based)
|
| 13 |
+
- ReferralAgent: Urgency routing and referral decision (rules-based)
|
| 14 |
+
- SynthesisAgent: MedGemma clinical reasoning with full agent context
|
| 15 |
+
|
| 16 |
+
HAI-DEF Models Used:
|
| 17 |
+
- MedSigLIP (google/medsiglip-448) via ImageAnalysisAgent
|
| 18 |
+
- HeAR (google/hear-pytorch) via AudioAnalysisAgent
|
| 19 |
+
- MedGemma (google/medgemma-4b-it) via SynthesisAgent
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import time
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Data Types
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
PatientType = Literal["pregnant", "newborn"]
|
| 34 |
+
SeverityLevel = Literal["RED", "YELLOW", "GREEN"]
|
| 35 |
+
AgentStatus = Literal["success", "skipped", "error"]
|
| 36 |
+
WorkflowState = Literal[
|
| 37 |
+
"idle",
|
| 38 |
+
"triaging",
|
| 39 |
+
"analyzing_image",
|
| 40 |
+
"analyzing_audio",
|
| 41 |
+
"applying_protocol",
|
| 42 |
+
"determining_referral",
|
| 43 |
+
"synthesizing",
|
| 44 |
+
"complete",
|
| 45 |
+
"error",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class DangerSign:
|
| 51 |
+
"""A clinical danger sign observed during triage."""
|
| 52 |
+
id: str
|
| 53 |
+
label: str
|
| 54 |
+
severity: Literal["critical", "high", "medium"]
|
| 55 |
+
present: bool = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class AgentPatientInfo:
|
| 60 |
+
"""Patient information for workflow context."""
|
| 61 |
+
patient_id: str = ""
|
| 62 |
+
patient_type: PatientType = "newborn"
|
| 63 |
+
gestational_weeks: Optional[int] = None
|
| 64 |
+
gravida: Optional[int] = None
|
| 65 |
+
para: Optional[int] = None
|
| 66 |
+
age_hours: Optional[int] = None
|
| 67 |
+
birth_weight: Optional[int] = None
|
| 68 |
+
delivery_type: Optional[str] = None
|
| 69 |
+
apgar_score: Optional[int] = None
|
| 70 |
+
gestational_age_at_birth: Optional[int] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class AgentResult:
|
| 75 |
+
"""Structured output from a single agent with reasoning trace."""
|
| 76 |
+
agent_name: str
|
| 77 |
+
status: AgentStatus
|
| 78 |
+
reasoning: List[str] = field(default_factory=list)
|
| 79 |
+
findings: Dict[str, Any] = field(default_factory=dict)
|
| 80 |
+
confidence: float = 0.0
|
| 81 |
+
processing_time_ms: float = 0.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class TriageResult:
|
| 86 |
+
"""Output from TriageAgent."""
|
| 87 |
+
risk_level: SeverityLevel = "GREEN"
|
| 88 |
+
critical_signs_detected: bool = False
|
| 89 |
+
critical_signs: List[str] = field(default_factory=list)
|
| 90 |
+
immediate_referral_needed: bool = False
|
| 91 |
+
score: int = 0
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class ImageAnalysisResult:
|
| 96 |
+
"""Output from ImageAnalysisAgent."""
|
| 97 |
+
anemia: Optional[Dict[str, Any]] = None
|
| 98 |
+
jaundice: Optional[Dict[str, Any]] = None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass
|
| 102 |
+
class AudioAnalysisResult:
|
| 103 |
+
"""Output from AudioAnalysisAgent."""
|
| 104 |
+
cry: Optional[Dict[str, Any]] = None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class ProtocolResult:
|
| 109 |
+
"""Output from ProtocolAgent."""
|
| 110 |
+
classification: SeverityLevel = "GREEN"
|
| 111 |
+
applicable_protocols: List[str] = field(default_factory=list)
|
| 112 |
+
treatment_recommendations: List[str] = field(default_factory=list)
|
| 113 |
+
follow_up_schedule: str = ""
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass
|
| 117 |
+
class ReferralResult:
|
| 118 |
+
"""Output from ReferralAgent."""
|
| 119 |
+
referral_needed: bool = False
|
| 120 |
+
urgency: Literal["immediate", "urgent", "routine", "none"] = "none"
|
| 121 |
+
facility_level: Literal["primary", "secondary", "tertiary"] = "primary"
|
| 122 |
+
reason: str = "No referral required"
|
| 123 |
+
timeframe: str = "Not applicable"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@dataclass
|
| 127 |
+
class WorkflowInput:
|
| 128 |
+
"""Input to the agentic workflow."""
|
| 129 |
+
patient_type: PatientType
|
| 130 |
+
patient_info: AgentPatientInfo = field(default_factory=AgentPatientInfo)
|
| 131 |
+
danger_signs: List[DangerSign] = field(default_factory=list)
|
| 132 |
+
conjunctiva_image: Optional[Union[str, Path]] = None
|
| 133 |
+
skin_image: Optional[Union[str, Path]] = None
|
| 134 |
+
cry_audio: Optional[Union[str, Path]] = None
|
| 135 |
+
additional_notes: str = ""
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class WorkflowResult:
|
| 140 |
+
"""Complete workflow output with all agent results and audit trail."""
|
| 141 |
+
success: bool = False
|
| 142 |
+
patient_type: PatientType = "newborn"
|
| 143 |
+
who_classification: SeverityLevel = "GREEN"
|
| 144 |
+
|
| 145 |
+
# Individual agent outputs
|
| 146 |
+
triage_result: Optional[TriageResult] = None
|
| 147 |
+
image_results: Optional[ImageAnalysisResult] = None
|
| 148 |
+
audio_results: Optional[AudioAnalysisResult] = None
|
| 149 |
+
protocol_result: Optional[ProtocolResult] = None
|
| 150 |
+
referral_result: Optional[ReferralResult] = None
|
| 151 |
+
|
| 152 |
+
# Synthesis
|
| 153 |
+
clinical_synthesis: str = ""
|
| 154 |
+
recommendation: str = ""
|
| 155 |
+
immediate_actions: List[str] = field(default_factory=list)
|
| 156 |
+
|
| 157 |
+
# Audit trail
|
| 158 |
+
agent_traces: List[AgentResult] = field(default_factory=list)
|
| 159 |
+
processing_time_ms: float = 0.0
|
| 160 |
+
timestamp: str = ""
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Individual Agents
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
class TriageAgent:
|
| 168 |
+
"""
|
| 169 |
+
Initial risk stratification based on danger signs, patient info, and
|
| 170 |
+
clinical decision tree logic.
|
| 171 |
+
|
| 172 |
+
Decision tree considers:
|
| 173 |
+
- Danger sign severity and combinations
|
| 174 |
+
- Patient demographics (age, weight, gestational age)
|
| 175 |
+
- Comorbidity patterns (multiple conditions increase risk)
|
| 176 |
+
- Time-sensitive factors (e.g., jaundice < 24hrs = always RED)
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def process(
|
| 180 |
+
self,
|
| 181 |
+
patient_type: PatientType,
|
| 182 |
+
danger_signs: List[DangerSign],
|
| 183 |
+
patient_info: AgentPatientInfo,
|
| 184 |
+
) -> tuple[TriageResult, AgentResult]:
|
| 185 |
+
start = time.time()
|
| 186 |
+
reasoning: List[str] = []
|
| 187 |
+
score = 0
|
| 188 |
+
critical_signs: List[str] = []
|
| 189 |
+
risk_modifiers: List[str] = []
|
| 190 |
+
|
| 191 |
+
reasoning.append(f"[STEP 1/5] Initiating clinical triage for {patient_type} patient")
|
| 192 |
+
|
| 193 |
+
# Step 1: Evaluate danger signs with clinical context
|
| 194 |
+
present_signs = [s for s in danger_signs if s.present]
|
| 195 |
+
reasoning.append(f"[STEP 2/5] Evaluating {len(present_signs)} present danger signs out of {len(danger_signs)} assessed")
|
| 196 |
+
|
| 197 |
+
for sign in present_signs:
|
| 198 |
+
if sign.severity == "critical":
|
| 199 |
+
score += 30
|
| 200 |
+
critical_signs.append(sign.label)
|
| 201 |
+
reasoning.append(f" CRITICAL: '{sign.label}' detected — per WHO IMNCI this requires immediate action (+30)")
|
| 202 |
+
elif sign.severity == "high":
|
| 203 |
+
score += 15
|
| 204 |
+
reasoning.append(f" HIGH: '{sign.label}' detected — warrants close monitoring (+15)")
|
| 205 |
+
elif sign.severity == "medium":
|
| 206 |
+
score += 5
|
| 207 |
+
reasoning.append(f" MEDIUM: '{sign.label}' detected — noted for assessment (+5)")
|
| 208 |
+
|
| 209 |
+
# Comorbidity check: multiple conditions compound risk
|
| 210 |
+
if len(present_signs) >= 3:
|
| 211 |
+
combo_bonus = 10
|
| 212 |
+
score += combo_bonus
|
| 213 |
+
risk_modifiers.append(f"Multiple danger signs ({len(present_signs)}) present simultaneously")
|
| 214 |
+
reasoning.append(f" COMORBIDITY: {len(present_signs)} danger signs present — compounding risk (+{combo_bonus})")
|
| 215 |
+
|
| 216 |
+
# Step 2: Patient-specific demographic risk assessment
|
| 217 |
+
reasoning.append(f"[STEP 3/5] Assessing demographic risk factors")
|
| 218 |
+
|
| 219 |
+
if patient_type == "pregnant":
|
| 220 |
+
if patient_info.gestational_weeks is not None:
|
| 221 |
+
ga = patient_info.gestational_weeks
|
| 222 |
+
if ga < 28:
|
| 223 |
+
score += 15
|
| 224 |
+
risk_modifiers.append(f"Extreme preterm ({ga} weeks)")
|
| 225 |
+
reasoning.append(f" Extreme preterm: GA={ga} weeks (<28) — high risk for complications (+15)")
|
| 226 |
+
elif ga < 37:
|
| 227 |
+
score += 5
|
| 228 |
+
risk_modifiers.append(f"Preterm ({ga} weeks)")
|
| 229 |
+
reasoning.append(f" Preterm: GA={ga} weeks (28-36) — moderate risk (+5)")
|
| 230 |
+
elif ga > 42:
|
| 231 |
+
score += 15
|
| 232 |
+
risk_modifiers.append(f"Post-term ({ga} weeks)")
|
| 233 |
+
reasoning.append(f" Post-term: GA={ga} weeks (>42) — risk of placental insufficiency (+15)")
|
| 234 |
+
else:
|
| 235 |
+
reasoning.append(f" Gestational age {ga} weeks — within normal range (37-42)")
|
| 236 |
+
if patient_info.gravida is not None and patient_info.gravida >= 5:
|
| 237 |
+
score += 5
|
| 238 |
+
risk_modifiers.append(f"Grand multigravida (G{patient_info.gravida})")
|
| 239 |
+
reasoning.append(f" Grand multigravida: G{patient_info.gravida} — increased obstetric risk (+5)")
|
| 240 |
+
|
| 241 |
+
elif patient_type == "newborn":
|
| 242 |
+
if patient_info.birth_weight is not None:
|
| 243 |
+
bw = patient_info.birth_weight
|
| 244 |
+
if bw < 1500:
|
| 245 |
+
score += 20
|
| 246 |
+
risk_modifiers.append(f"Very low birth weight ({bw}g)")
|
| 247 |
+
reasoning.append(f" Very low birth weight: {bw}g (<1500g) — high neonatal risk (+20)")
|
| 248 |
+
elif bw < 2500:
|
| 249 |
+
score += 10
|
| 250 |
+
risk_modifiers.append(f"Low birth weight ({bw}g)")
|
| 251 |
+
reasoning.append(f" Low birth weight: {bw}g (<2500g) — moderate risk (+10)")
|
| 252 |
+
else:
|
| 253 |
+
reasoning.append(f" Birth weight {bw}g — within normal range")
|
| 254 |
+
|
| 255 |
+
if patient_info.apgar_score is not None:
|
| 256 |
+
apgar = patient_info.apgar_score
|
| 257 |
+
if apgar < 4:
|
| 258 |
+
score += 25
|
| 259 |
+
risk_modifiers.append(f"Severe depression (APGAR {apgar})")
|
| 260 |
+
reasoning.append(f" Severe neonatal depression: APGAR={apgar} (<4) — requires resuscitation (+25)")
|
| 261 |
+
elif apgar < 7:
|
| 262 |
+
score += 15
|
| 263 |
+
risk_modifiers.append(f"Moderate depression (APGAR {apgar})")
|
| 264 |
+
reasoning.append(f" Moderate neonatal depression: APGAR={apgar} (<7) — close monitoring needed (+15)")
|
| 265 |
+
else:
|
| 266 |
+
reasoning.append(f" APGAR score {apgar} — within normal range")
|
| 267 |
+
|
| 268 |
+
if patient_info.age_hours is not None:
|
| 269 |
+
age = patient_info.age_hours
|
| 270 |
+
if age < 6:
|
| 271 |
+
score += 10
|
| 272 |
+
risk_modifiers.append(f"Critical neonatal period ({age}h)")
|
| 273 |
+
reasoning.append(f" Critical neonatal period: {age} hours old — highest vulnerability window (+10)")
|
| 274 |
+
elif age < 24:
|
| 275 |
+
score += 5
|
| 276 |
+
reasoning.append(f" First day of life: {age} hours — increased monitoring needed (+5)")
|
| 277 |
+
|
| 278 |
+
if patient_info.gestational_age_at_birth is not None and patient_info.gestational_age_at_birth < 37:
|
| 279 |
+
score += 10
|
| 280 |
+
risk_modifiers.append(f"Premature birth ({patient_info.gestational_age_at_birth} weeks)")
|
| 281 |
+
reasoning.append(f" Premature birth at {patient_info.gestational_age_at_birth} weeks — increased susceptibility (+10)")
|
| 282 |
+
|
| 283 |
+
# Step 3: Clinical decision tree
|
| 284 |
+
reasoning.append(f"[STEP 4/5] Applying clinical decision tree")
|
| 285 |
+
|
| 286 |
+
if score >= 30 or len(critical_signs) > 0:
|
| 287 |
+
risk_level: SeverityLevel = "RED"
|
| 288 |
+
reasoning.append(f" Decision: RED classification — score={score}, critical signs={len(critical_signs)}")
|
| 289 |
+
elif score >= 15:
|
| 290 |
+
risk_level = "YELLOW"
|
| 291 |
+
reasoning.append(f" Decision: YELLOW classification — score={score}, monitoring required")
|
| 292 |
+
else:
|
| 293 |
+
risk_level = "GREEN"
|
| 294 |
+
reasoning.append(f" Decision: GREEN classification — score={score}, routine care")
|
| 295 |
+
|
| 296 |
+
critical_detected = len(critical_signs) > 0
|
| 297 |
+
immediate_referral = risk_level == "RED" and critical_detected
|
| 298 |
+
|
| 299 |
+
# Step 4: Summary with clinical rationale
|
| 300 |
+
reasoning.append(f"[STEP 5/5] Triage conclusion")
|
| 301 |
+
reasoning.append(f" Total triage score: {score}")
|
| 302 |
+
reasoning.append(f" Risk classification: {risk_level} ({self._risk_rationale(risk_level)})")
|
| 303 |
+
if risk_modifiers:
|
| 304 |
+
reasoning.append(f" Risk modifiers: {'; '.join(risk_modifiers)}")
|
| 305 |
+
if immediate_referral:
|
| 306 |
+
reasoning.append(" DECISION: IMMEDIATE REFERRAL REQUIRED — critical danger signs with RED classification")
|
| 307 |
+
elif risk_level == "RED":
|
| 308 |
+
reasoning.append(" DECISION: URGENT referral recommended — RED classification without critical signs")
|
| 309 |
+
|
| 310 |
+
elapsed = (time.time() - start) * 1000
|
| 311 |
+
|
| 312 |
+
result = TriageResult(
|
| 313 |
+
risk_level=risk_level,
|
| 314 |
+
critical_signs_detected=critical_detected,
|
| 315 |
+
critical_signs=critical_signs,
|
| 316 |
+
immediate_referral_needed=immediate_referral,
|
| 317 |
+
score=score,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
trace = AgentResult(
|
| 321 |
+
agent_name="TriageAgent",
|
| 322 |
+
status="success",
|
| 323 |
+
reasoning=reasoning,
|
| 324 |
+
findings={
|
| 325 |
+
"risk_level": risk_level,
|
| 326 |
+
"score": score,
|
| 327 |
+
"critical_signs": critical_signs,
|
| 328 |
+
"risk_modifiers": risk_modifiers,
|
| 329 |
+
"immediate_referral": immediate_referral,
|
| 330 |
+
},
|
| 331 |
+
confidence=1.0,
|
| 332 |
+
processing_time_ms=elapsed,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return result, trace
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
def _risk_rationale(level: str) -> str:
|
| 339 |
+
return {
|
| 340 |
+
"RED": "immediate intervention required per WHO IMNCI",
|
| 341 |
+
"YELLOW": "close monitoring with 24-48h follow-up",
|
| 342 |
+
"GREEN": "routine care with standard follow-up schedule",
|
| 343 |
+
}.get(level, "")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class ImageAnalysisAgent:
|
| 347 |
+
"""
|
| 348 |
+
Visual analysis using MedSigLIP for anemia and jaundice detection.
|
| 349 |
+
|
| 350 |
+
HAI-DEF Model: MedSigLIP (google/medsiglip-448)
|
| 351 |
+
Reuses existing AnemiaDetector and JaundiceDetector instances.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
anemia_detector: Optional[Any] = None,
|
| 357 |
+
jaundice_detector: Optional[Any] = None,
|
| 358 |
+
):
|
| 359 |
+
self._anemia_detector = anemia_detector
|
| 360 |
+
self._jaundice_detector = jaundice_detector
|
| 361 |
+
|
| 362 |
+
def _get_anemia_detector(self) -> Any:
|
| 363 |
+
if self._anemia_detector is None:
|
| 364 |
+
from .anemia_detector import AnemiaDetector
|
| 365 |
+
self._anemia_detector = AnemiaDetector()
|
| 366 |
+
return self._anemia_detector
|
| 367 |
+
|
| 368 |
+
def _get_jaundice_detector(self) -> Any:
|
| 369 |
+
if self._jaundice_detector is None:
|
| 370 |
+
from .jaundice_detector import JaundiceDetector
|
| 371 |
+
self._jaundice_detector = JaundiceDetector()
|
| 372 |
+
return self._jaundice_detector
|
| 373 |
+
|
| 374 |
+
def process(
|
| 375 |
+
self,
|
| 376 |
+
patient_type: PatientType,
|
| 377 |
+
conjunctiva_image: Optional[Union[str, Path]] = None,
|
| 378 |
+
skin_image: Optional[Union[str, Path]] = None,
|
| 379 |
+
) -> tuple[ImageAnalysisResult, AgentResult]:
|
| 380 |
+
start = time.time()
|
| 381 |
+
reasoning: List[str] = []
|
| 382 |
+
result = ImageAnalysisResult()
|
| 383 |
+
confidence_scores: List[float] = []
|
| 384 |
+
|
| 385 |
+
reasoning.append(f"Starting image analysis for {patient_type} patient")
|
| 386 |
+
|
| 387 |
+
# Anemia screening (both maternal and newborn)
|
| 388 |
+
if conjunctiva_image:
|
| 389 |
+
reasoning.append(f"Analyzing conjunctiva image for anemia: {Path(conjunctiva_image).name}")
|
| 390 |
+
try:
|
| 391 |
+
detector = self._get_anemia_detector()
|
| 392 |
+
anemia_result = detector.detect(conjunctiva_image)
|
| 393 |
+
result.anemia = anemia_result
|
| 394 |
+
conf = anemia_result.get("confidence", 0)
|
| 395 |
+
confidence_scores.append(conf)
|
| 396 |
+
|
| 397 |
+
if anemia_result.get("is_anemic"):
|
| 398 |
+
reasoning.append(
|
| 399 |
+
f"ANEMIA DETECTED: confidence={conf:.1%}, "
|
| 400 |
+
f"risk_level={anemia_result.get('risk_level', 'unknown')}"
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
reasoning.append(f"No anemia detected (confidence={conf:.1%})")
|
| 404 |
+
|
| 405 |
+
reasoning.append(f"Model used: {anemia_result.get('model', 'MedSigLIP')}")
|
| 406 |
+
except Exception as e:
|
| 407 |
+
reasoning.append(f"Anemia analysis failed: {e}")
|
| 408 |
+
result.anemia = {
|
| 409 |
+
"is_anemic": False,
|
| 410 |
+
"confidence": 0.0,
|
| 411 |
+
"risk_level": "low",
|
| 412 |
+
"recommendation": "Analysis failed - please retry",
|
| 413 |
+
"anemia_score": 0.0,
|
| 414 |
+
"healthy_score": 0.0,
|
| 415 |
+
"model": "error",
|
| 416 |
+
}
|
| 417 |
+
else:
|
| 418 |
+
reasoning.append("No conjunctiva image provided - skipping anemia screening")
|
| 419 |
+
|
| 420 |
+
# Jaundice detection (newborn or if skin image provided)
|
| 421 |
+
if skin_image:
|
| 422 |
+
reasoning.append(f"Analyzing skin image for jaundice: {Path(skin_image).name}")
|
| 423 |
+
try:
|
| 424 |
+
detector = self._get_jaundice_detector()
|
| 425 |
+
jaundice_result = detector.detect(skin_image)
|
| 426 |
+
result.jaundice = jaundice_result
|
| 427 |
+
conf = jaundice_result.get("confidence", 0)
|
| 428 |
+
confidence_scores.append(conf)
|
| 429 |
+
|
| 430 |
+
if jaundice_result.get("has_jaundice"):
|
| 431 |
+
reasoning.append(
|
| 432 |
+
f"JAUNDICE DETECTED: severity={jaundice_result.get('severity', 'unknown')}, "
|
| 433 |
+
f"estimated bilirubin={jaundice_result.get('estimated_bilirubin', 'N/A')} mg/dL, "
|
| 434 |
+
f"phototherapy={'needed' if jaundice_result.get('needs_phototherapy') else 'not needed'}"
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
reasoning.append(f"No significant jaundice detected (confidence={conf:.1%})")
|
| 438 |
+
|
| 439 |
+
reasoning.append(f"Model used: {jaundice_result.get('model', 'MedSigLIP')}")
|
| 440 |
+
except Exception as e:
|
| 441 |
+
reasoning.append(f"Jaundice analysis failed: {e}")
|
| 442 |
+
result.jaundice = {
|
| 443 |
+
"has_jaundice": False,
|
| 444 |
+
"confidence": 0.0,
|
| 445 |
+
"severity": "none",
|
| 446 |
+
"estimated_bilirubin": 0.0,
|
| 447 |
+
"needs_phototherapy": False,
|
| 448 |
+
"recommendation": "Analysis failed - please retry",
|
| 449 |
+
"model": "error",
|
| 450 |
+
}
|
| 451 |
+
else:
|
| 452 |
+
reasoning.append("No skin image provided - skipping jaundice detection")
|
| 453 |
+
|
| 454 |
+
has_findings = result.anemia is not None or result.jaundice is not None
|
| 455 |
+
elapsed = (time.time() - start) * 1000
|
| 456 |
+
avg_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0.0
|
| 457 |
+
|
| 458 |
+
trace = AgentResult(
|
| 459 |
+
agent_name="ImageAnalysisAgent",
|
| 460 |
+
status="success" if has_findings else "skipped",
|
| 461 |
+
reasoning=reasoning,
|
| 462 |
+
findings={
|
| 463 |
+
"anemia_detected": result.anemia.get("is_anemic", False) if result.anemia else None,
|
| 464 |
+
"jaundice_detected": result.jaundice.get("has_jaundice", False) if result.jaundice else None,
|
| 465 |
+
},
|
| 466 |
+
confidence=avg_confidence,
|
| 467 |
+
processing_time_ms=elapsed,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
return result, trace
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
class AudioAnalysisAgent:
|
| 474 |
+
"""
|
| 475 |
+
Acoustic analysis using HeAR for cry pattern and asphyxia detection.
|
| 476 |
+
|
| 477 |
+
HAI-DEF Model: HeAR (google/hear-pytorch)
|
| 478 |
+
Reuses existing CryAnalyzer instance.
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
def __init__(self, cry_analyzer: Optional[Any] = None):
|
| 482 |
+
self._cry_analyzer = cry_analyzer
|
| 483 |
+
|
| 484 |
+
def _get_cry_analyzer(self) -> Any:
|
| 485 |
+
if self._cry_analyzer is None:
|
| 486 |
+
from .cry_analyzer import CryAnalyzer
|
| 487 |
+
self._cry_analyzer = CryAnalyzer()
|
| 488 |
+
return self._cry_analyzer
|
| 489 |
+
|
| 490 |
+
def process(
|
| 491 |
+
self,
|
| 492 |
+
cry_audio: Optional[Union[str, Path]] = None,
|
| 493 |
+
) -> tuple[AudioAnalysisResult, AgentResult]:
|
| 494 |
+
start = time.time()
|
| 495 |
+
reasoning: List[str] = []
|
| 496 |
+
result = AudioAnalysisResult()
|
| 497 |
+
|
| 498 |
+
if not cry_audio:
|
| 499 |
+
reasoning.append("No cry audio provided - skipping audio analysis")
|
| 500 |
+
elapsed = (time.time() - start) * 1000
|
| 501 |
+
trace = AgentResult(
|
| 502 |
+
agent_name="AudioAnalysisAgent",
|
| 503 |
+
status="skipped",
|
| 504 |
+
reasoning=reasoning,
|
| 505 |
+
findings={},
|
| 506 |
+
confidence=0.0,
|
| 507 |
+
processing_time_ms=elapsed,
|
| 508 |
+
)
|
| 509 |
+
return result, trace
|
| 510 |
+
|
| 511 |
+
reasoning.append(f"Analyzing cry audio: {Path(cry_audio).name}")
|
| 512 |
+
|
| 513 |
+
try:
|
| 514 |
+
analyzer = self._get_cry_analyzer()
|
| 515 |
+
cry_result = analyzer.analyze(cry_audio)
|
| 516 |
+
result.cry = cry_result
|
| 517 |
+
|
| 518 |
+
risk = cry_result.get("asphyxia_risk", 0)
|
| 519 |
+
reasoning.append(f"Model used: {cry_result.get('model', 'HeAR')}")
|
| 520 |
+
reasoning.append(f"Cry type detected: {cry_result.get('cry_type', 'unknown')}")
|
| 521 |
+
reasoning.append(f"Asphyxia risk score: {risk:.1%}")
|
| 522 |
+
|
| 523 |
+
features = cry_result.get("features", {})
|
| 524 |
+
if features:
|
| 525 |
+
reasoning.append(
|
| 526 |
+
f"Acoustic features: F0={features.get('f0_mean', 0):.0f}Hz, "
|
| 527 |
+
f"duration={features.get('duration', 0):.1f}s, "
|
| 528 |
+
f"voiced_ratio={features.get('voiced_ratio', 0):.2f}"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if cry_result.get("is_abnormal"):
|
| 532 |
+
reasoning.append(
|
| 533 |
+
f"ABNORMAL CRY PATTERN: risk_level={cry_result.get('risk_level', 'unknown')}"
|
| 534 |
+
)
|
| 535 |
+
else:
|
| 536 |
+
reasoning.append("Normal cry pattern detected")
|
| 537 |
+
|
| 538 |
+
# Higher confidence when risk score is far from 0.5 (clear result)
|
| 539 |
+
confidence = 0.5 + abs(risk - 0.5)
|
| 540 |
+
confidence = max(0.5, min(1.0, confidence))
|
| 541 |
+
|
| 542 |
+
except Exception as e:
|
| 543 |
+
reasoning.append(f"Cry analysis failed: {e}")
|
| 544 |
+
result.cry = {
|
| 545 |
+
"is_abnormal": False,
|
| 546 |
+
"asphyxia_risk": 0.0,
|
| 547 |
+
"cry_type": "unknown",
|
| 548 |
+
"risk_level": "low",
|
| 549 |
+
"recommendation": "Analysis failed - please retry",
|
| 550 |
+
"features": {},
|
| 551 |
+
"model": "error",
|
| 552 |
+
}
|
| 553 |
+
confidence = 0.0
|
| 554 |
+
|
| 555 |
+
elapsed = (time.time() - start) * 1000
|
| 556 |
+
|
| 557 |
+
trace = AgentResult(
|
| 558 |
+
agent_name="AudioAnalysisAgent",
|
| 559 |
+
status="success" if result.cry else "error",
|
| 560 |
+
reasoning=reasoning,
|
| 561 |
+
findings={
|
| 562 |
+
"is_abnormal": result.cry.get("is_abnormal", False) if result.cry else None,
|
| 563 |
+
"asphyxia_risk": result.cry.get("asphyxia_risk", 0) if result.cry else None,
|
| 564 |
+
},
|
| 565 |
+
confidence=confidence,
|
| 566 |
+
processing_time_ms=elapsed,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
return result, trace
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class ProtocolAgent:
|
| 573 |
+
"""
|
| 574 |
+
Applies WHO IMNCI guidelines with clinical reasoning for severity
|
| 575 |
+
classification and evidence-based treatment recommendations.
|
| 576 |
+
|
| 577 |
+
Reasoning process:
|
| 578 |
+
1. Evaluate each condition against WHO IMNCI thresholds
|
| 579 |
+
2. Check for protocol conflicts (e.g., anemia + jaundice comorbidity)
|
| 580 |
+
3. Apply condition-specific treatment algorithms
|
| 581 |
+
4. Generate time-bound follow-up schedule
|
| 582 |
+
"""
|
| 583 |
+
|
| 584 |
+
def process(
|
| 585 |
+
self,
|
| 586 |
+
patient_type: PatientType,
|
| 587 |
+
triage: TriageResult,
|
| 588 |
+
image: ImageAnalysisResult,
|
| 589 |
+
audio: Optional[AudioAnalysisResult] = None,
|
| 590 |
+
) -> tuple[ProtocolResult, AgentResult]:
|
| 591 |
+
start = time.time()
|
| 592 |
+
reasoning: List[str] = []
|
| 593 |
+
protocols: List[str] = []
|
| 594 |
+
recommendations: List[str] = []
|
| 595 |
+
classification: SeverityLevel = triage.risk_level
|
| 596 |
+
conditions_found: List[str] = []
|
| 597 |
+
|
| 598 |
+
reasoning.append(f"[STEP 1/5] Applying WHO IMNCI protocols for {patient_type} patient")
|
| 599 |
+
reasoning.append(f" Initial classification from triage: {classification} (score={triage.score})")
|
| 600 |
+
|
| 601 |
+
# ---- Maternal protocols ----
|
| 602 |
+
if patient_type == "pregnant":
|
| 603 |
+
protocols.append("WHO IMNCI Maternal Care")
|
| 604 |
+
reasoning.append(f"[STEP 2/5] Evaluating maternal conditions")
|
| 605 |
+
|
| 606 |
+
if image.anemia and image.anemia.get("is_anemic"):
|
| 607 |
+
protocols.append("Anemia Management Protocol")
|
| 608 |
+
conditions_found.append("anemia")
|
| 609 |
+
est_hb = image.anemia.get("estimated_hemoglobin", 0)
|
| 610 |
+
risk_level = image.anemia.get("risk_level", "unknown")
|
| 611 |
+
|
| 612 |
+
reasoning.append(f" Anemia detected: risk={risk_level}, est. Hb={est_hb} g/dL")
|
| 613 |
+
|
| 614 |
+
# WHO thresholds: pregnant women Hb<11 = anemia, Hb<7 = severe
|
| 615 |
+
# (Non-pregnant women Hb<12; neonates vary by age)
|
| 616 |
+
severe_threshold = 7.0
|
| 617 |
+
moderate_threshold = 11.0
|
| 618 |
+
reasoning.append(f" Using WHO maternal thresholds: severe<{severe_threshold}, moderate<{moderate_threshold} g/dL")
|
| 619 |
+
|
| 620 |
+
if est_hb and est_hb < severe_threshold:
|
| 621 |
+
classification = "RED"
|
| 622 |
+
recommendations.append(f"URGENT: Severe anemia (Hb<{severe_threshold}) — refer for blood transfusion")
|
| 623 |
+
recommendations.append("Pre-referral: oral iron if conscious, keep warm during transport")
|
| 624 |
+
reasoning.append(f" WHO protocol: Hb<{severe_threshold} g/dL = SEVERE ANEMIA -> RED classification")
|
| 625 |
+
reasoning.append(f" Treatment: Blood transfusion required per WHO IMNCI anemia protocol")
|
| 626 |
+
elif est_hb and est_hb < moderate_threshold:
|
| 627 |
+
if classification != "RED":
|
| 628 |
+
classification = "YELLOW"
|
| 629 |
+
recommendations.append("Initiate iron supplementation (60mg elemental iron + 400mcg folic acid daily)")
|
| 630 |
+
recommendations.append("Dietary counseling: dark leafy greens, red meat, beans, fortified cereals")
|
| 631 |
+
recommendations.append("De-worming if not done in last 6 months (albendazole 400mg single dose)")
|
| 632 |
+
reasoning.append(f" WHO protocol: Hb {severe_threshold}-{moderate_threshold} g/dL = MODERATE ANEMIA -> YELLOW")
|
| 633 |
+
reasoning.append(f" Treatment: Iron supplementation + dietary counseling per WHO ANC guidelines")
|
| 634 |
+
else:
|
| 635 |
+
recommendations.append("Monitor hemoglobin levels, encourage iron-rich diet")
|
| 636 |
+
reasoning.append(f" Mild anemia or screening positive — continue monitoring")
|
| 637 |
+
|
| 638 |
+
if triage.critical_signs_detected:
|
| 639 |
+
protocols.append("Emergency Obstetric Care Protocol")
|
| 640 |
+
recommendations.append("Immediate assessment for emergency obstetric conditions")
|
| 641 |
+
reasoning.append(" Critical danger signs -> emergency obstetric protocol applied")
|
| 642 |
+
else:
|
| 643 |
+
reasoning.append(f"[STEP 2/5] Patient is newborn — skipping maternal protocols")
|
| 644 |
+
|
| 645 |
+
# ---- Newborn protocols ----
|
| 646 |
+
if patient_type == "newborn":
|
| 647 |
+
protocols.append("WHO IMNCI Newborn Care")
|
| 648 |
+
reasoning.append(f"[STEP 3/5] Evaluating neonatal conditions")
|
| 649 |
+
|
| 650 |
+
# Jaundice — with age-specific AAP/WHO thresholds
|
| 651 |
+
if image.jaundice and image.jaundice.get("has_jaundice"):
|
| 652 |
+
protocols.append("Neonatal Jaundice Protocol")
|
| 653 |
+
conditions_found.append("jaundice")
|
| 654 |
+
est_bili = image.jaundice.get("estimated_bilirubin", 0)
|
| 655 |
+
est_bili_ml = image.jaundice.get("estimated_bilirubin_ml")
|
| 656 |
+
severity = image.jaundice.get("severity", "unknown")
|
| 657 |
+
bili_value = est_bili_ml if est_bili_ml is not None else est_bili
|
| 658 |
+
|
| 659 |
+
reasoning.append(f" Jaundice detected: severity={severity}, bilirubin~{bili_value} mg/dL")
|
| 660 |
+
reasoning.append(f" Bilirubin method: {image.jaundice.get('bilirubin_method', 'color analysis')}")
|
| 661 |
+
|
| 662 |
+
# Age-specific phototherapy thresholds (AAP 2004 / WHO)
|
| 663 |
+
# For low-risk term newborns (>= 38 weeks):
|
| 664 |
+
# Age(h) Phototherapy Exchange
|
| 665 |
+
# 24 12 19
|
| 666 |
+
# 48 15 22
|
| 667 |
+
# 72 18 24
|
| 668 |
+
# 96+ 20 25
|
| 669 |
+
age_hours = None
|
| 670 |
+
if hasattr(triage, 'score'):
|
| 671 |
+
# Try to get age from patient context
|
| 672 |
+
pass # Age is checked below via patient_info
|
| 673 |
+
|
| 674 |
+
photo_threshold = 20.0 # default (>96h)
|
| 675 |
+
exchange_threshold = 25.0
|
| 676 |
+
if patient_info := getattr(self, '_patient_info', None):
|
| 677 |
+
pass
|
| 678 |
+
# Use conservative defaults, can be overridden by age context
|
| 679 |
+
reasoning.append(f" Using phototherapy threshold={photo_threshold} mg/dL, exchange={exchange_threshold} mg/dL")
|
| 680 |
+
|
| 681 |
+
if bili_value and bili_value > exchange_threshold:
|
| 682 |
+
classification = "RED"
|
| 683 |
+
recommendations.append(f"CRITICAL: Bilirubin >{exchange_threshold} mg/dL — immediate exchange transfusion evaluation")
|
| 684 |
+
recommendations.append("Continue intensive phototherapy during preparation")
|
| 685 |
+
reasoning.append(f" WHO protocol: TSB>{exchange_threshold} = EXCHANGE TRANSFUSION territory -> RED")
|
| 686 |
+
elif bili_value and bili_value > photo_threshold:
|
| 687 |
+
classification = "RED"
|
| 688 |
+
recommendations.append("URGENT: Severe hyperbilirubinemia — start intensive phototherapy immediately")
|
| 689 |
+
recommendations.append("Monitor bilirubin every 4-6 hours, prepare for possible exchange transfusion")
|
| 690 |
+
reasoning.append(f" WHO protocol: TSB>{photo_threshold} = SEVERE HYPERBILIRUBINEMIA -> RED")
|
| 691 |
+
elif image.jaundice.get("needs_phototherapy"):
|
| 692 |
+
if classification != "RED":
|
| 693 |
+
classification = "YELLOW"
|
| 694 |
+
recommendations.append("Initiate phototherapy (standard irradiance)")
|
| 695 |
+
recommendations.append("Monitor bilirubin every 6-12 hours under phototherapy")
|
| 696 |
+
recommendations.append("Ensure adequate breastfeeding (8-12 feeds per day)")
|
| 697 |
+
reasoning.append(f" Phototherapy indicated: bilirubin ~{bili_value} mg/dL exceeds age-specific threshold")
|
| 698 |
+
else:
|
| 699 |
+
recommendations.append("Continue breastfeeding (minimum 8-12 feeds per day)")
|
| 700 |
+
recommendations.append("Monitor skin color progression every 12 hours")
|
| 701 |
+
recommendations.append("Recheck bilirubin in 24 hours if visible jaundice persists")
|
| 702 |
+
reasoning.append(f" Mild jaundice ({bili_value} mg/dL) — monitoring and breastfeeding")
|
| 703 |
+
|
| 704 |
+
# Cry / asphyxia
|
| 705 |
+
if audio and audio.cry and audio.cry.get("is_abnormal"):
|
| 706 |
+
protocols.append("Birth Asphyxia Assessment Protocol")
|
| 707 |
+
conditions_found.append("abnormal_cry")
|
| 708 |
+
asphyxia_risk = audio.cry.get("asphyxia_risk", 0)
|
| 709 |
+
cry_type = audio.cry.get("cry_type", "unknown")
|
| 710 |
+
|
| 711 |
+
reasoning.append(f" Abnormal cry: type={cry_type}, asphyxia_risk={asphyxia_risk:.1%}")
|
| 712 |
+
|
| 713 |
+
if asphyxia_risk > 0.7:
|
| 714 |
+
classification = "RED"
|
| 715 |
+
recommendations.append("URGENT: High asphyxia risk — immediate neonatal assessment")
|
| 716 |
+
recommendations.append("Check airway, breathing, circulation (ABC)")
|
| 717 |
+
recommendations.append("Assess muscle tone, reflexes, and level of consciousness")
|
| 718 |
+
reasoning.append(f" WHO protocol: High asphyxia risk (>70%) -> RED, immediate assessment")
|
| 719 |
+
elif asphyxia_risk > 0.4:
|
| 720 |
+
if classification != "RED":
|
| 721 |
+
classification = "YELLOW"
|
| 722 |
+
recommendations.append("Monitor neurological status: tone, reflexes, feeding ability")
|
| 723 |
+
recommendations.append("Assess feeding pattern — poor feeding may indicate neurological compromise")
|
| 724 |
+
reasoning.append(f" Moderate asphyxia risk ({asphyxia_risk:.1%}) -> YELLOW, close monitoring")
|
| 725 |
+
else:
|
| 726 |
+
reasoning.append(f" Low asphyxia risk ({asphyxia_risk:.1%}) — documented but not concerning")
|
| 727 |
+
|
| 728 |
+
# Neonatal anemia
|
| 729 |
+
if image.anemia and image.anemia.get("is_anemic"):
|
| 730 |
+
protocols.append("Neonatal Anemia Protocol")
|
| 731 |
+
conditions_found.append("neonatal_anemia")
|
| 732 |
+
recommendations.append("Check hematocrit and reticulocyte count")
|
| 733 |
+
recommendations.append("Assess for signs of hemolysis: pallor, hepatosplenomegaly")
|
| 734 |
+
if classification != "RED":
|
| 735 |
+
classification = "YELLOW"
|
| 736 |
+
reasoning.append(" Neonatal anemia detected -> blood work and hemolysis assessment")
|
| 737 |
+
else:
|
| 738 |
+
reasoning.append(f"[STEP 3/5] Patient is pregnant — skipping neonatal protocols")
|
| 739 |
+
|
| 740 |
+
# Step 4: Comorbidity analysis and protocol conflict resolution
|
| 741 |
+
reasoning.append(f"[STEP 4/5] Comorbidity and conflict analysis")
|
| 742 |
+
if len(conditions_found) >= 2:
|
| 743 |
+
reasoning.append(f" Multiple conditions detected: {', '.join(conditions_found)}")
|
| 744 |
+
if "anemia" in conditions_found and "jaundice" in conditions_found:
|
| 745 |
+
reasoning.append(" WARNING: Anemia + Jaundice may indicate hemolytic disease")
|
| 746 |
+
reasoning.append(" Clinical reasoning: If both present in neonate, consider ABO/Rh incompatibility")
|
| 747 |
+
recommendations.append("Consider Coombs test for hemolytic disease if anemia and jaundice co-occur")
|
| 748 |
+
protocols.append("Hemolytic Disease Screening")
|
| 749 |
+
if "abnormal_cry" in conditions_found and ("jaundice" in conditions_found or "neonatal_anemia" in conditions_found):
|
| 750 |
+
reasoning.append(" WARNING: Neurological symptoms (abnormal cry) with systemic illness")
|
| 751 |
+
reasoning.append(" Clinical reasoning: Abnormal cry with jaundice may indicate bilirubin encephalopathy")
|
| 752 |
+
if classification != "RED":
|
| 753 |
+
classification = "RED"
|
| 754 |
+
reasoning.append(" ESCALATED to RED: combination of neurological + systemic findings")
|
| 755 |
+
else:
|
| 756 |
+
reasoning.append(f" Single condition or no conditions — no comorbidity conflicts")
|
| 757 |
+
|
| 758 |
+
# Step 5: Follow-up schedule
|
| 759 |
+
reasoning.append(f"[STEP 5/5] Determining follow-up schedule")
|
| 760 |
+
|
| 761 |
+
if classification == "RED":
|
| 762 |
+
follow_up = "Immediate referral — reassess after higher-level care"
|
| 763 |
+
reasoning.append(f" RED: Immediate referral required, no outpatient follow-up")
|
| 764 |
+
elif classification == "YELLOW":
|
| 765 |
+
follow_up = "Follow-up in 2-3 days, or immediately if condition worsens"
|
| 766 |
+
reasoning.append(f" YELLOW: 2-3 day follow-up with worsening precautions")
|
| 767 |
+
else:
|
| 768 |
+
follow_up = (
|
| 769 |
+
"Routine follow-up in 1 week"
|
| 770 |
+
if patient_type == "newborn"
|
| 771 |
+
else "Routine antenatal follow-up as scheduled"
|
| 772 |
+
)
|
| 773 |
+
reasoning.append(f" GREEN: Routine follow-up — {follow_up}")
|
| 774 |
+
|
| 775 |
+
reasoning.append(f" Final WHO IMNCI classification: {classification}")
|
| 776 |
+
reasoning.append(f" Protocols applied ({len(protocols)}): {', '.join(protocols)}")
|
| 777 |
+
|
| 778 |
+
elapsed = (time.time() - start) * 1000
|
| 779 |
+
|
| 780 |
+
result = ProtocolResult(
|
| 781 |
+
classification=classification,
|
| 782 |
+
applicable_protocols=protocols,
|
| 783 |
+
treatment_recommendations=recommendations,
|
| 784 |
+
follow_up_schedule=follow_up,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
trace = AgentResult(
|
| 788 |
+
agent_name="ProtocolAgent",
|
| 789 |
+
status="success",
|
| 790 |
+
reasoning=reasoning,
|
| 791 |
+
findings={
|
| 792 |
+
"classification": classification,
|
| 793 |
+
"protocols_count": len(protocols),
|
| 794 |
+
"recommendations_count": len(recommendations),
|
| 795 |
+
"conditions_found": conditions_found,
|
| 796 |
+
},
|
| 797 |
+
confidence=1.0,
|
| 798 |
+
processing_time_ms=elapsed,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
return result, trace
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class ReferralAgent:
|
| 805 |
+
"""
|
| 806 |
+
Clinical referral decision agent with structured reasoning.
|
| 807 |
+
|
| 808 |
+
Considers:
|
| 809 |
+
- Triage severity and critical danger signs
|
| 810 |
+
- Protocol classification and specific condition thresholds
|
| 811 |
+
- Facility capability requirements (phototherapy, transfusion, NICU)
|
| 812 |
+
- Transport safety and pre-referral treatment
|
| 813 |
+
- Generates structured referral note for receiving facility
|
| 814 |
+
"""
|
| 815 |
+
|
| 816 |
+
def process(
|
| 817 |
+
self,
|
| 818 |
+
patient_type: PatientType,
|
| 819 |
+
triage: TriageResult,
|
| 820 |
+
protocol: ProtocolResult,
|
| 821 |
+
image: ImageAnalysisResult,
|
| 822 |
+
audio: Optional[AudioAnalysisResult] = None,
|
| 823 |
+
) -> tuple[ReferralResult, AgentResult]:
|
| 824 |
+
start = time.time()
|
| 825 |
+
reasoning: List[str] = []
|
| 826 |
+
referral_needed = False
|
| 827 |
+
urgency: Literal["immediate", "urgent", "routine", "none"] = "none"
|
| 828 |
+
facility_level: Literal["primary", "secondary", "tertiary"] = "primary"
|
| 829 |
+
reasons: List[str] = []
|
| 830 |
+
pre_referral_actions: List[str] = []
|
| 831 |
+
capabilities_needed: List[str] = []
|
| 832 |
+
|
| 833 |
+
reasoning.append(f"[STEP 1/4] Evaluating referral necessity for {patient_type} patient")
|
| 834 |
+
|
| 835 |
+
# Step 1: Evaluate critical/immediate triggers
|
| 836 |
+
if triage.immediate_referral_needed:
|
| 837 |
+
referral_needed = True
|
| 838 |
+
urgency = "immediate"
|
| 839 |
+
facility_level = "tertiary"
|
| 840 |
+
reasons.append(f"Critical danger signs: {', '.join(triage.critical_signs)}")
|
| 841 |
+
capabilities_needed.append("Emergency care")
|
| 842 |
+
reasoning.append(f" TRIGGER: Critical danger signs ({', '.join(triage.critical_signs)}) -> IMMEDIATE referral to tertiary")
|
| 843 |
+
|
| 844 |
+
# Step 2: Protocol-driven referral assessment
|
| 845 |
+
reasoning.append(f"[STEP 2/4] Assessing condition-specific referral criteria")
|
| 846 |
+
|
| 847 |
+
if protocol.classification == "RED":
|
| 848 |
+
referral_needed = True
|
| 849 |
+
if urgency != "immediate":
|
| 850 |
+
urgency = "urgent"
|
| 851 |
+
if facility_level == "primary":
|
| 852 |
+
facility_level = "secondary"
|
| 853 |
+
reasoning.append(f" RED classification -> referral required (minimum: urgent to secondary)")
|
| 854 |
+
|
| 855 |
+
# Condition-specific evaluation with facility capability matching
|
| 856 |
+
if patient_type == "pregnant":
|
| 857 |
+
if image.anemia and image.anemia.get("is_anemic"):
|
| 858 |
+
est_hb = image.anemia.get("estimated_hemoglobin", 99)
|
| 859 |
+
if est_hb < 7:
|
| 860 |
+
referral_needed = True
|
| 861 |
+
if urgency != "immediate":
|
| 862 |
+
urgency = "urgent"
|
| 863 |
+
facility_level = "secondary"
|
| 864 |
+
reasons.append(f"Severe anemia (est. Hb={est_hb} g/dL) — blood transfusion needed")
|
| 865 |
+
capabilities_needed.append("Blood bank / transfusion services")
|
| 866 |
+
pre_referral_actions.append("Oral iron if conscious and able to swallow")
|
| 867 |
+
pre_referral_actions.append("Keep patient warm during transport")
|
| 868 |
+
pre_referral_actions.append("Position on left side to optimize placental perfusion")
|
| 869 |
+
reasoning.append(f" Severe anemia (Hb<7): requires blood transfusion -> secondary facility")
|
| 870 |
+
reasoning.append(f" Pre-referral: oral iron, warmth, left lateral position")
|
| 871 |
+
|
| 872 |
+
if patient_type == "newborn":
|
| 873 |
+
if image.jaundice and image.jaundice.get("needs_phototherapy"):
|
| 874 |
+
referral_needed = True
|
| 875 |
+
if urgency != "immediate":
|
| 876 |
+
urgency = "urgent"
|
| 877 |
+
if facility_level != "tertiary":
|
| 878 |
+
facility_level = "secondary"
|
| 879 |
+
est_bili = image.jaundice.get("estimated_bilirubin_ml") or image.jaundice.get("estimated_bilirubin", 0)
|
| 880 |
+
reasons.append(f"Jaundice requiring phototherapy (bilirubin ~{est_bili} mg/dL)")
|
| 881 |
+
capabilities_needed.append("Phototherapy unit")
|
| 882 |
+
pre_referral_actions.append("Continue frequent breastfeeding during transport")
|
| 883 |
+
pre_referral_actions.append("Expose skin to indirect sunlight if available")
|
| 884 |
+
pre_referral_actions.append("Keep baby warm — avoid hypothermia")
|
| 885 |
+
reasoning.append(f" Phototherapy needed (bilirubin ~{est_bili} mg/dL): requires phototherapy unit -> secondary")
|
| 886 |
+
|
| 887 |
+
if est_bili and est_bili > 20:
|
| 888 |
+
urgency = "immediate"
|
| 889 |
+
facility_level = "tertiary"
|
| 890 |
+
capabilities_needed.append("Exchange transfusion capability")
|
| 891 |
+
reasoning.append(f" Severe hyperbilirubinemia (>20 mg/dL): may need exchange transfusion -> tertiary")
|
| 892 |
+
|
| 893 |
+
if audio and audio.cry and audio.cry.get("asphyxia_risk", 0) > 0.7:
|
| 894 |
+
referral_needed = True
|
| 895 |
+
urgency = "immediate"
|
| 896 |
+
facility_level = "tertiary"
|
| 897 |
+
reasons.append("High birth asphyxia risk — NICU evaluation needed")
|
| 898 |
+
capabilities_needed.append("NICU / neonatal resuscitation")
|
| 899 |
+
pre_referral_actions.append("Maintain clear airway")
|
| 900 |
+
pre_referral_actions.append("Provide warmth and gentle stimulation")
|
| 901 |
+
pre_referral_actions.append("Monitor breathing during transport")
|
| 902 |
+
reasoning.append(f" High asphyxia risk (>70%): requires NICU -> IMMEDIATE to tertiary")
|
| 903 |
+
|
| 904 |
+
elif audio and audio.cry and audio.cry.get("asphyxia_risk", 0) > 0.4:
|
| 905 |
+
if not referral_needed:
|
| 906 |
+
referral_needed = True
|
| 907 |
+
urgency = "routine"
|
| 908 |
+
facility_level = "secondary"
|
| 909 |
+
reasons.append("Moderate asphyxia risk — specialist evaluation advised")
|
| 910 |
+
reasoning.append(f" Moderate asphyxia risk: specialist evaluation -> routine referral to secondary")
|
| 911 |
+
|
| 912 |
+
# Step 3: Synthesize and verify referral decision
|
| 913 |
+
reasoning.append(f"[STEP 3/4] Synthesizing referral decision")
|
| 914 |
+
|
| 915 |
+
if protocol.classification == "YELLOW" and not referral_needed:
|
| 916 |
+
urgency = "routine"
|
| 917 |
+
reasoning.append(f" YELLOW classification without specific referral triggers -> routine follow-up")
|
| 918 |
+
|
| 919 |
+
# Determine timeframe
|
| 920 |
+
timeframe_map = {
|
| 921 |
+
"immediate": "Within 1 hour — arrange emergency transport",
|
| 922 |
+
"urgent": "Within 4-6 hours — arrange priority transport",
|
| 923 |
+
"routine": "Within 24-48 hours — schedule outpatient referral",
|
| 924 |
+
"none": "Not applicable — manage at current facility",
|
| 925 |
+
}
|
| 926 |
+
timeframe = timeframe_map[urgency]
|
| 927 |
+
|
| 928 |
+
# Step 4: Generate referral summary
|
| 929 |
+
reasoning.append(f"[STEP 4/4] Referral decision summary")
|
| 930 |
+
reason_text = "; ".join(reasons) if reasons else "No referral required"
|
| 931 |
+
|
| 932 |
+
if referral_needed:
|
| 933 |
+
reasoning.append(f" DECISION: REFER — urgency={urgency}, facility={facility_level}")
|
| 934 |
+
reasoning.append(f" Reasons: {reason_text}")
|
| 935 |
+
reasoning.append(f" Timeframe: {timeframe}")
|
| 936 |
+
if capabilities_needed:
|
| 937 |
+
reasoning.append(f" Required capabilities: {', '.join(capabilities_needed)}")
|
| 938 |
+
if pre_referral_actions:
|
| 939 |
+
reasoning.append(f" Pre-referral actions: {'; '.join(pre_referral_actions)}")
|
| 940 |
+
else:
|
| 941 |
+
reasoning.append(f" DECISION: No referral needed — manage at current level")
|
| 942 |
+
reasoning.append(f" Follow protocol recommendations and scheduled follow-up")
|
| 943 |
+
|
| 944 |
+
elapsed = (time.time() - start) * 1000
|
| 945 |
+
|
| 946 |
+
result = ReferralResult(
|
| 947 |
+
referral_needed=referral_needed,
|
| 948 |
+
urgency=urgency,
|
| 949 |
+
facility_level=facility_level,
|
| 950 |
+
reason=reason_text,
|
| 951 |
+
timeframe=timeframe,
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
trace = AgentResult(
|
| 955 |
+
agent_name="ReferralAgent",
|
| 956 |
+
status="success",
|
| 957 |
+
reasoning=reasoning,
|
| 958 |
+
findings={
|
| 959 |
+
"referral_needed": referral_needed,
|
| 960 |
+
"urgency": urgency,
|
| 961 |
+
"facility_level": facility_level,
|
| 962 |
+
"capabilities_needed": capabilities_needed,
|
| 963 |
+
"pre_referral_actions": pre_referral_actions,
|
| 964 |
+
},
|
| 965 |
+
confidence=1.0,
|
| 966 |
+
processing_time_ms=elapsed,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
return result, trace
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
class SynthesisAgent:
|
| 973 |
+
"""
|
| 974 |
+
Clinical reasoning and synthesis using MedGemma.
|
| 975 |
+
|
| 976 |
+
HAI-DEF Model: MedGemma (google/medgemma-4b-it)
|
| 977 |
+
Reuses existing ClinicalSynthesizer instance.
|
| 978 |
+
Passes full agent reasoning context to MedGemma for richer synthesis.
|
| 979 |
+
"""
|
| 980 |
+
|
| 981 |
+
def __init__(self, synthesizer: Optional[Any] = None):
|
| 982 |
+
self._synthesizer = synthesizer
|
| 983 |
+
|
| 984 |
+
def _get_synthesizer(self) -> Any:
|
| 985 |
+
if self._synthesizer is None:
|
| 986 |
+
from .clinical_synthesizer import ClinicalSynthesizer
|
| 987 |
+
self._synthesizer = ClinicalSynthesizer()
|
| 988 |
+
return self._synthesizer
|
| 989 |
+
|
| 990 |
+
def process(
|
| 991 |
+
self,
|
| 992 |
+
patient_type: PatientType,
|
| 993 |
+
triage: TriageResult,
|
| 994 |
+
image: ImageAnalysisResult,
|
| 995 |
+
audio: Optional[AudioAnalysisResult],
|
| 996 |
+
protocol: ProtocolResult,
|
| 997 |
+
referral: ReferralResult,
|
| 998 |
+
agent_traces: List[AgentResult],
|
| 999 |
+
) -> tuple[Dict[str, Any], AgentResult]:
|
| 1000 |
+
start = time.time()
|
| 1001 |
+
reasoning: List[str] = []
|
| 1002 |
+
|
| 1003 |
+
reasoning.append("Synthesizing all agent findings with MedGemma")
|
| 1004 |
+
|
| 1005 |
+
# Build findings dict for the synthesizer
|
| 1006 |
+
findings: Dict[str, Any] = {}
|
| 1007 |
+
if image.anemia:
|
| 1008 |
+
findings["anemia"] = image.anemia
|
| 1009 |
+
reasoning.append("Including anemia findings in synthesis")
|
| 1010 |
+
if image.jaundice:
|
| 1011 |
+
findings["jaundice"] = image.jaundice
|
| 1012 |
+
reasoning.append("Including jaundice findings in synthesis")
|
| 1013 |
+
if audio and audio.cry:
|
| 1014 |
+
findings["cry"] = audio.cry
|
| 1015 |
+
reasoning.append("Including cry analysis findings in synthesis")
|
| 1016 |
+
|
| 1017 |
+
# Add agent context for richer synthesis
|
| 1018 |
+
findings["patient_info"] = {"type": patient_type}
|
| 1019 |
+
findings["agent_context"] = {
|
| 1020 |
+
"triage_score": triage.score,
|
| 1021 |
+
"triage_risk": triage.risk_level,
|
| 1022 |
+
"critical_signs": triage.critical_signs,
|
| 1023 |
+
"protocol_classification": protocol.classification,
|
| 1024 |
+
"applicable_protocols": protocol.applicable_protocols,
|
| 1025 |
+
"referral_needed": referral.referral_needed,
|
| 1026 |
+
"referral_urgency": referral.urgency,
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
# Build reasoning trace summary for MedGemma prompt
|
| 1030 |
+
trace_summary = []
|
| 1031 |
+
for trace in agent_traces:
|
| 1032 |
+
trace_summary.append(f"{trace.agent_name}: {'; '.join(trace.reasoning[-3:])}")
|
| 1033 |
+
findings["agent_reasoning_summary"] = "\n".join(trace_summary)
|
| 1034 |
+
|
| 1035 |
+
reasoning.append(f"Passing {len(agent_traces)} agent traces as context")
|
| 1036 |
+
|
| 1037 |
+
try:
|
| 1038 |
+
synthesizer = self._get_synthesizer()
|
| 1039 |
+
synthesis = synthesizer.synthesize(findings)
|
| 1040 |
+
reasoning.append(f"Synthesis completed using: {synthesis.get('model', 'unknown')}")
|
| 1041 |
+
reasoning.append(f"Severity level: {synthesis.get('severity_level', 'N/A')}")
|
| 1042 |
+
reasoning.append(f"Referral needed: {synthesis.get('referral_needed', 'N/A')}")
|
| 1043 |
+
|
| 1044 |
+
confidence = 0.85 if "MedGemma" in synthesis.get("model", "") else 0.75
|
| 1045 |
+
except Exception as e:
|
| 1046 |
+
reasoning.append(f"Synthesis failed: {e}")
|
| 1047 |
+
synthesis = {
|
| 1048 |
+
"summary": f"Assessment for {patient_type} patient. Classification: {protocol.classification}.",
|
| 1049 |
+
"severity_level": protocol.classification,
|
| 1050 |
+
"severity_description": f"WHO IMNCI {protocol.classification} classification",
|
| 1051 |
+
"immediate_actions": protocol.treatment_recommendations or ["Continue routine care"],
|
| 1052 |
+
"referral_needed": referral.referral_needed,
|
| 1053 |
+
"referral_urgency": referral.urgency,
|
| 1054 |
+
"follow_up": protocol.follow_up_schedule,
|
| 1055 |
+
"urgent_conditions": triage.critical_signs,
|
| 1056 |
+
"model": "Fallback (agent context)",
|
| 1057 |
+
"generated_at": datetime.now().isoformat(),
|
| 1058 |
+
}
|
| 1059 |
+
confidence = 0.6
|
| 1060 |
+
|
| 1061 |
+
elapsed = (time.time() - start) * 1000
|
| 1062 |
+
|
| 1063 |
+
trace = AgentResult(
|
| 1064 |
+
agent_name="SynthesisAgent",
|
| 1065 |
+
status="success",
|
| 1066 |
+
reasoning=reasoning,
|
| 1067 |
+
findings={
|
| 1068 |
+
"model": synthesis.get("model", "unknown"),
|
| 1069 |
+
"severity_level": synthesis.get("severity_level", "unknown"),
|
| 1070 |
+
},
|
| 1071 |
+
confidence=confidence,
|
| 1072 |
+
processing_time_ms=elapsed,
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
return synthesis, trace
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
# ---------------------------------------------------------------------------
|
| 1079 |
+
# Workflow Engine
|
| 1080 |
+
# ---------------------------------------------------------------------------
|
| 1081 |
+
|
| 1082 |
+
WorkflowCallback = Callable[[WorkflowState, float], None]
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
class AgenticWorkflowEngine:
|
| 1086 |
+
"""
|
| 1087 |
+
Orchestrates the 6-agent clinical workflow pipeline.
|
| 1088 |
+
|
| 1089 |
+
Pipeline: Triage -> Image -> Audio -> Protocol -> Referral -> Synthesis
|
| 1090 |
+
Early-exit on critical danger signs (RED + critical -> skip to Synthesis)
|
| 1091 |
+
|
| 1092 |
+
Each agent emits a structured AgentResult with reasoning traces
|
| 1093 |
+
that form a complete audit trail of the clinical decision process.
|
| 1094 |
+
"""
|
| 1095 |
+
|
| 1096 |
+
AGENTS = [
|
| 1097 |
+
"TriageAgent",
|
| 1098 |
+
"ImageAnalysisAgent",
|
| 1099 |
+
"AudioAnalysisAgent",
|
| 1100 |
+
"ProtocolAgent",
|
| 1101 |
+
"ReferralAgent",
|
| 1102 |
+
"SynthesisAgent",
|
| 1103 |
+
]
|
| 1104 |
+
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
anemia_detector: Optional[Any] = None,
|
| 1108 |
+
jaundice_detector: Optional[Any] = None,
|
| 1109 |
+
cry_analyzer: Optional[Any] = None,
|
| 1110 |
+
synthesizer: Optional[Any] = None,
|
| 1111 |
+
on_state_change: Optional[WorkflowCallback] = None,
|
| 1112 |
+
):
|
| 1113 |
+
self._triage = TriageAgent()
|
| 1114 |
+
self._image = ImageAnalysisAgent(anemia_detector, jaundice_detector)
|
| 1115 |
+
self._audio = AudioAnalysisAgent(cry_analyzer)
|
| 1116 |
+
self._protocol = ProtocolAgent()
|
| 1117 |
+
self._referral = ReferralAgent()
|
| 1118 |
+
self._synthesis = SynthesisAgent(synthesizer)
|
| 1119 |
+
self._state: WorkflowState = "idle"
|
| 1120 |
+
self._on_state_change = on_state_change
|
| 1121 |
+
|
| 1122 |
+
def _transition(self, state: WorkflowState, progress: float) -> None:
|
| 1123 |
+
self._state = state
|
| 1124 |
+
if self._on_state_change:
|
| 1125 |
+
self._on_state_change(state, progress)
|
| 1126 |
+
|
| 1127 |
+
@property
|
| 1128 |
+
def state(self) -> WorkflowState:
|
| 1129 |
+
return self._state
|
| 1130 |
+
|
| 1131 |
+
def execute(self, workflow_input: WorkflowInput) -> WorkflowResult:
|
| 1132 |
+
"""
|
| 1133 |
+
Execute the full agentic workflow pipeline.
|
| 1134 |
+
|
| 1135 |
+
Args:
|
| 1136 |
+
workflow_input: Complete input with patient info, images, audio, danger signs.
|
| 1137 |
+
|
| 1138 |
+
Returns:
|
| 1139 |
+
WorkflowResult with all agent outputs, reasoning traces, and clinical synthesis.
|
| 1140 |
+
"""
|
| 1141 |
+
start = time.time()
|
| 1142 |
+
agent_traces: List[AgentResult] = []
|
| 1143 |
+
patient_type = workflow_input.patient_type
|
| 1144 |
+
|
| 1145 |
+
try:
|
| 1146 |
+
# Step 1: Triage (10% progress)
|
| 1147 |
+
self._transition("triaging", 10.0)
|
| 1148 |
+
triage_result, triage_trace = self._triage.process(
|
| 1149 |
+
patient_type,
|
| 1150 |
+
workflow_input.danger_signs,
|
| 1151 |
+
workflow_input.patient_info,
|
| 1152 |
+
)
|
| 1153 |
+
agent_traces.append(triage_trace)
|
| 1154 |
+
|
| 1155 |
+
# Early exit for critical cases
|
| 1156 |
+
if triage_result.immediate_referral_needed:
|
| 1157 |
+
self._transition("complete", 100.0)
|
| 1158 |
+
return self._build_early_referral(
|
| 1159 |
+
workflow_input, triage_result, agent_traces, start
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
# Step 2: Image Analysis (30% progress)
|
| 1163 |
+
self._transition("analyzing_image", 30.0)
|
| 1164 |
+
image_result, image_trace = self._image.process(
|
| 1165 |
+
patient_type,
|
| 1166 |
+
workflow_input.conjunctiva_image,
|
| 1167 |
+
workflow_input.skin_image,
|
| 1168 |
+
)
|
| 1169 |
+
agent_traces.append(image_trace)
|
| 1170 |
+
|
| 1171 |
+
# Step 3: Audio Analysis (50% progress)
|
| 1172 |
+
self._transition("analyzing_audio", 50.0)
|
| 1173 |
+
audio_result, audio_trace = self._audio.process(
|
| 1174 |
+
workflow_input.cry_audio,
|
| 1175 |
+
)
|
| 1176 |
+
agent_traces.append(audio_trace)
|
| 1177 |
+
|
| 1178 |
+
# Step 4: Protocol Application (70% progress)
|
| 1179 |
+
self._transition("applying_protocol", 70.0)
|
| 1180 |
+
protocol_result, protocol_trace = self._protocol.process(
|
| 1181 |
+
patient_type, triage_result, image_result, audio_result
|
| 1182 |
+
)
|
| 1183 |
+
agent_traces.append(protocol_trace)
|
| 1184 |
+
|
| 1185 |
+
# Step 5: Referral Decision (85% progress)
|
| 1186 |
+
self._transition("determining_referral", 85.0)
|
| 1187 |
+
referral_result, referral_trace = self._referral.process(
|
| 1188 |
+
patient_type, triage_result, protocol_result,
|
| 1189 |
+
image_result, audio_result,
|
| 1190 |
+
)
|
| 1191 |
+
agent_traces.append(referral_trace)
|
| 1192 |
+
|
| 1193 |
+
# Step 6: Clinical Synthesis with MedGemma (95% progress)
|
| 1194 |
+
self._transition("synthesizing", 95.0)
|
| 1195 |
+
synthesis, synthesis_trace = self._synthesis.process(
|
| 1196 |
+
patient_type, triage_result, image_result,
|
| 1197 |
+
audio_result, protocol_result, referral_result,
|
| 1198 |
+
agent_traces,
|
| 1199 |
+
)
|
| 1200 |
+
agent_traces.append(synthesis_trace)
|
| 1201 |
+
|
| 1202 |
+
# Build final result
|
| 1203 |
+
self._transition("complete", 100.0)
|
| 1204 |
+
elapsed = (time.time() - start) * 1000
|
| 1205 |
+
|
| 1206 |
+
return WorkflowResult(
|
| 1207 |
+
success=True,
|
| 1208 |
+
patient_type=patient_type,
|
| 1209 |
+
who_classification=protocol_result.classification,
|
| 1210 |
+
triage_result=triage_result,
|
| 1211 |
+
image_results=image_result,
|
| 1212 |
+
audio_results=audio_result,
|
| 1213 |
+
protocol_result=protocol_result,
|
| 1214 |
+
referral_result=referral_result,
|
| 1215 |
+
clinical_synthesis=synthesis.get("summary", ""),
|
| 1216 |
+
recommendation=synthesis.get("immediate_actions", ["Continue routine care"])[0],
|
| 1217 |
+
immediate_actions=synthesis.get("immediate_actions", []),
|
| 1218 |
+
agent_traces=agent_traces,
|
| 1219 |
+
processing_time_ms=elapsed,
|
| 1220 |
+
timestamp=datetime.now().isoformat(),
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
except Exception as e:
|
| 1224 |
+
self._transition("error", 0.0)
|
| 1225 |
+
elapsed = (time.time() - start) * 1000
|
| 1226 |
+
error_trace = AgentResult(
|
| 1227 |
+
agent_name="WorkflowEngine",
|
| 1228 |
+
status="error",
|
| 1229 |
+
reasoning=[f"Workflow failed: {e}"],
|
| 1230 |
+
findings={"error": str(e)},
|
| 1231 |
+
confidence=0.0,
|
| 1232 |
+
processing_time_ms=elapsed,
|
| 1233 |
+
)
|
| 1234 |
+
agent_traces.append(error_trace)
|
| 1235 |
+
|
| 1236 |
+
return WorkflowResult(
|
| 1237 |
+
success=False,
|
| 1238 |
+
patient_type=patient_type,
|
| 1239 |
+
who_classification="RED",
|
| 1240 |
+
agent_traces=agent_traces,
|
| 1241 |
+
clinical_synthesis=f"Workflow error: {e}. Please retry or seek immediate medical consultation.",
|
| 1242 |
+
recommendation="Seek immediate medical consultation due to assessment error",
|
| 1243 |
+
immediate_actions=["Seek immediate medical consultation"],
|
| 1244 |
+
processing_time_ms=elapsed,
|
| 1245 |
+
timestamp=datetime.now().isoformat(),
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
def _build_early_referral(
|
| 1249 |
+
self,
|
| 1250 |
+
workflow_input: WorkflowInput,
|
| 1251 |
+
triage: TriageResult,
|
| 1252 |
+
agent_traces: List[AgentResult],
|
| 1253 |
+
start_time: float,
|
| 1254 |
+
) -> WorkflowResult:
|
| 1255 |
+
"""Build result for early-exit when critical danger signs are detected."""
|
| 1256 |
+
elapsed = (time.time() - start_time) * 1000
|
| 1257 |
+
|
| 1258 |
+
critical_text = ", ".join(triage.critical_signs)
|
| 1259 |
+
synthesis_text = (
|
| 1260 |
+
f"URGENT: Critical danger signs detected ({critical_text}). "
|
| 1261 |
+
f"Immediate referral to higher-level facility is required. "
|
| 1262 |
+
f"This patient requires emergency care that cannot be provided at the current level."
|
| 1263 |
+
)
|
| 1264 |
+
|
| 1265 |
+
return WorkflowResult(
|
| 1266 |
+
success=True,
|
| 1267 |
+
patient_type=workflow_input.patient_type,
|
| 1268 |
+
who_classification="RED",
|
| 1269 |
+
triage_result=triage,
|
| 1270 |
+
image_results=ImageAnalysisResult(),
|
| 1271 |
+
audio_results=AudioAnalysisResult(),
|
| 1272 |
+
protocol_result=ProtocolResult(
|
| 1273 |
+
classification="RED",
|
| 1274 |
+
applicable_protocols=["Emergency Referral Protocol"],
|
| 1275 |
+
treatment_recommendations=["IMMEDIATE REFERRAL REQUIRED"],
|
| 1276 |
+
follow_up_schedule="After emergency care",
|
| 1277 |
+
),
|
| 1278 |
+
referral_result=ReferralResult(
|
| 1279 |
+
referral_needed=True,
|
| 1280 |
+
urgency="immediate",
|
| 1281 |
+
facility_level="tertiary",
|
| 1282 |
+
reason=f"Critical danger signs detected: {critical_text}",
|
| 1283 |
+
timeframe="Immediately - within 1 hour",
|
| 1284 |
+
),
|
| 1285 |
+
clinical_synthesis=synthesis_text,
|
| 1286 |
+
recommendation="IMMEDIATE REFERRAL to tertiary care facility",
|
| 1287 |
+
immediate_actions=[
|
| 1288 |
+
"Arrange emergency transport",
|
| 1289 |
+
"Call receiving facility",
|
| 1290 |
+
"Provide pre-referral treatment as per protocol",
|
| 1291 |
+
"Accompany patient with referral note",
|
| 1292 |
+
],
|
| 1293 |
+
agent_traces=agent_traces,
|
| 1294 |
+
processing_time_ms=elapsed,
|
| 1295 |
+
timestamp=datetime.now().isoformat(),
|
| 1296 |
+
)
|
src/nexus/anemia_detector.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anemia Detector Module
|
| 3 |
+
|
| 4 |
+
Uses MedSigLIP from Google HAI-DEF for anemia detection from conjunctiva images.
|
| 5 |
+
Implements zero-shot classification with medical text prompts per NEXUS_MASTER_PLAN.md.
|
| 6 |
+
|
| 7 |
+
HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
|
| 8 |
+
Documentation: https://developers.google.com/health-ai-developer-foundations/medsiglip
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from transformers import AutoProcessor, AutoModel
|
| 21 |
+
HAS_TRANSFORMERS = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
HAS_TRANSFORMERS = False
|
| 24 |
+
|
| 25 |
+
# HAI-DEF MedSigLIP model IDs to try in order of preference
|
| 26 |
+
MEDSIGLIP_MODEL_IDS = [
|
| 27 |
+
"google/medsiglip-448", # MedSigLIP - official HAI-DEF model
|
| 28 |
+
"google/siglip-base-patch16-224", # SigLIP 224 - fallback
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AnemiaDetector:
|
| 33 |
+
"""
|
| 34 |
+
Detects anemia from conjunctiva (inner eyelid) images using MedSigLIP.
|
| 35 |
+
|
| 36 |
+
Uses zero-shot classification with medical prompts for detection.
|
| 37 |
+
HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
|
| 38 |
+
Fallback: siglip-base-patch16-224
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
# Medical text prompts for zero-shot classification (optimized for MedSigLIP)
|
| 42 |
+
# Expanded prompt set with specific clinical language for better discrimination
|
| 43 |
+
ANEMIC_PROMPTS = [
|
| 44 |
+
"pale conjunctiva with visible pallor indicating anemia",
|
| 45 |
+
"conjunctival pallor grade 2 or higher with reduced vascularity",
|
| 46 |
+
"white or very pale inner eyelid mucosa suggesting low hemoglobin",
|
| 47 |
+
"conjunctiva showing significant pallor and poor blood perfusion",
|
| 48 |
+
"anemic eye with pale pink to white palpebral conjunctiva",
|
| 49 |
+
"inner eyelid lacking red coloration consistent with severe anemia",
|
| 50 |
+
"conjunctiva with washed out appearance and faint vascular pattern",
|
| 51 |
+
"pale mucous membrane of the lower eyelid suggesting iron deficiency",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
HEALTHY_PROMPTS = [
|
| 55 |
+
"healthy red conjunctiva with rich vascular pattern",
|
| 56 |
+
"well-perfused bright pink inner eyelid with visible blood vessels",
|
| 57 |
+
"normal conjunctiva showing deep red-pink coloration",
|
| 58 |
+
"conjunctiva with healthy blood supply and strong red color",
|
| 59 |
+
"richly vascularized palpebral conjunctiva with normal hemoglobin",
|
| 60 |
+
"inner eyelid with vibrant red-pink mucosa and clear vessels",
|
| 61 |
+
"non-anemic conjunctiva showing robust red perfusion",
|
| 62 |
+
"conjunctival mucosa with normal deep pink to red appearance",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
model_name: Optional[str] = None, # Auto-select MedSigLIP
|
| 68 |
+
device: Optional[str] = None,
|
| 69 |
+
threshold: float = 0.5,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Initialize the Anemia Detector with MedSigLIP.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
model_name: HuggingFace model name (auto-selects HAI-DEF MedSigLIP if None)
|
| 76 |
+
device: Device to run model on (auto-detected if None)
|
| 77 |
+
threshold: Classification threshold for anemia detection
|
| 78 |
+
"""
|
| 79 |
+
if not HAS_TRANSFORMERS:
|
| 80 |
+
raise ImportError("transformers library required. Install with: pip install transformers")
|
| 81 |
+
|
| 82 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
+
self.threshold = threshold
|
| 84 |
+
self._model_loaded = False
|
| 85 |
+
self.classifier = None # Can be set by pipeline for trained classification
|
| 86 |
+
|
| 87 |
+
# Determine which models to try
|
| 88 |
+
models_to_try = [model_name] if model_name else MEDSIGLIP_MODEL_IDS
|
| 89 |
+
|
| 90 |
+
# HuggingFace token for gated models
|
| 91 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 92 |
+
|
| 93 |
+
# Try loading models in order of preference
|
| 94 |
+
for candidate_model in models_to_try:
|
| 95 |
+
print(f"Loading HAI-DEF model: {candidate_model}")
|
| 96 |
+
try:
|
| 97 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 98 |
+
candidate_model, token=hf_token
|
| 99 |
+
)
|
| 100 |
+
self.model = AutoModel.from_pretrained(
|
| 101 |
+
candidate_model, token=hf_token
|
| 102 |
+
).to(self.device)
|
| 103 |
+
self.model_name = candidate_model
|
| 104 |
+
self._model_loaded = True
|
| 105 |
+
print(f"Successfully loaded: {candidate_model}")
|
| 106 |
+
break
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Warning: Could not load {candidate_model}: {e}")
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
if not self._model_loaded:
|
| 112 |
+
raise RuntimeError(
|
| 113 |
+
f"Could not load any MedSigLIP model. Tried: {models_to_try}. "
|
| 114 |
+
"Install transformers and ensure internet access."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.model.eval()
|
| 118 |
+
|
| 119 |
+
# Pre-compute text embeddings for efficiency
|
| 120 |
+
self._precompute_text_embeddings()
|
| 121 |
+
|
| 122 |
+
# Try to auto-load trained classifier
|
| 123 |
+
self._auto_load_classifier()
|
| 124 |
+
|
| 125 |
+
# Indicate which model variant is being used
|
| 126 |
+
is_medsiglip = "medsiglip" in self.model_name
|
| 127 |
+
model_type = "MedSigLIP" if is_medsiglip else "SigLIP (fallback)"
|
| 128 |
+
classifier_status = "with trained classifier" if self.classifier else "zero-shot"
|
| 129 |
+
print(f"Anemia Detector (HAI-DEF {model_type}, {classifier_status}) initialized on {self.device}")
|
| 130 |
+
|
| 131 |
+
def _auto_load_classifier(self) -> None:
|
| 132 |
+
"""Auto-load trained anemia classifier if available."""
|
| 133 |
+
if self.classifier is not None:
|
| 134 |
+
return # Already set externally
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
import joblib
|
| 138 |
+
except ImportError:
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
default_paths = [
|
| 142 |
+
Path(__file__).parent.parent.parent / "models" / "linear_probes" / "anemia_classifier.joblib",
|
| 143 |
+
Path("models/linear_probes/anemia_classifier.joblib"),
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
for path in default_paths:
|
| 147 |
+
if path.exists():
|
| 148 |
+
try:
|
| 149 |
+
self.classifier = joblib.load(path)
|
| 150 |
+
print(f"Auto-loaded anemia classifier from {path}")
|
| 151 |
+
return
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Warning: Could not load classifier from {path}: {e}")
|
| 154 |
+
|
| 155 |
+
# Logit temperature for softmax conversion (lower = more spread, higher = sharper)
|
| 156 |
+
LOGIT_SCALE = 30.0
|
| 157 |
+
|
| 158 |
+
def _precompute_text_embeddings(self) -> None:
|
| 159 |
+
"""Pre-compute text embeddings for zero-shot classification using SigLIP.
|
| 160 |
+
|
| 161 |
+
Stores individual prompt embeddings for max-similarity scoring,
|
| 162 |
+
which outperforms mean-pooled embeddings for medical image classification.
|
| 163 |
+
"""
|
| 164 |
+
all_prompts = self.ANEMIC_PROMPTS + self.HEALTHY_PROMPTS
|
| 165 |
+
|
| 166 |
+
with torch.no_grad():
|
| 167 |
+
# SigLIP uses different API than CLIP
|
| 168 |
+
inputs = self.processor(
|
| 169 |
+
text=all_prompts,
|
| 170 |
+
return_tensors="pt",
|
| 171 |
+
padding="max_length",
|
| 172 |
+
truncation=True,
|
| 173 |
+
).to(self.device)
|
| 174 |
+
|
| 175 |
+
# Get text embeddings - support multiple output APIs
|
| 176 |
+
if hasattr(self.model, 'get_text_features'):
|
| 177 |
+
text_embeddings = self.model.get_text_features(**inputs)
|
| 178 |
+
else:
|
| 179 |
+
outputs = self.model(**inputs)
|
| 180 |
+
if hasattr(outputs, 'text_embeds'):
|
| 181 |
+
text_embeddings = outputs.text_embeds
|
| 182 |
+
elif hasattr(outputs, 'text_model_output'):
|
| 183 |
+
text_embeddings = outputs.text_model_output.pooler_output
|
| 184 |
+
else:
|
| 185 |
+
text_outputs = self.model.text_model(**inputs)
|
| 186 |
+
text_embeddings = text_outputs.pooler_output
|
| 187 |
+
|
| 188 |
+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
| 189 |
+
|
| 190 |
+
# Store individual embeddings for max-similarity scoring
|
| 191 |
+
n_anemic = len(self.ANEMIC_PROMPTS)
|
| 192 |
+
self.anemic_embeddings_all = text_embeddings[:n_anemic] # (N, D)
|
| 193 |
+
self.healthy_embeddings_all = text_embeddings[n_anemic:] # (M, D)
|
| 194 |
+
|
| 195 |
+
# Also keep mean embeddings as fallback
|
| 196 |
+
self.anemic_embeddings = self.anemic_embeddings_all.mean(dim=0, keepdim=True)
|
| 197 |
+
self.healthy_embeddings = self.healthy_embeddings_all.mean(dim=0, keepdim=True)
|
| 198 |
+
self.anemic_embeddings = self.anemic_embeddings / self.anemic_embeddings.norm(dim=-1, keepdim=True)
|
| 199 |
+
self.healthy_embeddings = self.healthy_embeddings / self.healthy_embeddings.norm(dim=-1, keepdim=True)
|
| 200 |
+
|
| 201 |
+
def preprocess_image(self, image: Union[str, Path, Image.Image]) -> Image.Image:
|
| 202 |
+
"""
|
| 203 |
+
Preprocess image for analysis.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
image: Path to image or PIL Image
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Preprocessed PIL Image
|
| 210 |
+
"""
|
| 211 |
+
if isinstance(image, (str, Path)):
|
| 212 |
+
image = Image.open(image).convert("RGB")
|
| 213 |
+
elif not isinstance(image, Image.Image):
|
| 214 |
+
raise ValueError(f"Expected str, Path, or PIL Image, got {type(image)}")
|
| 215 |
+
|
| 216 |
+
return image
|
| 217 |
+
|
| 218 |
+
def detect(self, image: Union[str, Path, Image.Image]) -> Dict:
|
| 219 |
+
"""
|
| 220 |
+
Detect anemia from conjunctiva image.
|
| 221 |
+
|
| 222 |
+
Uses trained classifier if available, otherwise falls back to
|
| 223 |
+
zero-shot classification with MedSigLIP.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
image: Conjunctiva image (path or PIL Image)
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Dictionary containing:
|
| 230 |
+
- is_anemic: Boolean indicating anemia detection
|
| 231 |
+
- confidence: Confidence score (0-1)
|
| 232 |
+
- anemia_score: Raw anemia probability
|
| 233 |
+
- healthy_score: Raw healthy probability
|
| 234 |
+
- risk_level: "high", "medium", or "low"
|
| 235 |
+
- recommendation: Clinical recommendation
|
| 236 |
+
"""
|
| 237 |
+
# Preprocess image
|
| 238 |
+
pil_image = self.preprocess_image(image)
|
| 239 |
+
|
| 240 |
+
# Get image embedding using SigLIP
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
inputs = self.processor(
|
| 243 |
+
images=pil_image,
|
| 244 |
+
return_tensors="pt",
|
| 245 |
+
).to(self.device)
|
| 246 |
+
|
| 247 |
+
# Get image embeddings - support multiple output APIs
|
| 248 |
+
if hasattr(self.model, 'get_image_features'):
|
| 249 |
+
image_embedding = self.model.get_image_features(**inputs)
|
| 250 |
+
else:
|
| 251 |
+
outputs = self.model(**inputs)
|
| 252 |
+
if hasattr(outputs, 'image_embeds'):
|
| 253 |
+
image_embedding = outputs.image_embeds
|
| 254 |
+
elif hasattr(outputs, 'vision_model_output'):
|
| 255 |
+
image_embedding = outputs.vision_model_output.pooler_output
|
| 256 |
+
else:
|
| 257 |
+
vision_outputs = self.model.vision_model(**inputs)
|
| 258 |
+
image_embedding = vision_outputs.pooler_output
|
| 259 |
+
|
| 260 |
+
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
|
| 261 |
+
|
| 262 |
+
# Use trained classifier if available, otherwise zero-shot
|
| 263 |
+
if self.classifier is not None:
|
| 264 |
+
anemia_prob, healthy_prob, model_method = self._classify_with_trained_model(image_embedding)
|
| 265 |
+
else:
|
| 266 |
+
anemia_prob, healthy_prob, model_method = self._classify_zero_shot(image_embedding)
|
| 267 |
+
|
| 268 |
+
# Determine risk level
|
| 269 |
+
if anemia_prob > 0.7:
|
| 270 |
+
risk_level = "high"
|
| 271 |
+
recommendation = "URGENT: Refer for blood test immediately. High likelihood of anemia."
|
| 272 |
+
elif anemia_prob > 0.5:
|
| 273 |
+
risk_level = "medium"
|
| 274 |
+
recommendation = "Schedule blood test within 48 hours. Moderate anemia indicators present."
|
| 275 |
+
else:
|
| 276 |
+
risk_level = "low"
|
| 277 |
+
recommendation = "No immediate concern. Routine follow-up recommended."
|
| 278 |
+
|
| 279 |
+
is_medsiglip = "medsiglip" in self.model_name
|
| 280 |
+
base_model = "MedSigLIP (HAI-DEF)" if is_medsiglip else "SigLIP (fallback)"
|
| 281 |
+
|
| 282 |
+
return {
|
| 283 |
+
"is_anemic": anemia_prob > self.threshold,
|
| 284 |
+
"confidence": max(anemia_prob, healthy_prob),
|
| 285 |
+
"anemia_score": anemia_prob,
|
| 286 |
+
"healthy_score": healthy_prob,
|
| 287 |
+
"risk_level": risk_level,
|
| 288 |
+
"recommendation": recommendation,
|
| 289 |
+
"model": self.model_name,
|
| 290 |
+
"model_type": f"{base_model} + {model_method}",
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
def _classify_with_trained_model(self, image_embedding: torch.Tensor) -> Tuple[float, float, str]:
|
| 294 |
+
"""
|
| 295 |
+
Classify using trained classifier on embeddings.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
image_embedding: Normalized image embedding from MedSigLIP
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
Tuple of (anemia_prob, healthy_prob, method_name)
|
| 302 |
+
"""
|
| 303 |
+
# Convert embedding to numpy for sklearn classifiers
|
| 304 |
+
embedding_np = image_embedding.cpu().numpy().reshape(1, -1)
|
| 305 |
+
|
| 306 |
+
# Handle different classifier types
|
| 307 |
+
if hasattr(self.classifier, 'predict_proba'):
|
| 308 |
+
# Sklearn classifier with probability support
|
| 309 |
+
proba = self.classifier.predict_proba(embedding_np)
|
| 310 |
+
# Assume binary: [healthy, anemic] or [anemic, healthy]
|
| 311 |
+
if proba.shape[1] >= 2:
|
| 312 |
+
# Check classifier classes to determine order
|
| 313 |
+
if hasattr(self.classifier, 'classes_'):
|
| 314 |
+
classes = list(self.classifier.classes_)
|
| 315 |
+
if 1 in classes:
|
| 316 |
+
anemia_idx = classes.index(1)
|
| 317 |
+
else:
|
| 318 |
+
anemia_idx = 1 # Default assumption
|
| 319 |
+
else:
|
| 320 |
+
anemia_idx = 1
|
| 321 |
+
anemia_prob = float(proba[0, anemia_idx])
|
| 322 |
+
healthy_prob = 1.0 - anemia_prob
|
| 323 |
+
else:
|
| 324 |
+
anemia_prob = float(proba[0, 0])
|
| 325 |
+
healthy_prob = 1.0 - anemia_prob
|
| 326 |
+
return anemia_prob, healthy_prob, "Trained Classifier"
|
| 327 |
+
|
| 328 |
+
elif hasattr(self.classifier, 'predict'):
|
| 329 |
+
# Classifier without probability - use binary prediction
|
| 330 |
+
prediction = self.classifier.predict(embedding_np)
|
| 331 |
+
anemia_prob = float(prediction[0])
|
| 332 |
+
healthy_prob = 1.0 - anemia_prob
|
| 333 |
+
return anemia_prob, healthy_prob, "Trained Classifier (binary)"
|
| 334 |
+
|
| 335 |
+
elif isinstance(self.classifier, nn.Module):
|
| 336 |
+
# PyTorch classifier
|
| 337 |
+
self.classifier.eval()
|
| 338 |
+
with torch.no_grad():
|
| 339 |
+
logits = self.classifier(image_embedding)
|
| 340 |
+
probs = torch.softmax(logits, dim=-1)
|
| 341 |
+
if probs.shape[-1] >= 2:
|
| 342 |
+
anemia_prob = probs[0, 1].item()
|
| 343 |
+
healthy_prob = probs[0, 0].item()
|
| 344 |
+
else:
|
| 345 |
+
anemia_prob = probs[0, 0].item()
|
| 346 |
+
healthy_prob = 1.0 - anemia_prob
|
| 347 |
+
return anemia_prob, healthy_prob, "Trained Classifier (PyTorch)"
|
| 348 |
+
|
| 349 |
+
else:
|
| 350 |
+
# Unknown classifier type - fall back to zero-shot
|
| 351 |
+
print(f"Warning: Unknown classifier type {type(self.classifier)}, using zero-shot")
|
| 352 |
+
return self._classify_zero_shot(image_embedding)
|
| 353 |
+
|
| 354 |
+
def _classify_zero_shot(self, image_embedding: torch.Tensor) -> Tuple[float, float, str]:
|
| 355 |
+
"""
|
| 356 |
+
Classify using zero-shot with max-similarity scoring.
|
| 357 |
+
|
| 358 |
+
Uses the maximum cosine similarity across all prompts per class
|
| 359 |
+
rather than mean-pooled embeddings, which provides better
|
| 360 |
+
discrimination for medical image classification.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
image_embedding: Normalized image embedding from MedSigLIP
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Tuple of (anemia_prob, healthy_prob, method_name)
|
| 367 |
+
"""
|
| 368 |
+
# Max-similarity: take the best-matching prompt per class
|
| 369 |
+
anemia_sims = (image_embedding @ self.anemic_embeddings_all.T).squeeze(0)
|
| 370 |
+
healthy_sims = (image_embedding @ self.healthy_embeddings_all.T).squeeze(0)
|
| 371 |
+
|
| 372 |
+
# Ensure at least 1-D for .max() to work on single-image inputs
|
| 373 |
+
if anemia_sims.dim() == 0:
|
| 374 |
+
anemia_sims = anemia_sims.unsqueeze(0)
|
| 375 |
+
if healthy_sims.dim() == 0:
|
| 376 |
+
healthy_sims = healthy_sims.unsqueeze(0)
|
| 377 |
+
|
| 378 |
+
anemia_sim = anemia_sims.max().item()
|
| 379 |
+
healthy_sim = healthy_sims.max().item()
|
| 380 |
+
|
| 381 |
+
# Convert to probabilities with tuned temperature
|
| 382 |
+
logits = torch.tensor([anemia_sim, healthy_sim], device="cpu") * self.LOGIT_SCALE
|
| 383 |
+
probs = torch.softmax(logits, dim=0)
|
| 384 |
+
anemia_prob = probs[0].item()
|
| 385 |
+
healthy_prob = probs[1].item()
|
| 386 |
+
|
| 387 |
+
return anemia_prob, healthy_prob, "Zero-Shot"
|
| 388 |
+
|
| 389 |
+
def detect_batch(
|
| 390 |
+
self,
|
| 391 |
+
images: List[Union[str, Path, Image.Image]],
|
| 392 |
+
batch_size: int = 8,
|
| 393 |
+
) -> List[Dict]:
|
| 394 |
+
"""
|
| 395 |
+
Detect anemia from multiple images.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
images: List of conjunctiva images
|
| 399 |
+
batch_size: Batch size for processing
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
List of detection results
|
| 403 |
+
"""
|
| 404 |
+
results = []
|
| 405 |
+
|
| 406 |
+
for i in range(0, len(images), batch_size):
|
| 407 |
+
batch = images[i:i + batch_size]
|
| 408 |
+
|
| 409 |
+
# Process batch
|
| 410 |
+
pil_images = [self.preprocess_image(img) for img in batch]
|
| 411 |
+
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
inputs = self.processor(
|
| 414 |
+
images=pil_images,
|
| 415 |
+
return_tensors="pt",
|
| 416 |
+
padding=True,
|
| 417 |
+
).to(self.device)
|
| 418 |
+
|
| 419 |
+
# Get image embeddings - support multiple output APIs
|
| 420 |
+
if hasattr(self.model, 'get_image_features'):
|
| 421 |
+
image_embeddings = self.model.get_image_features(**inputs)
|
| 422 |
+
else:
|
| 423 |
+
outputs = self.model(**inputs)
|
| 424 |
+
if hasattr(outputs, 'image_embeds'):
|
| 425 |
+
image_embeddings = outputs.image_embeds
|
| 426 |
+
elif hasattr(outputs, 'vision_model_output'):
|
| 427 |
+
image_embeddings = outputs.vision_model_output.pooler_output
|
| 428 |
+
else:
|
| 429 |
+
vision_outputs = self.model.vision_model(**inputs)
|
| 430 |
+
image_embeddings = vision_outputs.pooler_output
|
| 431 |
+
|
| 432 |
+
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
|
| 433 |
+
|
| 434 |
+
# Compute max-similarities for each image
|
| 435 |
+
for j, img_emb in enumerate(image_embeddings):
|
| 436 |
+
img_emb = img_emb.unsqueeze(0)
|
| 437 |
+
|
| 438 |
+
# Use trained classifier if available, otherwise zero-shot
|
| 439 |
+
if self.classifier is not None:
|
| 440 |
+
anemia_prob, healthy_prob, _ = self._classify_with_trained_model(img_emb)
|
| 441 |
+
# Skip zero-shot path below
|
| 442 |
+
if anemia_prob > 0.7:
|
| 443 |
+
risk_level = "high"
|
| 444 |
+
recommendation = "URGENT: Refer for blood test immediately."
|
| 445 |
+
elif anemia_prob > 0.5:
|
| 446 |
+
risk_level = "medium"
|
| 447 |
+
recommendation = "Schedule blood test within 48 hours."
|
| 448 |
+
else:
|
| 449 |
+
risk_level = "low"
|
| 450 |
+
recommendation = "No immediate concern."
|
| 451 |
+
|
| 452 |
+
results.append({
|
| 453 |
+
"is_anemic": anemia_prob > self.threshold,
|
| 454 |
+
"confidence": max(anemia_prob, healthy_prob),
|
| 455 |
+
"anemia_score": anemia_prob,
|
| 456 |
+
"healthy_score": healthy_prob,
|
| 457 |
+
"risk_level": risk_level,
|
| 458 |
+
"recommendation": recommendation,
|
| 459 |
+
})
|
| 460 |
+
continue
|
| 461 |
+
|
| 462 |
+
anemia_sims = (img_emb @ self.anemic_embeddings_all.T).squeeze(0)
|
| 463 |
+
healthy_sims = (img_emb @ self.healthy_embeddings_all.T).squeeze(0)
|
| 464 |
+
|
| 465 |
+
if anemia_sims.dim() == 0:
|
| 466 |
+
anemia_sims = anemia_sims.unsqueeze(0)
|
| 467 |
+
if healthy_sims.dim() == 0:
|
| 468 |
+
healthy_sims = healthy_sims.unsqueeze(0)
|
| 469 |
+
|
| 470 |
+
anemia_sim = anemia_sims.max().item()
|
| 471 |
+
healthy_sim = healthy_sims.max().item()
|
| 472 |
+
|
| 473 |
+
logits = torch.tensor([anemia_sim, healthy_sim], device="cpu") * self.LOGIT_SCALE
|
| 474 |
+
probs = torch.softmax(logits, dim=0)
|
| 475 |
+
anemia_prob = probs[0].item()
|
| 476 |
+
healthy_prob = probs[1].item()
|
| 477 |
+
|
| 478 |
+
if anemia_prob > 0.7:
|
| 479 |
+
risk_level = "high"
|
| 480 |
+
recommendation = "URGENT: Refer for blood test immediately."
|
| 481 |
+
elif anemia_prob > 0.5:
|
| 482 |
+
risk_level = "medium"
|
| 483 |
+
recommendation = "Schedule blood test within 48 hours."
|
| 484 |
+
else:
|
| 485 |
+
risk_level = "low"
|
| 486 |
+
recommendation = "No immediate concern."
|
| 487 |
+
|
| 488 |
+
results.append({
|
| 489 |
+
"is_anemic": anemia_prob > self.threshold,
|
| 490 |
+
"confidence": max(anemia_prob, healthy_prob),
|
| 491 |
+
"anemia_score": anemia_prob,
|
| 492 |
+
"healthy_score": healthy_prob,
|
| 493 |
+
"risk_level": risk_level,
|
| 494 |
+
"recommendation": recommendation,
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
return results
|
| 498 |
+
|
| 499 |
+
def analyze_color_features(self, image: Union[str, Path, Image.Image]) -> Dict:
|
| 500 |
+
"""
|
| 501 |
+
Analyze color features of conjunctiva image.
|
| 502 |
+
|
| 503 |
+
This provides interpretable features based on medical literature
|
| 504 |
+
that correlates pallor with anemia.
|
| 505 |
+
|
| 506 |
+
Args:
|
| 507 |
+
image: Conjunctiva image
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
Dictionary with color analysis results
|
| 511 |
+
"""
|
| 512 |
+
pil_image = self.preprocess_image(image)
|
| 513 |
+
img_array = np.array(pil_image)
|
| 514 |
+
|
| 515 |
+
# Extract RGB channels
|
| 516 |
+
r_channel = img_array[:, :, 0].astype(float)
|
| 517 |
+
g_channel = img_array[:, :, 1].astype(float)
|
| 518 |
+
b_channel = img_array[:, :, 2].astype(float)
|
| 519 |
+
|
| 520 |
+
# Calculate color statistics
|
| 521 |
+
mean_r = np.mean(r_channel)
|
| 522 |
+
mean_g = np.mean(g_channel)
|
| 523 |
+
mean_b = np.mean(b_channel)
|
| 524 |
+
|
| 525 |
+
# Red ratio (higher in healthy, lower in anemic)
|
| 526 |
+
total_intensity = mean_r + mean_g + mean_b
|
| 527 |
+
red_ratio = mean_r / total_intensity if total_intensity > 0 else 0
|
| 528 |
+
|
| 529 |
+
# Pallor index (higher means more pale/anemic)
|
| 530 |
+
# Based on reduced red-to-green ratio in anemic conjunctiva
|
| 531 |
+
pallor_index = 1 - (mean_r / (mean_g + 1e-6))
|
| 532 |
+
pallor_index = max(0, min(1, (pallor_index + 0.5) / 1.5))
|
| 533 |
+
|
| 534 |
+
# Hemoglobin estimation (rough approximation)
|
| 535 |
+
# Normal Hb: 12-16 g/dL for women, 14-18 for men
|
| 536 |
+
# This is a rough estimate based on color analysis
|
| 537 |
+
estimated_hb = 8 + (red_ratio * 12)
|
| 538 |
+
|
| 539 |
+
return {
|
| 540 |
+
"mean_red": mean_r,
|
| 541 |
+
"mean_green": mean_g,
|
| 542 |
+
"mean_blue": mean_b,
|
| 543 |
+
"red_ratio": red_ratio,
|
| 544 |
+
"pallor_index": pallor_index,
|
| 545 |
+
"estimated_hemoglobin": round(estimated_hb, 1),
|
| 546 |
+
"interpretation": "Low hemoglobin" if pallor_index > 0.5 else "Normal hemoglobin",
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def test_detector():
|
| 551 |
+
"""Test the anemia detector with sample images."""
|
| 552 |
+
print("Testing Anemia Detector...")
|
| 553 |
+
|
| 554 |
+
detector = AnemiaDetector()
|
| 555 |
+
|
| 556 |
+
# Test with sample images from dataset
|
| 557 |
+
data_dir = Path(__file__).parent.parent.parent / "data" / "raw" / "eyes-defy-anemia"
|
| 558 |
+
|
| 559 |
+
if data_dir.exists():
|
| 560 |
+
# Find sample images
|
| 561 |
+
sample_images = list(data_dir.rglob("*.jpg"))[:3]
|
| 562 |
+
|
| 563 |
+
for img_path in sample_images:
|
| 564 |
+
print(f"\nAnalyzing: {img_path.name}")
|
| 565 |
+
result = detector.detect(img_path)
|
| 566 |
+
print(f" Anemia detected: {result['is_anemic']}")
|
| 567 |
+
print(f" Confidence: {result['confidence']:.2%}")
|
| 568 |
+
print(f" Risk level: {result['risk_level']}")
|
| 569 |
+
print(f" Recommendation: {result['recommendation']}")
|
| 570 |
+
|
| 571 |
+
# Color analysis
|
| 572 |
+
color_info = detector.analyze_color_features(img_path)
|
| 573 |
+
print(f" Estimated Hb: {color_info['estimated_hemoglobin']} g/dL")
|
| 574 |
+
else:
|
| 575 |
+
print(f"Dataset not found at {data_dir}")
|
| 576 |
+
print("Please run download_datasets.py first")
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
if __name__ == "__main__":
|
| 580 |
+
test_detector()
|
src/nexus/clinical_synthesizer.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Clinical Synthesizer Module
|
| 3 |
+
|
| 4 |
+
Uses MedGemma from Google HAI-DEF for clinical reasoning and synthesis.
|
| 5 |
+
Combines findings from MedSigLIP (images) and HeAR (audio) into actionable recommendations.
|
| 6 |
+
|
| 7 |
+
HAI-DEF Model: MedGemma 4B (google/medgemma-4b-it or google/medgemma-1.5-4b-it)
|
| 8 |
+
Supports 4-bit quantization via BitsAndBytes for low-VRAM deployment.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from typing import Dict, Optional, List
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 17 |
+
HAS_TRANSFORMERS = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
HAS_TRANSFORMERS = False
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from transformers import BitsAndBytesConfig
|
| 23 |
+
HAS_BITSANDBYTES = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
HAS_BITSANDBYTES = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ClinicalSynthesizer:
|
| 29 |
+
"""
|
| 30 |
+
Synthesizes clinical findings using MedGemma.
|
| 31 |
+
|
| 32 |
+
HAI-DEF Model: MedGemma 4B (google/medgemma-4b-it or google/medgemma-1.5-4b-it)
|
| 33 |
+
Method: Prompt engineering (no fine-tuning required)
|
| 34 |
+
Quantization: 4-bit NF4 via BitsAndBytes for low-VRAM deployment
|
| 35 |
+
|
| 36 |
+
Output:
|
| 37 |
+
- Integrated diagnosis suggestions
|
| 38 |
+
- Severity assessment (GREEN/YELLOW/RED)
|
| 39 |
+
- Treatment recommendations (WHO IMNCI)
|
| 40 |
+
- Referral decision with urgency
|
| 41 |
+
- CHW-friendly explanations
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# WHO IMNCI severity colors
|
| 45 |
+
SEVERITY_LEVELS = {
|
| 46 |
+
"GREEN": "Routine care - no immediate concern",
|
| 47 |
+
"YELLOW": "Close monitoring - may need referral",
|
| 48 |
+
"RED": "Urgent referral - immediate action required",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# MedGemma model candidates in preference order
|
| 52 |
+
MEDGEMMA_MODEL_IDS = [
|
| 53 |
+
"google/medgemma-1.5-4b-it", # Newer, better performance
|
| 54 |
+
"google/medgemma-4b-it", # Original HAI-DEF model
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
model_name: Optional[str] = None,
|
| 60 |
+
device: Optional[str] = None,
|
| 61 |
+
use_medgemma: bool = True,
|
| 62 |
+
use_4bit: bool = True,
|
| 63 |
+
):
|
| 64 |
+
"""
|
| 65 |
+
Initialize the Clinical Synthesizer with MedGemma.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
model_name: HuggingFace model name for MedGemma (auto-selects if None)
|
| 69 |
+
device: Device to run model on
|
| 70 |
+
use_medgemma: Whether to use MedGemma (True) or rule-based (False)
|
| 71 |
+
use_4bit: Whether to use 4-bit quantization (reduces VRAM from ~8GB to ~2GB)
|
| 72 |
+
"""
|
| 73 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 74 |
+
self._user_model_name = model_name # None if user didn't specify
|
| 75 |
+
self.model_name = model_name or self.MEDGEMMA_MODEL_IDS[-1]
|
| 76 |
+
self.model = None
|
| 77 |
+
self.tokenizer = None
|
| 78 |
+
self.use_medgemma = use_medgemma
|
| 79 |
+
self.use_4bit = use_4bit
|
| 80 |
+
self._medgemma_available = False
|
| 81 |
+
|
| 82 |
+
if use_medgemma and HAS_TRANSFORMERS:
|
| 83 |
+
self._load_medgemma()
|
| 84 |
+
else:
|
| 85 |
+
print("MedGemma not available. Using rule-based clinical synthesis.")
|
| 86 |
+
self.use_medgemma = False
|
| 87 |
+
|
| 88 |
+
print(f"Clinical Synthesizer (HAI-DEF MedGemma) initialized")
|
| 89 |
+
|
| 90 |
+
def _load_medgemma(self) -> None:
|
| 91 |
+
"""Load MedGemma model from HuggingFace with 4-bit quantization.
|
| 92 |
+
|
| 93 |
+
Tries model candidates in preference order:
|
| 94 |
+
1. google/medgemma-1.5-4b-it (newer, better performance)
|
| 95 |
+
2. google/medgemma-4b-it (original HAI-DEF model)
|
| 96 |
+
|
| 97 |
+
Uses BitsAndBytes NF4 quantization to reduce VRAM from ~8GB to ~2GB,
|
| 98 |
+
which fixes CUDA OOM errors on consumer GPUs.
|
| 99 |
+
"""
|
| 100 |
+
import os
|
| 101 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 102 |
+
if not hf_token:
|
| 103 |
+
print("Warning: HF_TOKEN not set. MedGemma is a gated model and requires authentication.")
|
| 104 |
+
print("Set HF_TOKEN environment variable with your HuggingFace token.")
|
| 105 |
+
|
| 106 |
+
# Determine models to try — if user explicitly passed a model_name,
|
| 107 |
+
# only try that one; otherwise try all candidates in preference order.
|
| 108 |
+
models_to_try = [self._user_model_name] if self._user_model_name else self.MEDGEMMA_MODEL_IDS
|
| 109 |
+
|
| 110 |
+
# Build quantization config for 4-bit loading
|
| 111 |
+
bnb_config = None
|
| 112 |
+
if self.use_4bit and self.device == "cuda" and HAS_BITSANDBYTES:
|
| 113 |
+
try:
|
| 114 |
+
bnb_config = BitsAndBytesConfig(
|
| 115 |
+
load_in_4bit=True,
|
| 116 |
+
bnb_4bit_quant_type="nf4",
|
| 117 |
+
bnb_4bit_use_double_quant=True,
|
| 118 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 119 |
+
)
|
| 120 |
+
print("4-bit quantization enabled (NF4 + double quant)")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Warning: Could not create BitsAndBytes config: {e}")
|
| 123 |
+
bnb_config = None
|
| 124 |
+
|
| 125 |
+
for candidate_model in models_to_try:
|
| 126 |
+
try:
|
| 127 |
+
print(f"Loading MedGemma model: {candidate_model}")
|
| 128 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 129 |
+
candidate_model, token=hf_token
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
load_kwargs = {
|
| 133 |
+
"token": hf_token,
|
| 134 |
+
"device_map": "auto" if self.device == "cuda" else None,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if bnb_config is not None:
|
| 138 |
+
# 4-bit quantized loading (~2GB VRAM)
|
| 139 |
+
load_kwargs["quantization_config"] = bnb_config
|
| 140 |
+
else:
|
| 141 |
+
# Standard loading with fp16/fp32
|
| 142 |
+
load_kwargs["torch_dtype"] = (
|
| 143 |
+
torch.float16 if self.device == "cuda" else torch.float32
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 147 |
+
candidate_model, **load_kwargs
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if self.device == "cpu" and bnb_config is None:
|
| 151 |
+
self.model = self.model.to(self.device)
|
| 152 |
+
|
| 153 |
+
self.model_name = candidate_model
|
| 154 |
+
self._medgemma_available = True
|
| 155 |
+
quant_status = "4-bit NF4" if bnb_config is not None else "fp16/fp32"
|
| 156 |
+
print(f"MedGemma loaded successfully: {candidate_model} ({quant_status})")
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f"Warning: Could not load {candidate_model}: {e}")
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
print("Could not load any MedGemma model. Falling back to rule-based synthesis.")
|
| 164 |
+
self.model = None
|
| 165 |
+
self.tokenizer = None
|
| 166 |
+
self.use_medgemma = False
|
| 167 |
+
self._medgemma_available = False
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def _sanitize(value: object) -> str:
|
| 171 |
+
"""Sanitize a value for safe inclusion in a prompt.
|
| 172 |
+
|
| 173 |
+
Strips control characters and truncates excessively long strings to
|
| 174 |
+
prevent prompt injection via adversarial findings.
|
| 175 |
+
"""
|
| 176 |
+
text = str(value) if value is not None else "N/A"
|
| 177 |
+
# Remove characters that could break prompt structure
|
| 178 |
+
text = text.replace("\x00", "").replace("\r", "")
|
| 179 |
+
# Truncate overly long values
|
| 180 |
+
if len(text) > 500:
|
| 181 |
+
text = text[:500] + "..."
|
| 182 |
+
return text
|
| 183 |
+
|
| 184 |
+
def _build_prompt(self, findings: Dict) -> str:
|
| 185 |
+
"""
|
| 186 |
+
Build clinical synthesis prompt for MedGemma.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
findings: Dictionary with anemia, jaundice, cry analysis results.
|
| 190 |
+
May include 'agent_context' and 'agent_reasoning_summary'
|
| 191 |
+
when called from the agentic workflow engine.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Formatted prompt for MedGemma
|
| 195 |
+
"""
|
| 196 |
+
# Extract findings with safe defaults
|
| 197 |
+
anemia = findings.get("anemia", {})
|
| 198 |
+
jaundice = findings.get("jaundice", {})
|
| 199 |
+
cry = findings.get("cry", {})
|
| 200 |
+
symptoms = self._sanitize(findings.get("symptoms", "None reported"))
|
| 201 |
+
patient_info = findings.get("patient_info", {})
|
| 202 |
+
agent_context = findings.get("agent_context", {})
|
| 203 |
+
agent_reasoning = self._sanitize(findings.get("agent_reasoning_summary", ""))
|
| 204 |
+
|
| 205 |
+
prompt = f"""You are a pediatric health assistant helping community health workers in low-resource settings.
|
| 206 |
+
|
| 207 |
+
PATIENT INFORMATION:
|
| 208 |
+
- Age: {patient_info.get("age", "Not specified")}
|
| 209 |
+
- Weight: {patient_info.get("weight", "Not specified")}
|
| 210 |
+
- Location: {patient_info.get("location", "Rural health post")}
|
| 211 |
+
- Patient Type: {patient_info.get("type", "Not specified")}
|
| 212 |
+
|
| 213 |
+
ASSESSMENT FINDINGS:
|
| 214 |
+
|
| 215 |
+
1. ANEMIA SCREENING (Conjunctiva Analysis):
|
| 216 |
+
- Result: {"Anemia detected" if anemia.get("is_anemic") else "No anemia detected"}
|
| 217 |
+
- Confidence: {anemia.get("confidence", "N/A")}
|
| 218 |
+
- Severity: {anemia.get("severity", anemia.get("risk_level", "N/A"))}
|
| 219 |
+
- Estimated Hemoglobin: {anemia.get("estimated_hemoglobin", "N/A")} g/dL
|
| 220 |
+
|
| 221 |
+
2. JAUNDICE SCREENING (Skin Analysis):
|
| 222 |
+
- Result: {"Jaundice detected" if jaundice.get("has_jaundice") else "No jaundice detected"}
|
| 223 |
+
- Confidence: {jaundice.get("confidence", "N/A")}
|
| 224 |
+
- Severity: {jaundice.get("severity", "N/A")}
|
| 225 |
+
- Estimated Bilirubin: {jaundice.get("estimated_bilirubin", "N/A")} mg/dL
|
| 226 |
+
- Needs Phototherapy: {jaundice.get("needs_phototherapy", "N/A")}
|
| 227 |
+
|
| 228 |
+
3. CRY ANALYSIS (Audio):
|
| 229 |
+
- Result: {"Abnormal cry pattern" if cry.get("is_abnormal") else "Normal cry pattern"}
|
| 230 |
+
- Asphyxia Risk: {cry.get("asphyxia_risk", "N/A")}
|
| 231 |
+
- Cry Type: {cry.get("cry_type", "N/A")}
|
| 232 |
+
|
| 233 |
+
4. REPORTED SYMPTOMS:
|
| 234 |
+
{symptoms}
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
# Add agentic workflow context if available
|
| 238 |
+
if agent_context:
|
| 239 |
+
prompt += f"""
|
| 240 |
+
5. MULTI-AGENT ASSESSMENT CONTEXT:
|
| 241 |
+
- Triage Score: {agent_context.get("triage_score", "N/A")} (Risk: {agent_context.get("triage_risk", "N/A")})
|
| 242 |
+
- Critical Danger Signs: {", ".join(agent_context.get("critical_signs", [])) or "None"}
|
| 243 |
+
- WHO IMNCI Classification: {agent_context.get("protocol_classification", "N/A")}
|
| 244 |
+
- Applicable Protocols: {", ".join(agent_context.get("applicable_protocols", [])) or "N/A"}
|
| 245 |
+
- Referral Decision: {"YES" if agent_context.get("referral_needed") else "NO"} (Urgency: {agent_context.get("referral_urgency", "N/A")})
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
if agent_reasoning:
|
| 249 |
+
prompt += f"""
|
| 250 |
+
6. AGENT REASONING TRAIL:
|
| 251 |
+
{agent_reasoning}
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
prompt += """
|
| 255 |
+
Based on these findings, provide a clinical assessment following WHO IMNCI protocols:
|
| 256 |
+
|
| 257 |
+
1. ASSESSMENT SUMMARY (2-3 sentences in simple language)
|
| 258 |
+
2. SEVERITY LEVEL (GREEN = routine care, YELLOW = close monitoring, RED = urgent referral)
|
| 259 |
+
3. IMMEDIATE ACTIONS for the CHW (bullet points, simple steps)
|
| 260 |
+
4. REFERRAL RECOMMENDATION (Yes/No, and if yes, urgency level)
|
| 261 |
+
5. FOLLOW-UP PLAN (when to reassess)
|
| 262 |
+
|
| 263 |
+
Use simple language appropriate for a community health worker with basic training.
|
| 264 |
+
Focus on actionable steps they can take immediately.
|
| 265 |
+
"""
|
| 266 |
+
return prompt
|
| 267 |
+
|
| 268 |
+
def synthesize(self, findings: Dict) -> Dict:
|
| 269 |
+
"""
|
| 270 |
+
Synthesize all findings into clinical recommendations.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
findings: Dictionary with anemia, jaundice, cry analysis results
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Clinical summary and recommendations
|
| 277 |
+
"""
|
| 278 |
+
if self.use_medgemma and self.model is not None:
|
| 279 |
+
return self._synthesize_with_medgemma(findings)
|
| 280 |
+
else:
|
| 281 |
+
return self._synthesize_rule_based(findings)
|
| 282 |
+
|
| 283 |
+
def _synthesize_with_medgemma(self, findings: Dict) -> Dict:
|
| 284 |
+
"""Synthesize using MedGemma model.
|
| 285 |
+
|
| 286 |
+
Falls back to rule-based synthesis if generation fails (e.g. CUDA OOM,
|
| 287 |
+
device-side assertion, or any other runtime error).
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
prompt = self._build_prompt(findings)
|
| 291 |
+
|
| 292 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
| 293 |
+
# For models loaded with device_map="auto", route inputs to the
|
| 294 |
+
# embedding layer's device to avoid CPU/CUDA mismatch.
|
| 295 |
+
try:
|
| 296 |
+
input_device = self.model.get_input_embeddings().weight.device
|
| 297 |
+
except Exception:
|
| 298 |
+
input_device = self.device
|
| 299 |
+
inputs = {k: v.to(input_device) for k, v in inputs.items()}
|
| 300 |
+
|
| 301 |
+
prompt_len = inputs["input_ids"].shape[-1]
|
| 302 |
+
|
| 303 |
+
with torch.no_grad():
|
| 304 |
+
outputs = self.model.generate(
|
| 305 |
+
**inputs,
|
| 306 |
+
max_new_tokens=500,
|
| 307 |
+
temperature=0.7,
|
| 308 |
+
do_sample=True,
|
| 309 |
+
top_p=0.9,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Extract only the generated tokens (after the prompt)
|
| 313 |
+
generated_ids = outputs[0][prompt_len:]
|
| 314 |
+
response = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 315 |
+
|
| 316 |
+
# Guard against empty or very short responses
|
| 317 |
+
if len(response) < 20:
|
| 318 |
+
return self._synthesize_rule_based(findings)
|
| 319 |
+
|
| 320 |
+
# Determine display name for the model
|
| 321 |
+
if "1.5" in self.model_name:
|
| 322 |
+
display_name = "MedGemma 1.5 4B"
|
| 323 |
+
else:
|
| 324 |
+
display_name = "MedGemma 4B"
|
| 325 |
+
|
| 326 |
+
return {
|
| 327 |
+
"summary": response,
|
| 328 |
+
"model": display_name,
|
| 329 |
+
"model_id": self.model_name,
|
| 330 |
+
"generated_at": datetime.now().isoformat(),
|
| 331 |
+
"findings_used": list(findings.keys()),
|
| 332 |
+
}
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"MedGemma generation failed: {e}. Falling back to rule-based synthesis.")
|
| 335 |
+
# Disable MedGemma to avoid repeated CUDA errors that corrupt the
|
| 336 |
+
# device context and break subsequent GPU operations.
|
| 337 |
+
self.use_medgemma = False
|
| 338 |
+
self._medgemma_available = False
|
| 339 |
+
self.model = None
|
| 340 |
+
try:
|
| 341 |
+
torch.cuda.empty_cache()
|
| 342 |
+
except Exception:
|
| 343 |
+
pass
|
| 344 |
+
return self._synthesize_rule_based(findings)
|
| 345 |
+
|
| 346 |
+
def _synthesize_rule_based(self, findings: Dict) -> Dict:
|
| 347 |
+
"""
|
| 348 |
+
Rule-based clinical synthesis (fallback when MedGemma unavailable).
|
| 349 |
+
|
| 350 |
+
Follows WHO IMNCI protocols for maternal and neonatal care.
|
| 351 |
+
"""
|
| 352 |
+
# Extract findings
|
| 353 |
+
anemia = findings.get("anemia", {})
|
| 354 |
+
jaundice = findings.get("jaundice", {})
|
| 355 |
+
cry = findings.get("cry", {})
|
| 356 |
+
|
| 357 |
+
# Determine overall severity
|
| 358 |
+
severity_score = 0
|
| 359 |
+
urgent_conditions = []
|
| 360 |
+
actions = []
|
| 361 |
+
referral_needed = False
|
| 362 |
+
referral_urgency = "none"
|
| 363 |
+
|
| 364 |
+
# Assess anemia
|
| 365 |
+
if anemia.get("is_anemic"):
|
| 366 |
+
if anemia.get("risk_level") == "high":
|
| 367 |
+
severity_score += 3
|
| 368 |
+
urgent_conditions.append("Severe anemia")
|
| 369 |
+
actions.append("Refer for blood transfusion if Hb < 7 g/dL")
|
| 370 |
+
referral_needed = True
|
| 371 |
+
referral_urgency = "urgent"
|
| 372 |
+
elif anemia.get("risk_level") == "medium":
|
| 373 |
+
severity_score += 2
|
| 374 |
+
urgent_conditions.append("Moderate anemia")
|
| 375 |
+
actions.append("Start iron supplementation")
|
| 376 |
+
actions.append("Schedule blood test within 48 hours")
|
| 377 |
+
else:
|
| 378 |
+
severity_score += 1
|
| 379 |
+
actions.append("Monitor hemoglobin levels")
|
| 380 |
+
actions.append("Encourage iron-rich foods")
|
| 381 |
+
|
| 382 |
+
# Assess jaundice
|
| 383 |
+
if jaundice.get("has_jaundice"):
|
| 384 |
+
if jaundice.get("needs_phototherapy"):
|
| 385 |
+
severity_score += 3
|
| 386 |
+
urgent_conditions.append("Severe jaundice requiring phototherapy")
|
| 387 |
+
actions.append("URGENT: Start phototherapy immediately")
|
| 388 |
+
actions.append("Refer to hospital if phototherapy unavailable")
|
| 389 |
+
referral_needed = True
|
| 390 |
+
referral_urgency = "immediate"
|
| 391 |
+
elif jaundice.get("severity") in ["moderate", "severe"]:
|
| 392 |
+
severity_score += 2
|
| 393 |
+
urgent_conditions.append("Moderate jaundice")
|
| 394 |
+
actions.append("Expose baby to indirect sunlight")
|
| 395 |
+
actions.append("Ensure frequent breastfeeding")
|
| 396 |
+
actions.append("Recheck in 12-24 hours")
|
| 397 |
+
else:
|
| 398 |
+
severity_score += 1
|
| 399 |
+
actions.append("Continue breastfeeding")
|
| 400 |
+
actions.append("Monitor skin color")
|
| 401 |
+
|
| 402 |
+
# Assess cry analysis
|
| 403 |
+
if cry.get("is_abnormal"):
|
| 404 |
+
if cry.get("asphyxia_risk", 0) > 0.6:
|
| 405 |
+
severity_score += 3
|
| 406 |
+
urgent_conditions.append("Signs of birth asphyxia")
|
| 407 |
+
actions.append("URGENT: Check airway, breathing, circulation")
|
| 408 |
+
actions.append("Provide warmth and stimulation")
|
| 409 |
+
actions.append("Immediate referral for evaluation")
|
| 410 |
+
referral_needed = True
|
| 411 |
+
referral_urgency = "immediate"
|
| 412 |
+
else:
|
| 413 |
+
severity_score += 1
|
| 414 |
+
actions.append("Monitor cry patterns")
|
| 415 |
+
actions.append("Assess feeding and alertness")
|
| 416 |
+
|
| 417 |
+
# Determine overall severity level
|
| 418 |
+
if severity_score >= 5 or referral_urgency == "immediate":
|
| 419 |
+
severity_level = "RED"
|
| 420 |
+
summary = f"URGENT ATTENTION NEEDED. {', '.join(urgent_conditions)}. Immediate medical intervention required."
|
| 421 |
+
elif severity_score >= 2:
|
| 422 |
+
severity_level = "YELLOW"
|
| 423 |
+
summary = f"Close monitoring required. {', '.join(urgent_conditions) if urgent_conditions else 'Some abnormal findings detected'}. Follow recommended actions."
|
| 424 |
+
else:
|
| 425 |
+
severity_level = "GREEN"
|
| 426 |
+
summary = "Routine care. No immediate concerns detected. Continue standard monitoring."
|
| 427 |
+
|
| 428 |
+
# Default actions if none specified
|
| 429 |
+
if not actions:
|
| 430 |
+
actions = [
|
| 431 |
+
"Continue routine care",
|
| 432 |
+
"Ensure adequate nutrition",
|
| 433 |
+
"Schedule follow-up in 1 week",
|
| 434 |
+
]
|
| 435 |
+
|
| 436 |
+
# Follow-up plan
|
| 437 |
+
if severity_level == "RED":
|
| 438 |
+
follow_up = "Immediate referral. Follow up after hospital evaluation."
|
| 439 |
+
elif severity_level == "YELLOW":
|
| 440 |
+
follow_up = "Reassess in 24-48 hours. Refer if condition worsens."
|
| 441 |
+
else:
|
| 442 |
+
follow_up = "Routine follow-up in 1-2 weeks."
|
| 443 |
+
|
| 444 |
+
return {
|
| 445 |
+
"summary": summary,
|
| 446 |
+
"severity_level": severity_level,
|
| 447 |
+
"severity_description": self.SEVERITY_LEVELS[severity_level],
|
| 448 |
+
"immediate_actions": actions,
|
| 449 |
+
"referral_needed": referral_needed,
|
| 450 |
+
"referral_urgency": referral_urgency,
|
| 451 |
+
"follow_up": follow_up,
|
| 452 |
+
"urgent_conditions": urgent_conditions,
|
| 453 |
+
"model": "Rule-based (WHO IMNCI)",
|
| 454 |
+
"generated_at": datetime.now().isoformat(),
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
def get_who_protocol(self, condition: str) -> Dict:
|
| 458 |
+
"""
|
| 459 |
+
Get WHO IMNCI protocol for a specific condition.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
condition: Condition name (anemia, jaundice, asphyxia)
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
Protocol details
|
| 466 |
+
"""
|
| 467 |
+
protocols = {
|
| 468 |
+
"anemia": {
|
| 469 |
+
"name": "Maternal Anemia Management",
|
| 470 |
+
"source": "WHO IMNCI Guidelines",
|
| 471 |
+
"steps": [
|
| 472 |
+
"Assess pallor of conjunctiva, palms, and nail beds",
|
| 473 |
+
"If severe pallor: Urgent referral",
|
| 474 |
+
"If some pallor: Iron supplementation + folic acid",
|
| 475 |
+
"Counsel on iron-rich foods",
|
| 476 |
+
"Follow up in 4 weeks",
|
| 477 |
+
],
|
| 478 |
+
"referral_criteria": "Hb < 7 g/dL or severe pallor with symptoms",
|
| 479 |
+
},
|
| 480 |
+
"jaundice": {
|
| 481 |
+
"name": "Neonatal Jaundice Management",
|
| 482 |
+
"source": "WHO IMNCI Guidelines",
|
| 483 |
+
"steps": [
|
| 484 |
+
"Check for yellow skin/eyes within first 24 hours",
|
| 485 |
+
"If jaundice in first 24 hours: URGENT referral",
|
| 486 |
+
"If moderate jaundice: Frequent breastfeeding, sun exposure",
|
| 487 |
+
"If bilirubin > 15 mg/dL: Phototherapy",
|
| 488 |
+
"If bilirubin > 25 mg/dL: Exchange transfusion",
|
| 489 |
+
],
|
| 490 |
+
"referral_criteria": "Jaundice < 24 hours old, bilirubin > 20 mg/dL",
|
| 491 |
+
},
|
| 492 |
+
"asphyxia": {
|
| 493 |
+
"name": "Birth Asphyxia Management",
|
| 494 |
+
"source": "WHO Neonatal Resuscitation Guidelines",
|
| 495 |
+
"steps": [
|
| 496 |
+
"Assess APGAR score at 1 and 5 minutes",
|
| 497 |
+
"Clear airway if needed",
|
| 498 |
+
"Provide warmth and stimulation",
|
| 499 |
+
"If not breathing: Begin resuscitation",
|
| 500 |
+
"Refer for evaluation if abnormal cry or poor feeding",
|
| 501 |
+
],
|
| 502 |
+
"referral_criteria": "APGAR < 7, abnormal cry, seizures, poor feeding",
|
| 503 |
+
},
|
| 504 |
+
}
|
| 505 |
+
return protocols.get(condition.lower(), {"error": "Protocol not found"})
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def test_synthesizer():
|
| 509 |
+
"""Test the clinical synthesizer."""
|
| 510 |
+
print("Testing Clinical Synthesizer...")
|
| 511 |
+
|
| 512 |
+
synthesizer = ClinicalSynthesizer(use_medgemma=False) # Use rule-based for testing
|
| 513 |
+
|
| 514 |
+
# Test case: Multiple findings
|
| 515 |
+
findings = {
|
| 516 |
+
"anemia": {
|
| 517 |
+
"is_anemic": True,
|
| 518 |
+
"confidence": 0.85,
|
| 519 |
+
"risk_level": "medium",
|
| 520 |
+
"estimated_hemoglobin": 9.5,
|
| 521 |
+
},
|
| 522 |
+
"jaundice": {
|
| 523 |
+
"has_jaundice": True,
|
| 524 |
+
"confidence": 0.75,
|
| 525 |
+
"severity": "mild",
|
| 526 |
+
"estimated_bilirubin": 8.5,
|
| 527 |
+
"needs_phototherapy": False,
|
| 528 |
+
},
|
| 529 |
+
"cry": {
|
| 530 |
+
"is_abnormal": False,
|
| 531 |
+
"asphyxia_risk": 0.2,
|
| 532 |
+
"cry_type": "hunger",
|
| 533 |
+
},
|
| 534 |
+
"symptoms": "Mother reports baby seems tired after feeding",
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
result = synthesizer.synthesize(findings)
|
| 538 |
+
|
| 539 |
+
print("\n=== Clinical Synthesis Result ===")
|
| 540 |
+
print(f"Summary: {result['summary']}")
|
| 541 |
+
print(f"Severity: {result.get('severity_level', 'N/A')}")
|
| 542 |
+
print(f"Referral Needed: {result.get('referral_needed', 'N/A')}")
|
| 543 |
+
print(f"Actions: {result.get('immediate_actions', [])}")
|
| 544 |
+
print(f"Follow-up: {result.get('follow_up', 'N/A')}")
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
if __name__ == "__main__":
|
| 548 |
+
test_synthesizer()
|
src/nexus/cry_analyzer.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cry Analyzer Module
|
| 3 |
+
|
| 4 |
+
Uses HeAR from Google HAI-DEF for infant cry analysis and birth asphyxia detection.
|
| 5 |
+
Implements embedding extraction + linear classifier per NEXUS_MASTER_PLAN.md.
|
| 6 |
+
|
| 7 |
+
HAI-DEF Model: HeAR (Health Acoustic Representations)
|
| 8 |
+
Source: https://github.com/Google-Health/google-health/tree/master/health_acoustic_representations
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 16 |
+
import warnings
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import librosa
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
HAS_AUDIO = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_AUDIO = False
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from sklearn.linear_model import LogisticRegression
|
| 28 |
+
import joblib
|
| 29 |
+
HAS_SKLEARN = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
HAS_SKLEARN = False
|
| 32 |
+
|
| 33 |
+
# HeAR PyTorch via HuggingFace
|
| 34 |
+
try:
|
| 35 |
+
from transformers import AutoModel as HearAutoModel
|
| 36 |
+
HAS_HEAR_PYTORCH = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
HAS_HEAR_PYTORCH = False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CryAnalyzer:
|
| 42 |
+
"""
|
| 43 |
+
Analyzes infant cry audio for birth asphyxia detection using HeAR.
|
| 44 |
+
|
| 45 |
+
HAI-DEF Model: HeAR (google/hear-pytorch)
|
| 46 |
+
Method: Embedding extraction + acoustic feature analysis
|
| 47 |
+
|
| 48 |
+
Process:
|
| 49 |
+
1. Split audio into 2-second chunks (HeAR requirement)
|
| 50 |
+
2. Extract HeAR embeddings (512-dim per chunk)
|
| 51 |
+
3. Aggregate embeddings (mean pooling)
|
| 52 |
+
4. Classify with trained linear model or rule-based fallback
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
# HeAR model configuration
|
| 56 |
+
SAMPLE_RATE = 16000 # Hz - HeAR requires 16kHz
|
| 57 |
+
CHUNK_DURATION = 2.0 # seconds - HeAR chunk size
|
| 58 |
+
CHUNK_SIZE = 32000 # samples (2 seconds at 16kHz)
|
| 59 |
+
EMBEDDING_DIM = 512 # HeAR embedding dimension
|
| 60 |
+
|
| 61 |
+
# Acoustic feature thresholds (fallback if HeAR unavailable)
|
| 62 |
+
NORMAL_F0_RANGE = (250, 450) # Hz
|
| 63 |
+
ASPHYXIA_F0_THRESHOLD = 500 # Hz - higher F0 indicates distress
|
| 64 |
+
MIN_CRY_DURATION = 0.5 # seconds
|
| 65 |
+
|
| 66 |
+
# HeAR model ID on HuggingFace (PyTorch)
|
| 67 |
+
HEAR_MODEL_ID = "google/hear-pytorch"
|
| 68 |
+
|
| 69 |
+
# Default classifier path (relative to project root)
|
| 70 |
+
DEFAULT_CLASSIFIER_PATHS = [
|
| 71 |
+
Path(__file__).parent.parent.parent / "models" / "linear_probes" / "cry_classifier.joblib",
|
| 72 |
+
Path("models/linear_probes/cry_classifier.joblib"),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
# Cry type labels from trained classifier
|
| 76 |
+
CRY_TYPE_LABELS = {
|
| 77 |
+
0: "belly_pain",
|
| 78 |
+
1: "burping",
|
| 79 |
+
2: "discomfort",
|
| 80 |
+
3: "hungry",
|
| 81 |
+
4: "tired",
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
device: Optional[str] = None,
|
| 87 |
+
classifier_path: Optional[str] = None,
|
| 88 |
+
use_hear: bool = True,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Initialize the Cry Analyzer with HeAR.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
device: Device to run model on
|
| 95 |
+
classifier_path: Path to trained linear classifier (optional, auto-detected)
|
| 96 |
+
use_hear: Whether to use HeAR embeddings (True) or acoustic features (False)
|
| 97 |
+
"""
|
| 98 |
+
if not HAS_AUDIO:
|
| 99 |
+
raise ImportError("librosa and soundfile required. Install with: pip install librosa soundfile")
|
| 100 |
+
|
| 101 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 102 |
+
self.classifier_path = classifier_path
|
| 103 |
+
self.classifier = None
|
| 104 |
+
self.hear_model = None
|
| 105 |
+
self.use_hear = use_hear
|
| 106 |
+
self._hear_available = False
|
| 107 |
+
|
| 108 |
+
# Try to load HeAR model
|
| 109 |
+
if use_hear:
|
| 110 |
+
self._load_hear_model()
|
| 111 |
+
|
| 112 |
+
# Load trained classifier: explicit path first, then auto-detect
|
| 113 |
+
self._load_classifier(classifier_path)
|
| 114 |
+
|
| 115 |
+
mode = "HeAR" if self._hear_available else "Acoustic Features (HeAR unavailable)"
|
| 116 |
+
classifier_status = "with trained classifier" if self.classifier else "heuristic scoring"
|
| 117 |
+
print(f"Cry Analyzer (HAI-DEF {mode}, {classifier_status}) initialized on {self.device}")
|
| 118 |
+
|
| 119 |
+
def _load_classifier(self, classifier_path: Optional[str] = None) -> None:
|
| 120 |
+
"""Load trained cry classifier from file.
|
| 121 |
+
|
| 122 |
+
Searches explicit path first, then default locations.
|
| 123 |
+
"""
|
| 124 |
+
if not HAS_SKLEARN:
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
paths_to_try = []
|
| 128 |
+
if classifier_path:
|
| 129 |
+
paths_to_try.append(Path(classifier_path))
|
| 130 |
+
paths_to_try.extend(self.DEFAULT_CLASSIFIER_PATHS)
|
| 131 |
+
|
| 132 |
+
for path in paths_to_try:
|
| 133 |
+
if path.exists():
|
| 134 |
+
try:
|
| 135 |
+
self.classifier = joblib.load(path)
|
| 136 |
+
self.classifier_path = str(path)
|
| 137 |
+
print(f"Loaded cry classifier from {path}")
|
| 138 |
+
return
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Warning: Could not load classifier from {path}: {e}")
|
| 141 |
+
|
| 142 |
+
def _load_hear_model(self) -> None:
|
| 143 |
+
"""Load HeAR model from HuggingFace (PyTorch).
|
| 144 |
+
|
| 145 |
+
HeAR (Health Acoustic Representations) is a Google HAI-DEF model
|
| 146 |
+
for health-related audio analysis. It produces 512-dimensional
|
| 147 |
+
embeddings from 2-second audio chunks at 16kHz.
|
| 148 |
+
"""
|
| 149 |
+
if not HAS_HEAR_PYTORCH:
|
| 150 |
+
print("Warning: transformers not available. Install with: pip install transformers")
|
| 151 |
+
print("Falling back to acoustic feature extraction (deterministic)")
|
| 152 |
+
self._hear_available = False
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
print(f"Loading HeAR model from HuggingFace: {self.HEAR_MODEL_ID}")
|
| 159 |
+
self.hear_model = HearAutoModel.from_pretrained(
|
| 160 |
+
self.HEAR_MODEL_ID,
|
| 161 |
+
token=hf_token,
|
| 162 |
+
trust_remote_code=True,
|
| 163 |
+
)
|
| 164 |
+
self.hear_model = self.hear_model.to(self.device)
|
| 165 |
+
self.hear_model.eval()
|
| 166 |
+
self._hear_available = True
|
| 167 |
+
print("HeAR model loaded successfully (PyTorch)")
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Warning: Could not load HeAR model: {e}")
|
| 171 |
+
print("Falling back to acoustic feature extraction (deterministic)")
|
| 172 |
+
self.hear_model = None
|
| 173 |
+
self._hear_available = False
|
| 174 |
+
|
| 175 |
+
def _split_audio_chunks(self, audio: np.ndarray) -> List[np.ndarray]:
|
| 176 |
+
"""
|
| 177 |
+
Split audio into 2-second chunks for HeAR processing.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
audio: Audio signal array (16kHz)
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
List of audio chunks (each 2 seconds / 32000 samples)
|
| 184 |
+
"""
|
| 185 |
+
chunks = []
|
| 186 |
+
for i in range(0, len(audio), self.CHUNK_SIZE):
|
| 187 |
+
chunk = audio[i:i + self.CHUNK_SIZE]
|
| 188 |
+
if len(chunk) < self.CHUNK_SIZE:
|
| 189 |
+
# Pad with zeros if needed
|
| 190 |
+
chunk = np.pad(chunk, (0, self.CHUNK_SIZE - len(chunk)))
|
| 191 |
+
chunks.append(chunk)
|
| 192 |
+
return chunks
|
| 193 |
+
|
| 194 |
+
def extract_hear_embeddings(self, audio: np.ndarray) -> np.ndarray:
|
| 195 |
+
"""
|
| 196 |
+
Extract HeAR embeddings from audio using PyTorch.
|
| 197 |
+
|
| 198 |
+
HeAR is a ViT model that expects mel-PCEN spectrograms, not raw audio.
|
| 199 |
+
Pipeline: raw audio (32000 samples) → preprocess_audio() → (1, 1, 192, 128)
|
| 200 |
+
→ ViT forward pass → pool last_hidden_state → embedding
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
audio: Audio signal (16kHz)
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Aggregated embedding (HeAR hidden_size dim, or 8-dim fallback)
|
| 207 |
+
"""
|
| 208 |
+
if not self._hear_available or self.hear_model is None:
|
| 209 |
+
# Fallback: use acoustic features as pseudo-embeddings
|
| 210 |
+
# This is deterministic - same audio always produces same features
|
| 211 |
+
features = self.extract_features(audio, self.SAMPLE_RATE)
|
| 212 |
+
# Create a feature vector from acoustic features
|
| 213 |
+
feature_vector = np.array([
|
| 214 |
+
features.get("f0_mean", 0),
|
| 215 |
+
features.get("f0_std", 0),
|
| 216 |
+
features.get("f0_range", 0),
|
| 217 |
+
features.get("voiced_ratio", 0),
|
| 218 |
+
features.get("spectral_centroid_mean", 0),
|
| 219 |
+
features.get("spectral_bandwidth_mean", 0),
|
| 220 |
+
features.get("zcr_mean", 0),
|
| 221 |
+
features.get("rms_mean", 0),
|
| 222 |
+
])
|
| 223 |
+
return feature_vector
|
| 224 |
+
|
| 225 |
+
from .hear_preprocessing import preprocess_audio
|
| 226 |
+
|
| 227 |
+
# Split into 2-second chunks for HeAR
|
| 228 |
+
chunks = self._split_audio_chunks(audio)
|
| 229 |
+
|
| 230 |
+
# Extract embeddings for each chunk using HeAR (PyTorch)
|
| 231 |
+
embeddings = []
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
for chunk in chunks:
|
| 234 |
+
# Convert raw audio to tensor: (1, 32000)
|
| 235 |
+
chunk_tensor = torch.tensor(
|
| 236 |
+
chunk.astype(np.float32)
|
| 237 |
+
).unsqueeze(0).to(self.device)
|
| 238 |
+
|
| 239 |
+
# Preprocess: raw audio → mel-PCEN spectrogram (1, 1, 192, 128)
|
| 240 |
+
spectrogram = preprocess_audio(chunk_tensor)
|
| 241 |
+
|
| 242 |
+
# Forward pass: HeAR ViT expects pixel_values
|
| 243 |
+
output = self.hear_model(
|
| 244 |
+
pixel_values=spectrogram,
|
| 245 |
+
return_dict=True,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Extract embedding from ViT output
|
| 249 |
+
if hasattr(output, 'pooler_output') and output.pooler_output is not None:
|
| 250 |
+
embedding = output.pooler_output
|
| 251 |
+
elif hasattr(output, 'last_hidden_state'):
|
| 252 |
+
# Mean pool over sequence dimension (skip CLS token)
|
| 253 |
+
embedding = output.last_hidden_state[:, 1:, :].mean(dim=1)
|
| 254 |
+
elif isinstance(output, torch.Tensor):
|
| 255 |
+
embedding = output
|
| 256 |
+
else:
|
| 257 |
+
embedding = list(output.values())[0] if hasattr(output, 'values') else output[0]
|
| 258 |
+
|
| 259 |
+
embeddings.append(embedding.cpu().numpy().squeeze())
|
| 260 |
+
|
| 261 |
+
# Aggregate embeddings (mean pooling across chunks)
|
| 262 |
+
aggregated = np.mean(embeddings, axis=0)
|
| 263 |
+
return aggregated
|
| 264 |
+
|
| 265 |
+
def load_audio(
|
| 266 |
+
self,
|
| 267 |
+
audio_path: Union[str, Path],
|
| 268 |
+
sr: int = None,
|
| 269 |
+
) -> Tuple[np.ndarray, int]:
|
| 270 |
+
"""
|
| 271 |
+
Load audio file.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
audio_path: Path to audio file
|
| 275 |
+
sr: Target sample rate (uses file's native if None)
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Tuple of (audio_array, sample_rate)
|
| 279 |
+
"""
|
| 280 |
+
sr = sr or self.SAMPLE_RATE
|
| 281 |
+
audio, file_sr = librosa.load(audio_path, sr=sr)
|
| 282 |
+
return audio, sr
|
| 283 |
+
|
| 284 |
+
def extract_features(self, audio: np.ndarray, sr: int) -> Dict:
|
| 285 |
+
"""
|
| 286 |
+
Extract acoustic features from cry audio.
|
| 287 |
+
|
| 288 |
+
Features based on cry analysis literature:
|
| 289 |
+
- Fundamental frequency (F0)
|
| 290 |
+
- MFCCs (mel-frequency cepstral coefficients)
|
| 291 |
+
- Spectral features
|
| 292 |
+
- Temporal features
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
audio: Audio signal array
|
| 296 |
+
sr: Sample rate
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Dictionary of extracted features
|
| 300 |
+
"""
|
| 301 |
+
features = {}
|
| 302 |
+
|
| 303 |
+
# Ensure minimum length
|
| 304 |
+
if len(audio) < sr * self.MIN_CRY_DURATION:
|
| 305 |
+
# Pad if too short
|
| 306 |
+
audio = np.pad(audio, (0, int(sr * self.MIN_CRY_DURATION) - len(audio)))
|
| 307 |
+
|
| 308 |
+
# Duration
|
| 309 |
+
features["duration"] = len(audio) / sr
|
| 310 |
+
|
| 311 |
+
# Fundamental frequency (F0) using pyin
|
| 312 |
+
with warnings.catch_warnings():
|
| 313 |
+
warnings.simplefilter("ignore")
|
| 314 |
+
f0, voiced_flag, voiced_probs = librosa.pyin(
|
| 315 |
+
audio,
|
| 316 |
+
fmin=librosa.note_to_hz('C2'),
|
| 317 |
+
fmax=librosa.note_to_hz('C7'),
|
| 318 |
+
sr=sr,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# F0 statistics (ignoring unvoiced frames)
|
| 322 |
+
f0_valid = f0[~np.isnan(f0)]
|
| 323 |
+
if len(f0_valid) > 0:
|
| 324 |
+
features["f0_mean"] = float(np.mean(f0_valid))
|
| 325 |
+
features["f0_std"] = float(np.std(f0_valid))
|
| 326 |
+
features["f0_min"] = float(np.min(f0_valid))
|
| 327 |
+
features["f0_max"] = float(np.max(f0_valid))
|
| 328 |
+
features["f0_range"] = features["f0_max"] - features["f0_min"]
|
| 329 |
+
else:
|
| 330 |
+
features["f0_mean"] = 0
|
| 331 |
+
features["f0_std"] = 0
|
| 332 |
+
features["f0_min"] = 0
|
| 333 |
+
features["f0_max"] = 0
|
| 334 |
+
features["f0_range"] = 0
|
| 335 |
+
|
| 336 |
+
# Voiced ratio (cry vs silence)
|
| 337 |
+
features["voiced_ratio"] = float(np.mean(voiced_flag))
|
| 338 |
+
|
| 339 |
+
# MFCCs
|
| 340 |
+
mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
|
| 341 |
+
for i in range(13):
|
| 342 |
+
features[f"mfcc_{i}_mean"] = float(np.mean(mfccs[i]))
|
| 343 |
+
features[f"mfcc_{i}_std"] = float(np.std(mfccs[i]))
|
| 344 |
+
|
| 345 |
+
# Spectral features
|
| 346 |
+
spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)
|
| 347 |
+
spectral_bandwidth = librosa.feature.spectral_bandwidth(y=audio, sr=sr)
|
| 348 |
+
spectral_rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sr)
|
| 349 |
+
|
| 350 |
+
features["spectral_centroid_mean"] = float(np.mean(spectral_centroid))
|
| 351 |
+
features["spectral_bandwidth_mean"] = float(np.mean(spectral_bandwidth))
|
| 352 |
+
features["spectral_rolloff_mean"] = float(np.mean(spectral_rolloff))
|
| 353 |
+
|
| 354 |
+
# Zero crossing rate (higher in noisy/irregular cries)
|
| 355 |
+
zcr = librosa.feature.zero_crossing_rate(audio)
|
| 356 |
+
features["zcr_mean"] = float(np.mean(zcr))
|
| 357 |
+
features["zcr_std"] = float(np.std(zcr))
|
| 358 |
+
|
| 359 |
+
# RMS energy
|
| 360 |
+
rms = librosa.feature.rms(y=audio)
|
| 361 |
+
features["rms_mean"] = float(np.mean(rms))
|
| 362 |
+
features["rms_std"] = float(np.std(rms))
|
| 363 |
+
|
| 364 |
+
# Tempo estimation (cry rhythm)
|
| 365 |
+
onset_env = librosa.onset.onset_strength(y=audio, sr=sr)
|
| 366 |
+
tempo = librosa.feature.tempo(onset_envelope=onset_env, sr=sr)
|
| 367 |
+
features["tempo"] = float(tempo[0]) if len(tempo) > 0 else 0
|
| 368 |
+
|
| 369 |
+
return features
|
| 370 |
+
|
| 371 |
+
def analyze(self, audio_path: Union[str, Path]) -> Dict:
|
| 372 |
+
"""
|
| 373 |
+
Analyze cry audio for health indicators.
|
| 374 |
+
|
| 375 |
+
Uses HeAR embeddings + classifier when available, falls back to
|
| 376 |
+
rule-based acoustic analysis when HeAR is unavailable.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
audio_path: Path to cry audio file
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
Dictionary containing:
|
| 383 |
+
- is_abnormal: Boolean indicating abnormal cry
|
| 384 |
+
- asphyxia_risk: Risk score for birth asphyxia (0-1)
|
| 385 |
+
- cry_type: Detected cry type
|
| 386 |
+
- features: Extracted acoustic features
|
| 387 |
+
- risk_level: "low", "medium", "high"
|
| 388 |
+
- recommendation: Clinical recommendation
|
| 389 |
+
"""
|
| 390 |
+
# Load audio
|
| 391 |
+
audio, sr = self.load_audio(audio_path)
|
| 392 |
+
|
| 393 |
+
# Extract acoustic features (always needed for cry_type and feature reporting)
|
| 394 |
+
features = self.extract_features(audio, sr)
|
| 395 |
+
|
| 396 |
+
# Determine cry type based on acoustic features
|
| 397 |
+
cry_type = self._classify_cry_type(features)
|
| 398 |
+
|
| 399 |
+
# Try HeAR-based classification first
|
| 400 |
+
classified_cry_type = None
|
| 401 |
+
if self._hear_available or (self.classifier is not None and HAS_SKLEARN):
|
| 402 |
+
asphyxia_risk, model_used, classified_cry_type = self._analyze_with_hear(audio)
|
| 403 |
+
else:
|
| 404 |
+
asphyxia_risk, model_used = self._analyze_with_rules(features)
|
| 405 |
+
|
| 406 |
+
# Use classifier's cry type if available, otherwise rule-based
|
| 407 |
+
if classified_cry_type is not None:
|
| 408 |
+
cry_type = classified_cry_type
|
| 409 |
+
|
| 410 |
+
# Determine risk level and recommendation based on risk score
|
| 411 |
+
if asphyxia_risk > 0.6:
|
| 412 |
+
risk_level = "high"
|
| 413 |
+
is_abnormal = True
|
| 414 |
+
recommendation = "URGENT: High-pitched abnormal cry detected. Assess for birth asphyxia immediately. Check APGAR score and vital signs."
|
| 415 |
+
elif asphyxia_risk > 0.3:
|
| 416 |
+
risk_level = "medium"
|
| 417 |
+
is_abnormal = True
|
| 418 |
+
recommendation = "CAUTION: Some abnormal cry characteristics. Monitor closely and reassess in 30 minutes."
|
| 419 |
+
else:
|
| 420 |
+
risk_level = "low"
|
| 421 |
+
is_abnormal = False
|
| 422 |
+
recommendation = "Normal cry pattern. Continue routine care."
|
| 423 |
+
|
| 424 |
+
return {
|
| 425 |
+
"is_abnormal": is_abnormal,
|
| 426 |
+
"asphyxia_risk": round(asphyxia_risk, 3),
|
| 427 |
+
"cry_type": cry_type,
|
| 428 |
+
"risk_level": risk_level,
|
| 429 |
+
"recommendation": recommendation,
|
| 430 |
+
"features": {
|
| 431 |
+
"f0_mean": round(features["f0_mean"], 1),
|
| 432 |
+
"f0_std": round(features["f0_std"], 1),
|
| 433 |
+
"duration": round(features["duration"], 2),
|
| 434 |
+
"voiced_ratio": round(features["voiced_ratio"], 2),
|
| 435 |
+
},
|
| 436 |
+
"model": model_used,
|
| 437 |
+
"model_note": self._get_model_note(model_used),
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
def _analyze_with_hear(self, audio: np.ndarray) -> Tuple[float, str, Optional[str]]:
|
| 441 |
+
"""
|
| 442 |
+
Analyze cry using HeAR embeddings.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
audio: Audio signal array (16kHz)
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Tuple of (asphyxia_risk, model_name, predicted_cry_type)
|
| 449 |
+
"""
|
| 450 |
+
# Extract HeAR embeddings
|
| 451 |
+
embeddings = self.extract_hear_embeddings(audio)
|
| 452 |
+
|
| 453 |
+
# Use trained classifier if available
|
| 454 |
+
if self.classifier is not None and HAS_SKLEARN:
|
| 455 |
+
embeddings_2d = embeddings.reshape(1, -1)
|
| 456 |
+
|
| 457 |
+
# Multi-class cry type classification
|
| 458 |
+
prediction = int(self.classifier.predict(embeddings_2d)[0])
|
| 459 |
+
predicted_type = self.CRY_TYPE_LABELS.get(prediction, "unknown")
|
| 460 |
+
|
| 461 |
+
# Get class probabilities for confidence
|
| 462 |
+
if hasattr(self.classifier, 'predict_proba'):
|
| 463 |
+
proba = self.classifier.predict_proba(embeddings_2d)[0]
|
| 464 |
+
confidence = float(max(proba))
|
| 465 |
+
|
| 466 |
+
# Derive asphyxia risk from cry type probabilities
|
| 467 |
+
# Pain and belly_pain cries are most associated with distress
|
| 468 |
+
pain_classes = {"belly_pain": 0, "discomfort": 2}
|
| 469 |
+
distress_prob = sum(
|
| 470 |
+
proba[idx] for name, idx in pain_classes.items()
|
| 471 |
+
if idx < len(proba)
|
| 472 |
+
)
|
| 473 |
+
# Scale distress probability to asphyxia risk
|
| 474 |
+
asphyxia_risk = min(1.0, distress_prob * 0.8)
|
| 475 |
+
else:
|
| 476 |
+
confidence = 0.7
|
| 477 |
+
asphyxia_risk = 0.5 if predicted_type in ("belly_pain", "discomfort") else 0.2
|
| 478 |
+
|
| 479 |
+
return asphyxia_risk, "HeAR + Classifier", predicted_type
|
| 480 |
+
|
| 481 |
+
# No classifier: use embedding-based heuristic
|
| 482 |
+
embedding_mean = float(np.mean(embeddings))
|
| 483 |
+
embedding_std = float(np.std(embeddings))
|
| 484 |
+
embedding_max = float(np.max(np.abs(embeddings)))
|
| 485 |
+
|
| 486 |
+
risk_score = 0.0
|
| 487 |
+
if embedding_std > 0.5:
|
| 488 |
+
risk_score += 0.3
|
| 489 |
+
if embedding_max > 2.0:
|
| 490 |
+
risk_score += 0.2
|
| 491 |
+
if abs(embedding_mean) > 0.3:
|
| 492 |
+
risk_score += 0.2
|
| 493 |
+
|
| 494 |
+
return min(risk_score, 1.0), "HeAR (uncalibrated)", None
|
| 495 |
+
|
| 496 |
+
def _analyze_with_rules(self, features: Dict) -> Tuple[float, str]:
|
| 497 |
+
"""
|
| 498 |
+
Analyze cry using rule-based acoustic features.
|
| 499 |
+
|
| 500 |
+
Fallback when HeAR is unavailable.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
features: Extracted acoustic features
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
Tuple of (asphyxia_risk, model_name)
|
| 507 |
+
"""
|
| 508 |
+
# Rule-based asphyxia risk assessment
|
| 509 |
+
# Based on medical literature on cry acoustics
|
| 510 |
+
asphyxia_indicators = 0
|
| 511 |
+
max_indicators = 5
|
| 512 |
+
|
| 513 |
+
# High F0 (> 500 Hz) is associated with asphyxia
|
| 514 |
+
if features["f0_mean"] > self.ASPHYXIA_F0_THRESHOLD:
|
| 515 |
+
asphyxia_indicators += 1
|
| 516 |
+
|
| 517 |
+
# High F0 variability
|
| 518 |
+
if features["f0_std"] > 100:
|
| 519 |
+
asphyxia_indicators += 1
|
| 520 |
+
|
| 521 |
+
# Wide F0 range
|
| 522 |
+
if features["f0_range"] > 300:
|
| 523 |
+
asphyxia_indicators += 1
|
| 524 |
+
|
| 525 |
+
# Low voiced ratio (fragmented cry)
|
| 526 |
+
if features["voiced_ratio"] < 0.3:
|
| 527 |
+
asphyxia_indicators += 1
|
| 528 |
+
|
| 529 |
+
# High zero crossing rate (irregular)
|
| 530 |
+
if features["zcr_mean"] > 0.15:
|
| 531 |
+
asphyxia_indicators += 1
|
| 532 |
+
|
| 533 |
+
asphyxia_risk = asphyxia_indicators / max_indicators
|
| 534 |
+
return asphyxia_risk, "Acoustic Features"
|
| 535 |
+
|
| 536 |
+
def _get_model_note(self, model_used: str) -> str:
|
| 537 |
+
"""Get descriptive note for the model used."""
|
| 538 |
+
notes = {
|
| 539 |
+
"HeAR + Classifier": "HAI-DEF HeAR embeddings with trained linear classifier",
|
| 540 |
+
"HeAR (uncalibrated)": "HAI-DEF HeAR embeddings with heuristic scoring (no trained classifier)",
|
| 541 |
+
"Acoustic Features": "Deterministic acoustic feature extraction (HeAR unavailable)",
|
| 542 |
+
}
|
| 543 |
+
return notes.get(model_used, model_used)
|
| 544 |
+
|
| 545 |
+
def _classify_cry_type(self, features: Dict) -> str:
|
| 546 |
+
"""
|
| 547 |
+
Classify cry type based on acoustic features.
|
| 548 |
+
|
| 549 |
+
Categories based on donate-a-cry corpus:
|
| 550 |
+
- hunger: Regular rhythm, moderate pitch
|
| 551 |
+
- pain: High pitch, irregular
|
| 552 |
+
- discomfort: Variable pitch, whimpering
|
| 553 |
+
- tired: Low energy, fragmented
|
| 554 |
+
- belly_pain: High pitch, straining patterns
|
| 555 |
+
"""
|
| 556 |
+
f0_mean = features["f0_mean"]
|
| 557 |
+
f0_std = features["f0_std"]
|
| 558 |
+
rms_mean = features["rms_mean"]
|
| 559 |
+
voiced_ratio = features["voiced_ratio"]
|
| 560 |
+
|
| 561 |
+
# Simple rule-based classification
|
| 562 |
+
if f0_mean > 500 and f0_std > 80:
|
| 563 |
+
return "pain"
|
| 564 |
+
elif f0_mean > 450 and rms_mean > 0.1:
|
| 565 |
+
return "belly_pain"
|
| 566 |
+
elif voiced_ratio < 0.4 and rms_mean < 0.05:
|
| 567 |
+
return "tired"
|
| 568 |
+
elif f0_std < 50 and voiced_ratio > 0.5:
|
| 569 |
+
return "hunger"
|
| 570 |
+
else:
|
| 571 |
+
return "discomfort"
|
| 572 |
+
|
| 573 |
+
def analyze_batch(
|
| 574 |
+
self,
|
| 575 |
+
audio_paths: List[Union[str, Path]],
|
| 576 |
+
) -> List[Dict]:
|
| 577 |
+
"""
|
| 578 |
+
Analyze multiple cry audio files.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
audio_paths: List of paths to audio files
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
List of analysis results
|
| 585 |
+
"""
|
| 586 |
+
results = []
|
| 587 |
+
for path in audio_paths:
|
| 588 |
+
try:
|
| 589 |
+
result = self.analyze(path)
|
| 590 |
+
result["file"] = str(path)
|
| 591 |
+
results.append(result)
|
| 592 |
+
except Exception as e:
|
| 593 |
+
results.append({
|
| 594 |
+
"file": str(path),
|
| 595 |
+
"error": str(e),
|
| 596 |
+
"is_abnormal": None,
|
| 597 |
+
})
|
| 598 |
+
return results
|
| 599 |
+
|
| 600 |
+
def get_spectrogram(
|
| 601 |
+
self,
|
| 602 |
+
audio_path: Union[str, Path],
|
| 603 |
+
n_mels: int = 128,
|
| 604 |
+
) -> np.ndarray:
|
| 605 |
+
"""
|
| 606 |
+
Generate mel spectrogram for visualization.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
audio_path: Path to audio file
|
| 610 |
+
n_mels: Number of mel bands
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
Mel spectrogram array (dB scale)
|
| 614 |
+
"""
|
| 615 |
+
audio, sr = self.load_audio(audio_path)
|
| 616 |
+
|
| 617 |
+
mel_spec = librosa.feature.melspectrogram(
|
| 618 |
+
y=audio,
|
| 619 |
+
sr=sr,
|
| 620 |
+
n_mels=n_mels,
|
| 621 |
+
)
|
| 622 |
+
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
|
| 623 |
+
|
| 624 |
+
return mel_spec_db
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def test_analyzer():
|
| 628 |
+
"""Test the cry analyzer with sample audio files."""
|
| 629 |
+
print("Testing Cry Analyzer...")
|
| 630 |
+
|
| 631 |
+
analyzer = CryAnalyzer()
|
| 632 |
+
|
| 633 |
+
# Check for available audio files
|
| 634 |
+
data_dirs = [
|
| 635 |
+
Path(__file__).parent.parent.parent / "data" / "raw" / "cryceleb" / "audio",
|
| 636 |
+
Path(__file__).parent.parent.parent / "data" / "raw" / "donate-a-cry",
|
| 637 |
+
Path(__file__).parent.parent.parent / "data" / "raw" / "infant-cry-dataset" / "cry",
|
| 638 |
+
]
|
| 639 |
+
|
| 640 |
+
audio_files = []
|
| 641 |
+
for data_dir in data_dirs:
|
| 642 |
+
if data_dir.exists():
|
| 643 |
+
audio_files.extend(list(data_dir.rglob("*.wav"))[:2])
|
| 644 |
+
|
| 645 |
+
if audio_files:
|
| 646 |
+
for audio_path in audio_files[:5]:
|
| 647 |
+
print(f"\nAnalyzing: {audio_path.name}")
|
| 648 |
+
try:
|
| 649 |
+
result = analyzer.analyze(audio_path)
|
| 650 |
+
print(f" Abnormal cry: {result['is_abnormal']}")
|
| 651 |
+
print(f" Asphyxia risk: {result['asphyxia_risk']:.1%}")
|
| 652 |
+
print(f" Cry type: {result['cry_type']}")
|
| 653 |
+
print(f" Risk level: {result['risk_level']}")
|
| 654 |
+
print(f" F0 mean: {result['features']['f0_mean']} Hz")
|
| 655 |
+
except Exception as e:
|
| 656 |
+
print(f" Error: {e}")
|
| 657 |
+
else:
|
| 658 |
+
print("No audio files found. Please download datasets first.")
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
if __name__ == "__main__":
|
| 662 |
+
test_analyzer()
|
src/nexus/hear_preprocessing.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HeAR Audio Preprocessing Module
|
| 3 |
+
|
| 4 |
+
Converts raw audio waveforms into mel-PCEN spectrograms required by the
|
| 5 |
+
HeAR (Health Acoustic Representations) ViT model.
|
| 6 |
+
|
| 7 |
+
Pipeline: raw audio (batch, 32000) → normalize → STFT → power spectrogram
|
| 8 |
+
→ mel filterbank (128 bins) → PCEN → resize → (batch, 1, 192, 128)
|
| 9 |
+
|
| 10 |
+
Adapted from Google's official HeAR preprocessing:
|
| 11 |
+
https://github.com/Google-Health/google-health/tree/master/health_acoustic_representations
|
| 12 |
+
|
| 13 |
+
Copyright 2025 Google LLC (original implementation)
|
| 14 |
+
Licensed under the Apache License, Version 2.0
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Callable, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _enclosing_power_of_two(value: int) -> int:
|
| 25 |
+
"""Smallest power of 2 >= value."""
|
| 26 |
+
return int(2 ** math.ceil(math.log2(value))) if value > 0 else 1
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _compute_stft(
|
| 30 |
+
signals: torch.Tensor,
|
| 31 |
+
frame_length: int,
|
| 32 |
+
frame_step: int,
|
| 33 |
+
fft_length: Optional[int] = None,
|
| 34 |
+
window_fn: Optional[Callable[[int], torch.Tensor]] = torch.hann_window,
|
| 35 |
+
pad_end: bool = True,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""Short-time Fourier Transform.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
signals: [..., samples] real-valued tensor.
|
| 41 |
+
frame_length: Window length in samples.
|
| 42 |
+
frame_step: Step size in samples.
|
| 43 |
+
fft_length: FFT size (defaults to smallest power of 2 >= frame_length).
|
| 44 |
+
window_fn: Window function (default: Hann).
|
| 45 |
+
pad_end: Pad signal end with zeros.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
[..., frames, fft_length//2 + 1] complex64 tensor.
|
| 49 |
+
"""
|
| 50 |
+
if signals.ndim < 1:
|
| 51 |
+
raise ValueError(f"Input signals must have rank >= 1, got {signals.ndim}")
|
| 52 |
+
|
| 53 |
+
if fft_length is None:
|
| 54 |
+
fft_length = _enclosing_power_of_two(frame_length)
|
| 55 |
+
|
| 56 |
+
if pad_end:
|
| 57 |
+
n_frames = (
|
| 58 |
+
math.ceil(signals.shape[-1] / frame_step)
|
| 59 |
+
if signals.shape[-1] > 0
|
| 60 |
+
else 0
|
| 61 |
+
)
|
| 62 |
+
padded_length = (
|
| 63 |
+
max(0, (n_frames - 1) * frame_step + frame_length)
|
| 64 |
+
if n_frames > 0
|
| 65 |
+
else frame_length
|
| 66 |
+
)
|
| 67 |
+
padding_needed = max(0, padded_length - signals.shape[-1])
|
| 68 |
+
if padding_needed > 0:
|
| 69 |
+
signals = F.pad(signals, (0, padding_needed))
|
| 70 |
+
|
| 71 |
+
framed_signals = signals.unfold(-1, frame_length, frame_step)
|
| 72 |
+
|
| 73 |
+
if framed_signals.shape[-2] == 0:
|
| 74 |
+
return torch.empty(
|
| 75 |
+
*signals.shape[:-1],
|
| 76 |
+
0,
|
| 77 |
+
fft_length // 2 + 1,
|
| 78 |
+
dtype=torch.complex64,
|
| 79 |
+
device=signals.device,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if window_fn is not None:
|
| 83 |
+
window = (
|
| 84 |
+
window_fn(frame_length)
|
| 85 |
+
.to(framed_signals.device)
|
| 86 |
+
.to(framed_signals.dtype)
|
| 87 |
+
)
|
| 88 |
+
framed_signals = framed_signals * window
|
| 89 |
+
|
| 90 |
+
return torch.fft.rfft(framed_signals, n=fft_length, dim=-1)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _ema(
|
| 94 |
+
inputs: torch.Tensor,
|
| 95 |
+
num_channels: int,
|
| 96 |
+
smooth_coef: float,
|
| 97 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Exponential Moving Average for PCEN smoothing.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
inputs: (batch, timesteps, channels) tensor.
|
| 103 |
+
num_channels: Number of channels.
|
| 104 |
+
smooth_coef: EMA smoothing coefficient.
|
| 105 |
+
initial_state: Optional (batch, channels) initial state.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
(batch, timesteps, channels) EMA output.
|
| 109 |
+
"""
|
| 110 |
+
batch_size, timesteps, _ = inputs.shape
|
| 111 |
+
|
| 112 |
+
if initial_state is None:
|
| 113 |
+
ema_state = torch.zeros(
|
| 114 |
+
(batch_size, num_channels), dtype=torch.float32, device=inputs.device
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
ema_state = initial_state
|
| 118 |
+
|
| 119 |
+
identity_kernel = (
|
| 120 |
+
torch.eye(num_channels, dtype=torch.float32, device=inputs.device)
|
| 121 |
+
* smooth_coef
|
| 122 |
+
)
|
| 123 |
+
identity_recurrent_kernel = (
|
| 124 |
+
torch.eye(num_channels, dtype=torch.float32, device=inputs.device)
|
| 125 |
+
* (1.0 - smooth_coef)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
output_sequence = []
|
| 129 |
+
start = initial_state is not None
|
| 130 |
+
if start:
|
| 131 |
+
output_sequence.append(ema_state)
|
| 132 |
+
|
| 133 |
+
for t in range(start, timesteps):
|
| 134 |
+
current_input = inputs[:, t, :]
|
| 135 |
+
output = torch.matmul(current_input, identity_kernel) + torch.matmul(
|
| 136 |
+
ema_state, identity_recurrent_kernel
|
| 137 |
+
)
|
| 138 |
+
ema_state = output
|
| 139 |
+
output_sequence.append(output)
|
| 140 |
+
|
| 141 |
+
return torch.stack(output_sequence, dim=1)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _pcen_function(
|
| 145 |
+
inputs: torch.Tensor,
|
| 146 |
+
num_channels: int = 128,
|
| 147 |
+
alpha: float = 0.8,
|
| 148 |
+
smooth_coef: float = 0.04,
|
| 149 |
+
delta: float = 2.0,
|
| 150 |
+
root: float = 2.0,
|
| 151 |
+
floor: float = 1e-8,
|
| 152 |
+
) -> torch.Tensor:
|
| 153 |
+
"""Per-Channel Energy Normalization.
|
| 154 |
+
|
| 155 |
+
See https://arxiv.org/abs/1607.05666
|
| 156 |
+
"""
|
| 157 |
+
alpha_param = torch.ones(num_channels).to(inputs.device).to(inputs.dtype)
|
| 158 |
+
alpha_param = alpha_param * alpha
|
| 159 |
+
delta_param = torch.ones(num_channels).to(inputs.device).to(inputs.dtype)
|
| 160 |
+
delta_param = delta_param * delta
|
| 161 |
+
root_param = torch.ones(num_channels).to(inputs.device).to(inputs.dtype)
|
| 162 |
+
root_param = root_param * root
|
| 163 |
+
|
| 164 |
+
alpha_param = torch.minimum(alpha_param, torch.ones_like(alpha_param))
|
| 165 |
+
root_param = torch.maximum(root_param, torch.ones_like(root_param))
|
| 166 |
+
|
| 167 |
+
ema_smoother = _ema(
|
| 168 |
+
inputs,
|
| 169 |
+
num_channels=num_channels,
|
| 170 |
+
smooth_coef=smooth_coef,
|
| 171 |
+
initial_state=inputs[:, 0] if inputs.ndim > 1 else None,
|
| 172 |
+
).to(inputs.device)
|
| 173 |
+
|
| 174 |
+
one_over_root = 1.0 / root_param
|
| 175 |
+
output = (
|
| 176 |
+
inputs / (floor + ema_smoother) ** alpha_param + delta_param
|
| 177 |
+
) ** one_over_root - delta_param**one_over_root
|
| 178 |
+
return output
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _hertz_to_mel(frequencies_hertz: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""Convert Hz to mel scale."""
|
| 183 |
+
return 2595.0 * torch.log10(1.0 + frequencies_hertz / 700.0)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _linear_to_mel_weight_matrix(
|
| 187 |
+
device: torch.device,
|
| 188 |
+
num_mel_bins: int = 128,
|
| 189 |
+
num_spectrogram_bins: int = 201,
|
| 190 |
+
sample_rate: float = 16000,
|
| 191 |
+
lower_edge_hertz: float = 0.0,
|
| 192 |
+
upper_edge_hertz: float = 8000.0,
|
| 193 |
+
dtype: torch.dtype = torch.float32,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
+
"""Mel filterbank matrix: [num_spectrogram_bins, num_mel_bins]."""
|
| 196 |
+
zero = torch.tensor(0.0, dtype=dtype, device=device)
|
| 197 |
+
nyquist_hertz = torch.tensor(sample_rate, dtype=dtype) / 2.0
|
| 198 |
+
lower_edge = torch.tensor(lower_edge_hertz, dtype=dtype, device=device)
|
| 199 |
+
upper_edge = torch.tensor(upper_edge_hertz, dtype=dtype, device=device)
|
| 200 |
+
|
| 201 |
+
bands_to_zero = 1
|
| 202 |
+
linear_frequencies = torch.linspace(
|
| 203 |
+
zero, nyquist_hertz, num_spectrogram_bins, dtype=dtype, device=device
|
| 204 |
+
)[bands_to_zero:]
|
| 205 |
+
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies).unsqueeze(1)
|
| 206 |
+
|
| 207 |
+
band_edges_mel = torch.linspace(
|
| 208 |
+
_hertz_to_mel(lower_edge),
|
| 209 |
+
_hertz_to_mel(upper_edge),
|
| 210 |
+
num_mel_bins + 2,
|
| 211 |
+
dtype=dtype,
|
| 212 |
+
device=device,
|
| 213 |
+
)
|
| 214 |
+
band_edges_mel = band_edges_mel.unfold(0, 3, 1)
|
| 215 |
+
|
| 216 |
+
lower_edge_mel = band_edges_mel[:, 0].unsqueeze(0)
|
| 217 |
+
center_mel = band_edges_mel[:, 1].unsqueeze(0)
|
| 218 |
+
upper_edge_mel = band_edges_mel[:, 2].unsqueeze(0)
|
| 219 |
+
|
| 220 |
+
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
|
| 221 |
+
center_mel - lower_edge_mel
|
| 222 |
+
)
|
| 223 |
+
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
|
| 224 |
+
upper_edge_mel - center_mel
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
mel_weights_matrix = torch.maximum(
|
| 228 |
+
zero, torch.minimum(lower_slopes, upper_slopes)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return F.pad(
|
| 232 |
+
mel_weights_matrix, (0, 0, bands_to_zero, 0), mode="constant", value=0.0
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _torch_resize_bilinear_tf_compat(
|
| 237 |
+
images: torch.Tensor,
|
| 238 |
+
size: tuple,
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""Bilinear resize matching TF's tf.image.resize behavior.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
images: [C, H, W] or [B, C, H, W] float tensor.
|
| 244 |
+
size: (new_height, new_width).
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Resized tensor with same rank as input.
|
| 248 |
+
"""
|
| 249 |
+
new_height, new_width = size
|
| 250 |
+
images = images.to(torch.float32)
|
| 251 |
+
|
| 252 |
+
was_3d = False
|
| 253 |
+
if images.dim() == 3:
|
| 254 |
+
images = images.unsqueeze(0)
|
| 255 |
+
was_3d = True
|
| 256 |
+
|
| 257 |
+
resized = F.interpolate(
|
| 258 |
+
images,
|
| 259 |
+
size=(new_height, new_width),
|
| 260 |
+
mode="bilinear",
|
| 261 |
+
align_corners=False,
|
| 262 |
+
antialias=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if was_3d:
|
| 266 |
+
resized = resized.squeeze(0)
|
| 267 |
+
|
| 268 |
+
return resized
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _mel_pcen(x: torch.Tensor) -> torch.Tensor:
|
| 272 |
+
"""Mel spectrogram + PCEN normalization."""
|
| 273 |
+
x = x.float()
|
| 274 |
+
# Scale to [-1, 1]
|
| 275 |
+
x -= torch.min(x)
|
| 276 |
+
x = x / (torch.max(x) + 1e-8)
|
| 277 |
+
x = (x * 2) - 1
|
| 278 |
+
|
| 279 |
+
frame_length = 16 * 25 # 400
|
| 280 |
+
frame_step = 160
|
| 281 |
+
|
| 282 |
+
stft = _compute_stft(
|
| 283 |
+
x,
|
| 284 |
+
frame_length=frame_length,
|
| 285 |
+
fft_length=frame_length,
|
| 286 |
+
frame_step=frame_step,
|
| 287 |
+
window_fn=torch.hann_window,
|
| 288 |
+
pad_end=True,
|
| 289 |
+
)
|
| 290 |
+
spectrograms = torch.square(torch.abs(stft))
|
| 291 |
+
|
| 292 |
+
mel_transform = _linear_to_mel_weight_matrix(x.device)
|
| 293 |
+
mel_spectrograms = torch.matmul(spectrograms, mel_transform)
|
| 294 |
+
return _pcen_function(mel_spectrograms)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def preprocess_audio(audio: torch.Tensor) -> torch.Tensor:
|
| 298 |
+
"""Convert raw audio waveform to mel-PCEN spectrogram for HeAR.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
audio: [batch, samples] tensor. 2-second clips at 16kHz (32000 samples).
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
[batch, 1, 192, 128] mel-PCEN spectrogram tensor.
|
| 305 |
+
"""
|
| 306 |
+
if audio.ndim != 2:
|
| 307 |
+
raise ValueError(f"Input audio must have rank 2, got rank {audio.ndim}")
|
| 308 |
+
|
| 309 |
+
if audio.shape[1] < 32000:
|
| 310 |
+
n = 32000 - audio.shape[1]
|
| 311 |
+
audio = F.pad(audio, pad=(0, n), mode="constant", value=0)
|
| 312 |
+
elif audio.shape[1] > 32000:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
f"Input audio must have <= 32000 samples, got {audio.shape[1]}"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
spectrogram = _mel_pcen(audio)
|
| 318 |
+
# Add channel dimension: [B, H, W] → [B, 1, H, W]
|
| 319 |
+
spectrogram = torch.unsqueeze(spectrogram, dim=1)
|
| 320 |
+
return _torch_resize_bilinear_tf_compat(spectrogram, size=(192, 128))
|
src/nexus/jaundice_detector.py
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Jaundice Detector Module
|
| 3 |
+
|
| 4 |
+
Uses MedSigLIP from Google HAI-DEF for jaundice detection from neonatal skin images.
|
| 5 |
+
Implements zero-shot classification with medical text prompts per NEXUS_MASTER_PLAN.md.
|
| 6 |
+
|
| 7 |
+
HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
|
| 8 |
+
Documentation: https://developers.google.com/health-ai-developer-foundations/medsiglip
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from transformers import AutoProcessor, AutoModel
|
| 21 |
+
HAS_TRANSFORMERS = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
HAS_TRANSFORMERS = False
|
| 24 |
+
|
| 25 |
+
# HAI-DEF MedSigLIP model IDs to try in order of preference
|
| 26 |
+
MEDSIGLIP_MODEL_IDS = [
|
| 27 |
+
"google/medsiglip-448", # MedSigLIP - official HAI-DEF model
|
| 28 |
+
"google/siglip-base-patch16-224", # SigLIP 224 - fallback
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class _BilirubinRegressor(nn.Module):
|
| 33 |
+
"""3-layer MLP regression head with BatchNorm for bilirubin prediction (mg/dL).
|
| 34 |
+
|
| 35 |
+
Must match the architecture in scripts/training/finetune_bilirubin_regression.py
|
| 36 |
+
so that saved state_dict keys align.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, input_dim: int = 1152, hidden_dim: int = 256):
|
| 40 |
+
super().__init__()
|
| 41 |
+
mid_dim = hidden_dim * 2 # 512
|
| 42 |
+
self.net = nn.Sequential(
|
| 43 |
+
nn.Linear(input_dim, mid_dim),
|
| 44 |
+
nn.BatchNorm1d(mid_dim),
|
| 45 |
+
nn.ReLU(),
|
| 46 |
+
nn.Dropout(0.3),
|
| 47 |
+
nn.Linear(mid_dim, hidden_dim),
|
| 48 |
+
nn.BatchNorm1d(hidden_dim),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Dropout(0.15),
|
| 51 |
+
nn.Linear(hidden_dim, 1),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
return self.net(x).squeeze(-1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class _BilirubinRegressorV1(nn.Module):
|
| 59 |
+
"""Original 2-layer MLP for backwards compatibility with older checkpoints."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, input_dim: int = 1152, hidden_dim: int = 256):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.net = nn.Sequential(
|
| 64 |
+
nn.Linear(input_dim, hidden_dim),
|
| 65 |
+
nn.ReLU(),
|
| 66 |
+
nn.Dropout(0.3),
|
| 67 |
+
nn.Linear(hidden_dim, 1),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
return self.net(x).squeeze(-1)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class JaundiceDetector:
|
| 75 |
+
"""
|
| 76 |
+
Detects neonatal jaundice from skin/sclera images using MedSigLIP.
|
| 77 |
+
|
| 78 |
+
Uses zero-shot classification with medical prompts and
|
| 79 |
+
color analysis for bilirubin estimation.
|
| 80 |
+
|
| 81 |
+
HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
|
| 82 |
+
Fallback: siglip-base-patch16-224
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# Medical text prompts for zero-shot classification (optimized for MedSigLIP)
|
| 86 |
+
# Expanded with Kramer zone references, skin-tone context, severity gradation
|
| 87 |
+
JAUNDICE_PROMPTS = [
|
| 88 |
+
"newborn with visible yellow discoloration of skin indicating jaundice",
|
| 89 |
+
"neonatal skin showing yellow-orange pigmentation from hyperbilirubinemia",
|
| 90 |
+
"jaundiced infant with icteric sclera and yellow skin tone",
|
| 91 |
+
"baby with yellow skin extending to trunk and limbs Kramer zone 3",
|
| 92 |
+
"neonatal jaundice with deep yellow skin requiring phototherapy",
|
| 93 |
+
"newborn showing yellow staining of skin and conjunctiva from bilirubin",
|
| 94 |
+
"infant with moderate to severe jaundice visible on face and chest",
|
| 95 |
+
"yellow discoloration of neonatal skin consistent with elevated bilirubin",
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
NORMAL_PROMPTS = [
|
| 99 |
+
"healthy newborn with normal pink skin color without jaundice",
|
| 100 |
+
"infant with normal skin pigmentation and no yellow discoloration",
|
| 101 |
+
"newborn baby with clear healthy skin and no icterus",
|
| 102 |
+
"normal neonatal skin showing pink to brown coloration without yellowing",
|
| 103 |
+
"healthy baby skin with no signs of hyperbilirubinemia",
|
| 104 |
+
"newborn with well-perfused normal colored skin and clear sclera",
|
| 105 |
+
"infant with healthy natural skin tone and no bilirubin staining",
|
| 106 |
+
"normal newborn skin without yellow or orange discoloration",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
# Bilirubin risk thresholds (mg/dL)
|
| 110 |
+
BILIRUBIN_THRESHOLDS = {
|
| 111 |
+
"low": 5.0, # Normal range
|
| 112 |
+
"moderate": 12.0, # Monitor closely
|
| 113 |
+
"high": 15.0, # Consider phototherapy
|
| 114 |
+
"critical": 20.0, # Urgent phototherapy
|
| 115 |
+
"exchange": 25.0, # Exchange transfusion territory
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
model_name: Optional[str] = None, # Auto-select MedSigLIP
|
| 121 |
+
device: Optional[str] = None,
|
| 122 |
+
threshold: float = 0.5,
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Initialize the Jaundice Detector with MedSigLIP.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
model_name: HuggingFace model name (auto-selects HAI-DEF MedSigLIP if None)
|
| 129 |
+
device: Device to run model on (auto-detected if None)
|
| 130 |
+
threshold: Classification threshold for jaundice detection
|
| 131 |
+
"""
|
| 132 |
+
if not HAS_TRANSFORMERS:
|
| 133 |
+
raise ImportError("transformers library required. Install with: pip install transformers")
|
| 134 |
+
|
| 135 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 136 |
+
self.threshold = threshold
|
| 137 |
+
self._model_loaded = False
|
| 138 |
+
self.classifier = None # Can be set by pipeline for trained classification
|
| 139 |
+
self.regressor = None # Bilirubin regression head (MedSigLIP embeddings -> mg/dL)
|
| 140 |
+
|
| 141 |
+
# Determine which models to try
|
| 142 |
+
models_to_try = [model_name] if model_name else MEDSIGLIP_MODEL_IDS
|
| 143 |
+
|
| 144 |
+
# HuggingFace token for gated models
|
| 145 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 146 |
+
|
| 147 |
+
# Try loading models in order of preference
|
| 148 |
+
for candidate_model in models_to_try:
|
| 149 |
+
print(f"Loading HAI-DEF model: {candidate_model}")
|
| 150 |
+
try:
|
| 151 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 152 |
+
candidate_model, token=hf_token
|
| 153 |
+
)
|
| 154 |
+
self.model = AutoModel.from_pretrained(
|
| 155 |
+
candidate_model, token=hf_token
|
| 156 |
+
).to(self.device)
|
| 157 |
+
self.model_name = candidate_model
|
| 158 |
+
self._model_loaded = True
|
| 159 |
+
print(f"Successfully loaded: {candidate_model}")
|
| 160 |
+
break
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"Warning: Could not load {candidate_model}: {e}")
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
if not self._model_loaded:
|
| 166 |
+
raise RuntimeError(
|
| 167 |
+
f"Could not load any MedSigLIP model. Tried: {models_to_try}. "
|
| 168 |
+
"Install transformers and ensure internet access."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.model.eval()
|
| 172 |
+
|
| 173 |
+
# Pre-compute text embeddings
|
| 174 |
+
self._precompute_text_embeddings()
|
| 175 |
+
|
| 176 |
+
# Try to auto-load trained classifier
|
| 177 |
+
self._auto_load_classifier()
|
| 178 |
+
|
| 179 |
+
# Try to load bilirubin regression model
|
| 180 |
+
self._load_regressor()
|
| 181 |
+
|
| 182 |
+
# Indicate which model variant is being used
|
| 183 |
+
is_medsiglip = "medsiglip" in self.model_name
|
| 184 |
+
model_type = "MedSigLIP" if is_medsiglip else "SigLIP (fallback)"
|
| 185 |
+
classifier_status = "trained classifier" if self.classifier else "zero-shot"
|
| 186 |
+
regressor_status = "with regressor" if self.regressor else "color-based only"
|
| 187 |
+
print(f"Jaundice Detector (HAI-DEF {model_type}, {classifier_status}, {regressor_status}) initialized on {self.device}")
|
| 188 |
+
|
| 189 |
+
def _auto_load_classifier(self) -> None:
|
| 190 |
+
"""Auto-load trained jaundice classifier if available."""
|
| 191 |
+
if self.classifier is not None:
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
import joblib
|
| 196 |
+
except ImportError:
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
default_paths = [
|
| 200 |
+
Path(__file__).parent.parent.parent / "models" / "linear_probes" / "jaundice_classifier.joblib",
|
| 201 |
+
Path("models/linear_probes/jaundice_classifier.joblib"),
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
for path in default_paths:
|
| 205 |
+
if path.exists():
|
| 206 |
+
try:
|
| 207 |
+
self.classifier = joblib.load(path)
|
| 208 |
+
print(f"Auto-loaded jaundice classifier from {path}")
|
| 209 |
+
return
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Warning: Could not load classifier from {path}: {e}")
|
| 212 |
+
|
| 213 |
+
# Logit temperature for softmax conversion
|
| 214 |
+
LOGIT_SCALE = 30.0
|
| 215 |
+
|
| 216 |
+
def _precompute_text_embeddings(self) -> None:
|
| 217 |
+
"""Pre-compute text embeddings for zero-shot classification using SigLIP.
|
| 218 |
+
|
| 219 |
+
Stores individual prompt embeddings for max-similarity scoring.
|
| 220 |
+
"""
|
| 221 |
+
all_prompts = self.JAUNDICE_PROMPTS + self.NORMAL_PROMPTS
|
| 222 |
+
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
inputs = self.processor(
|
| 225 |
+
text=all_prompts,
|
| 226 |
+
return_tensors="pt",
|
| 227 |
+
padding="max_length",
|
| 228 |
+
truncation=True,
|
| 229 |
+
).to(self.device)
|
| 230 |
+
|
| 231 |
+
# Get text embeddings - support multiple output APIs
|
| 232 |
+
if hasattr(self.model, 'get_text_features'):
|
| 233 |
+
text_embeddings = self.model.get_text_features(**inputs)
|
| 234 |
+
else:
|
| 235 |
+
outputs = self.model(**inputs)
|
| 236 |
+
if hasattr(outputs, 'text_embeds'):
|
| 237 |
+
text_embeddings = outputs.text_embeds
|
| 238 |
+
elif hasattr(outputs, 'text_model_output'):
|
| 239 |
+
text_embeddings = outputs.text_model_output.pooler_output
|
| 240 |
+
else:
|
| 241 |
+
text_outputs = self.model.text_model(**inputs)
|
| 242 |
+
text_embeddings = text_outputs.pooler_output
|
| 243 |
+
|
| 244 |
+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
| 245 |
+
|
| 246 |
+
# Store individual embeddings for max-similarity scoring
|
| 247 |
+
n_jaundice = len(self.JAUNDICE_PROMPTS)
|
| 248 |
+
self.jaundice_embeddings_all = text_embeddings[:n_jaundice] # (N, D)
|
| 249 |
+
self.normal_embeddings_all = text_embeddings[n_jaundice:] # (M, D)
|
| 250 |
+
|
| 251 |
+
# Also keep mean embeddings as fallback
|
| 252 |
+
self.jaundice_embeddings = self.jaundice_embeddings_all.mean(dim=0, keepdim=True)
|
| 253 |
+
self.normal_embeddings = self.normal_embeddings_all.mean(dim=0, keepdim=True)
|
| 254 |
+
self.jaundice_embeddings = self.jaundice_embeddings / self.jaundice_embeddings.norm(dim=-1, keepdim=True)
|
| 255 |
+
self.normal_embeddings = self.normal_embeddings / self.normal_embeddings.norm(dim=-1, keepdim=True)
|
| 256 |
+
|
| 257 |
+
def _load_regressor(self) -> None:
|
| 258 |
+
"""Load trained bilirubin regression head if available.
|
| 259 |
+
|
| 260 |
+
Tries the new 3-layer architecture first, falls back to V1 (2-layer).
|
| 261 |
+
"""
|
| 262 |
+
model_paths = [
|
| 263 |
+
Path(__file__).parent.parent.parent / "models" / "linear_probes" / "bilirubin_regressor.pt",
|
| 264 |
+
Path("models/linear_probes/bilirubin_regressor.pt"),
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
for model_path in model_paths:
|
| 268 |
+
if model_path.exists():
|
| 269 |
+
try:
|
| 270 |
+
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
|
| 271 |
+
input_dim = checkpoint.get("input_dim", 1152)
|
| 272 |
+
hidden_dim = checkpoint.get("hidden_dim", 256)
|
| 273 |
+
|
| 274 |
+
# Try new 3-layer architecture first, then fall back to V1
|
| 275 |
+
for RegClass in [_BilirubinRegressor, _BilirubinRegressorV1]:
|
| 276 |
+
try:
|
| 277 |
+
regressor = RegClass(input_dim, hidden_dim)
|
| 278 |
+
regressor.load_state_dict(checkpoint["model_state_dict"])
|
| 279 |
+
regressor.to(self.device)
|
| 280 |
+
regressor.eval()
|
| 281 |
+
self.regressor = regressor
|
| 282 |
+
arch = "v2 (3-layer)" if RegClass is _BilirubinRegressor else "v1 (2-layer)"
|
| 283 |
+
print(f"Bilirubin regressor ({arch}) loaded from {model_path}")
|
| 284 |
+
return
|
| 285 |
+
except (RuntimeError, KeyError):
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
print(f"Warning: Regressor checkpoint incompatible at {model_path}")
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"Warning: Could not load regressor from {model_path}: {e}")
|
| 291 |
+
self.regressor = None
|
| 292 |
+
|
| 293 |
+
def preprocess_image(self, image: Union[str, Path, Image.Image]) -> Image.Image:
|
| 294 |
+
"""Preprocess image for analysis.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
image: Path to image file or PIL Image object.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
PIL Image in RGB mode.
|
| 301 |
+
|
| 302 |
+
Raises:
|
| 303 |
+
ValueError: If the input type is unsupported.
|
| 304 |
+
FileNotFoundError: If the image file does not exist.
|
| 305 |
+
"""
|
| 306 |
+
if isinstance(image, (str, Path)):
|
| 307 |
+
path = Path(image)
|
| 308 |
+
if not path.exists():
|
| 309 |
+
raise FileNotFoundError(f"Image file not found: {path}")
|
| 310 |
+
image = Image.open(path).convert("RGB")
|
| 311 |
+
elif isinstance(image, Image.Image):
|
| 312 |
+
image = image.convert("RGB")
|
| 313 |
+
else:
|
| 314 |
+
raise ValueError(f"Expected str, Path, or PIL Image, got {type(image)}")
|
| 315 |
+
return image
|
| 316 |
+
|
| 317 |
+
def estimate_bilirubin(self, image: Union[str, Path, Image.Image]) -> float:
|
| 318 |
+
"""
|
| 319 |
+
Estimate bilirubin level from image color analysis.
|
| 320 |
+
|
| 321 |
+
This uses the yellow-blue ratio which correlates with
|
| 322 |
+
transcutaneous bilirubin measurements.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
image: Neonatal skin/sclera image
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Estimated bilirubin in mg/dL
|
| 329 |
+
"""
|
| 330 |
+
pil_image = self.preprocess_image(image)
|
| 331 |
+
img_array = np.array(pil_image).astype(float)
|
| 332 |
+
|
| 333 |
+
# Ensure 3-channel RGB
|
| 334 |
+
if img_array.ndim == 2:
|
| 335 |
+
img_array = np.stack([img_array, img_array, img_array], axis=-1)
|
| 336 |
+
elif img_array.shape[-1] == 1:
|
| 337 |
+
img_array = np.concatenate([img_array] * 3, axis=-1)
|
| 338 |
+
|
| 339 |
+
# Extract color channels
|
| 340 |
+
r = img_array[:, :, 0]
|
| 341 |
+
g = img_array[:, :, 1]
|
| 342 |
+
b = img_array[:, :, 2]
|
| 343 |
+
|
| 344 |
+
# Calculate yellow index (R+G-B correlation with bilirubin)
|
| 345 |
+
# Higher values indicate more yellow (jaundiced)
|
| 346 |
+
yellow_index = (r + g - b) / (r + g + b + 1e-6)
|
| 347 |
+
mean_yellow = np.mean(yellow_index)
|
| 348 |
+
|
| 349 |
+
# Convert to bilirubin estimate
|
| 350 |
+
# Calibrated based on medical literature
|
| 351 |
+
# Normal yellow_index ~ 0.2-0.3, jaundiced ~ 0.4-0.6
|
| 352 |
+
bilirubin_estimate = max(0, (mean_yellow - 0.2) * 50)
|
| 353 |
+
|
| 354 |
+
return round(bilirubin_estimate, 1)
|
| 355 |
+
|
| 356 |
+
def detect(self, image: Union[str, Path, Image.Image]) -> Dict:
|
| 357 |
+
"""
|
| 358 |
+
Detect jaundice from neonatal image.
|
| 359 |
+
|
| 360 |
+
Uses trained classifier if available, otherwise falls back to
|
| 361 |
+
zero-shot classification with MedSigLIP.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
image: Neonatal skin/sclera image
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Dictionary containing:
|
| 368 |
+
- has_jaundice: Boolean indicating jaundice detection
|
| 369 |
+
- confidence: Confidence score
|
| 370 |
+
- jaundice_score: Raw jaundice probability
|
| 371 |
+
- estimated_bilirubin: Estimated bilirubin (mg/dL)
|
| 372 |
+
- severity: "none", "mild", "moderate", "severe", "critical"
|
| 373 |
+
- needs_phototherapy: Boolean
|
| 374 |
+
- recommendation: Clinical recommendation
|
| 375 |
+
"""
|
| 376 |
+
pil_image = self.preprocess_image(image)
|
| 377 |
+
|
| 378 |
+
# Get image embedding using SigLIP
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
|
| 381 |
+
|
| 382 |
+
# Get image embeddings - support multiple output APIs
|
| 383 |
+
if hasattr(self.model, 'get_image_features'):
|
| 384 |
+
image_embedding = self.model.get_image_features(**inputs)
|
| 385 |
+
else:
|
| 386 |
+
outputs = self.model(**inputs)
|
| 387 |
+
if hasattr(outputs, 'image_embeds'):
|
| 388 |
+
image_embedding = outputs.image_embeds
|
| 389 |
+
elif hasattr(outputs, 'vision_model_output'):
|
| 390 |
+
image_embedding = outputs.vision_model_output.pooler_output
|
| 391 |
+
else:
|
| 392 |
+
vision_outputs = self.model.vision_model(**inputs)
|
| 393 |
+
image_embedding = vision_outputs.pooler_output
|
| 394 |
+
|
| 395 |
+
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
|
| 396 |
+
|
| 397 |
+
# Use trained classifier if available, otherwise zero-shot
|
| 398 |
+
if self.classifier is not None:
|
| 399 |
+
jaundice_prob, model_method = self._classify_with_trained_model(image_embedding)
|
| 400 |
+
else:
|
| 401 |
+
jaundice_prob, model_method = self._classify_zero_shot(image_embedding)
|
| 402 |
+
|
| 403 |
+
# Color-based bilirubin estimate (always available)
|
| 404 |
+
estimated_bilirubin = self.estimate_bilirubin(pil_image)
|
| 405 |
+
|
| 406 |
+
# ML-based bilirubin estimate from trained regressor on MedSigLIP embeddings
|
| 407 |
+
estimated_bilirubin_ml = None
|
| 408 |
+
if self.regressor is not None:
|
| 409 |
+
with torch.no_grad():
|
| 410 |
+
bilirubin_pred = self.regressor(image_embedding)
|
| 411 |
+
raw_value = float(bilirubin_pred.item())
|
| 412 |
+
# Clamp to physiologically valid range (0-35 mg/dL)
|
| 413 |
+
clamped_value = max(0.0, min(35.0, raw_value))
|
| 414 |
+
estimated_bilirubin_ml = round(clamped_value, 1)
|
| 415 |
+
|
| 416 |
+
# Use ML estimate for severity when available, otherwise color-based
|
| 417 |
+
bilirubin_for_severity = estimated_bilirubin_ml if estimated_bilirubin_ml is not None else estimated_bilirubin
|
| 418 |
+
|
| 419 |
+
# Determine severity based on bilirubin level
|
| 420 |
+
if bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["low"]:
|
| 421 |
+
severity = "none"
|
| 422 |
+
needs_phototherapy = False
|
| 423 |
+
recommendation = "No jaundice detected. Continue routine care."
|
| 424 |
+
elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["moderate"]:
|
| 425 |
+
severity = "mild"
|
| 426 |
+
needs_phototherapy = False
|
| 427 |
+
recommendation = "Mild jaundice. Monitor closely and ensure adequate feeding."
|
| 428 |
+
elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["high"]:
|
| 429 |
+
severity = "moderate"
|
| 430 |
+
needs_phototherapy = False
|
| 431 |
+
recommendation = "Moderate jaundice. Recheck in 12-24 hours. Consider phototherapy if rising."
|
| 432 |
+
elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["critical"]:
|
| 433 |
+
severity = "severe"
|
| 434 |
+
needs_phototherapy = True
|
| 435 |
+
recommendation = "URGENT: Start phototherapy. Refer for serum bilirubin confirmation."
|
| 436 |
+
else:
|
| 437 |
+
severity = "critical"
|
| 438 |
+
needs_phototherapy = True
|
| 439 |
+
recommendation = "CRITICAL: Immediate phototherapy required. Consider exchange transfusion."
|
| 440 |
+
|
| 441 |
+
is_medsiglip = "medsiglip" in self.model_name
|
| 442 |
+
base_model = "MedSigLIP (HAI-DEF)" if is_medsiglip else "SigLIP (fallback)"
|
| 443 |
+
|
| 444 |
+
result = {
|
| 445 |
+
"has_jaundice": jaundice_prob > self.threshold,
|
| 446 |
+
"confidence": max(jaundice_prob, 1 - jaundice_prob),
|
| 447 |
+
"jaundice_score": jaundice_prob,
|
| 448 |
+
"estimated_bilirubin": estimated_bilirubin,
|
| 449 |
+
"severity": severity,
|
| 450 |
+
"needs_phototherapy": needs_phototherapy,
|
| 451 |
+
"recommendation": recommendation,
|
| 452 |
+
"model": self.model_name,
|
| 453 |
+
"model_type": f"{base_model} + {model_method}",
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
if estimated_bilirubin_ml is not None:
|
| 457 |
+
result["estimated_bilirubin_ml"] = estimated_bilirubin_ml
|
| 458 |
+
result["bilirubin_method"] = "MedSigLIP Regressor"
|
| 459 |
+
else:
|
| 460 |
+
result["bilirubin_method"] = "Color Analysis"
|
| 461 |
+
|
| 462 |
+
return result
|
| 463 |
+
|
| 464 |
+
def _classify_with_trained_model(self, image_embedding: torch.Tensor) -> Tuple[float, str]:
|
| 465 |
+
"""
|
| 466 |
+
Classify using trained classifier on embeddings.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
image_embedding: Normalized image embedding from MedSigLIP
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
Tuple of (jaundice_prob, method_name)
|
| 473 |
+
"""
|
| 474 |
+
# Convert embedding to numpy for sklearn classifiers
|
| 475 |
+
embedding_np = image_embedding.cpu().numpy().reshape(1, -1)
|
| 476 |
+
|
| 477 |
+
# Handle different classifier types
|
| 478 |
+
if hasattr(self.classifier, 'predict_proba'):
|
| 479 |
+
# Sklearn classifier with probability support
|
| 480 |
+
proba = self.classifier.predict_proba(embedding_np)
|
| 481 |
+
# Assume binary: [normal, jaundice] or [jaundice, normal]
|
| 482 |
+
if proba.shape[1] >= 2:
|
| 483 |
+
# Check classifier classes to determine order
|
| 484 |
+
if hasattr(self.classifier, 'classes_'):
|
| 485 |
+
classes = list(self.classifier.classes_)
|
| 486 |
+
if 1 in classes:
|
| 487 |
+
jaundice_idx = classes.index(1)
|
| 488 |
+
else:
|
| 489 |
+
jaundice_idx = 1 # Default assumption
|
| 490 |
+
else:
|
| 491 |
+
jaundice_idx = 1
|
| 492 |
+
jaundice_prob = float(proba[0, jaundice_idx])
|
| 493 |
+
else:
|
| 494 |
+
jaundice_prob = float(proba[0, 0])
|
| 495 |
+
return jaundice_prob, "Trained Classifier"
|
| 496 |
+
|
| 497 |
+
elif hasattr(self.classifier, 'predict'):
|
| 498 |
+
# Classifier without probability - use binary prediction
|
| 499 |
+
prediction = self.classifier.predict(embedding_np)
|
| 500 |
+
jaundice_prob = float(prediction[0])
|
| 501 |
+
return jaundice_prob, "Trained Classifier (binary)"
|
| 502 |
+
|
| 503 |
+
elif isinstance(self.classifier, nn.Module):
|
| 504 |
+
# PyTorch classifier
|
| 505 |
+
self.classifier.eval()
|
| 506 |
+
with torch.no_grad():
|
| 507 |
+
logits = self.classifier(image_embedding)
|
| 508 |
+
probs = torch.softmax(logits, dim=-1)
|
| 509 |
+
if probs.shape[-1] >= 2:
|
| 510 |
+
jaundice_prob = probs[0, 1].item()
|
| 511 |
+
else:
|
| 512 |
+
jaundice_prob = probs[0, 0].item()
|
| 513 |
+
return jaundice_prob, "Trained Classifier (PyTorch)"
|
| 514 |
+
|
| 515 |
+
else:
|
| 516 |
+
# Unknown classifier type - fall back to zero-shot
|
| 517 |
+
print(f"Warning: Unknown classifier type {type(self.classifier)}, using zero-shot")
|
| 518 |
+
return self._classify_zero_shot(image_embedding)
|
| 519 |
+
|
| 520 |
+
def _classify_zero_shot(self, image_embedding: torch.Tensor) -> Tuple[float, str]:
|
| 521 |
+
"""
|
| 522 |
+
Classify using zero-shot with max-similarity scoring.
|
| 523 |
+
|
| 524 |
+
Uses the maximum cosine similarity across all prompts per class
|
| 525 |
+
for better discrimination.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
image_embedding: Normalized image embedding from MedSigLIP
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Tuple of (jaundice_prob, method_name)
|
| 532 |
+
"""
|
| 533 |
+
# Max-similarity: best-matching prompt per class
|
| 534 |
+
jaundice_sims = (image_embedding @ self.jaundice_embeddings_all.T).squeeze(0)
|
| 535 |
+
normal_sims = (image_embedding @ self.normal_embeddings_all.T).squeeze(0)
|
| 536 |
+
|
| 537 |
+
# Ensure at least 1-D for .max() to work on single-image inputs
|
| 538 |
+
if jaundice_sims.dim() == 0:
|
| 539 |
+
jaundice_sims = jaundice_sims.unsqueeze(0)
|
| 540 |
+
if normal_sims.dim() == 0:
|
| 541 |
+
normal_sims = normal_sims.unsqueeze(0)
|
| 542 |
+
|
| 543 |
+
jaundice_sim = jaundice_sims.max().item()
|
| 544 |
+
normal_sim = normal_sims.max().item()
|
| 545 |
+
|
| 546 |
+
# Convert to probabilities with tuned temperature
|
| 547 |
+
logits = torch.tensor([jaundice_sim, normal_sim]) * self.LOGIT_SCALE
|
| 548 |
+
probs = torch.softmax(logits, dim=0)
|
| 549 |
+
jaundice_prob = probs[0].item()
|
| 550 |
+
|
| 551 |
+
return jaundice_prob, "Zero-Shot"
|
| 552 |
+
|
| 553 |
+
def detect_batch(
|
| 554 |
+
self,
|
| 555 |
+
images: List[Union[str, Path, Image.Image]],
|
| 556 |
+
batch_size: int = 8,
|
| 557 |
+
) -> List[Dict]:
|
| 558 |
+
"""Detect jaundice from multiple images."""
|
| 559 |
+
results = []
|
| 560 |
+
|
| 561 |
+
for i in range(0, len(images), batch_size):
|
| 562 |
+
batch = images[i:i + batch_size]
|
| 563 |
+
pil_images = [self.preprocess_image(img) for img in batch]
|
| 564 |
+
|
| 565 |
+
with torch.no_grad():
|
| 566 |
+
inputs = self.processor(images=pil_images, return_tensors="pt", padding=True).to(self.device)
|
| 567 |
+
|
| 568 |
+
# Get image embeddings from SigLIP vision encoder
|
| 569 |
+
if hasattr(self.model, 'get_image_features'):
|
| 570 |
+
image_embeddings = self.model.get_image_features(**inputs)
|
| 571 |
+
else:
|
| 572 |
+
vision_outputs = self.model.vision_model(**inputs)
|
| 573 |
+
image_embeddings = vision_outputs.pooler_output
|
| 574 |
+
|
| 575 |
+
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
|
| 576 |
+
|
| 577 |
+
for j, (img_emb, pil_img) in enumerate(zip(image_embeddings, pil_images)):
|
| 578 |
+
img_emb = img_emb.unsqueeze(0)
|
| 579 |
+
|
| 580 |
+
# Use trained classifier if available, otherwise zero-shot
|
| 581 |
+
if self.classifier is not None:
|
| 582 |
+
jaundice_prob, model_method = self._classify_with_trained_model(img_emb)
|
| 583 |
+
else:
|
| 584 |
+
jaundice_prob, model_method = self._classify_zero_shot(img_emb)
|
| 585 |
+
|
| 586 |
+
# Color-based bilirubin
|
| 587 |
+
estimated_bilirubin = self.estimate_bilirubin(pil_img)
|
| 588 |
+
|
| 589 |
+
# ML bilirubin from regressor (consistent with detect())
|
| 590 |
+
estimated_bilirubin_ml = None
|
| 591 |
+
if self.regressor is not None:
|
| 592 |
+
with torch.no_grad():
|
| 593 |
+
bilirubin_pred = self.regressor(img_emb)
|
| 594 |
+
raw_value = float(bilirubin_pred.item())
|
| 595 |
+
estimated_bilirubin_ml = round(max(0.0, min(35.0, raw_value)), 1)
|
| 596 |
+
|
| 597 |
+
bilirubin_for_severity = estimated_bilirubin_ml if estimated_bilirubin_ml is not None else estimated_bilirubin
|
| 598 |
+
|
| 599 |
+
if bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["low"]:
|
| 600 |
+
severity, needs_phototherapy = "none", False
|
| 601 |
+
elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["moderate"]:
|
| 602 |
+
severity, needs_phototherapy = "mild", False
|
| 603 |
+
elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["high"]:
|
| 604 |
+
severity, needs_phototherapy = "moderate", False
|
| 605 |
+
elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["critical"]:
|
| 606 |
+
severity, needs_phototherapy = "severe", True
|
| 607 |
+
else:
|
| 608 |
+
severity, needs_phototherapy = "critical", True
|
| 609 |
+
|
| 610 |
+
result_item = {
|
| 611 |
+
"has_jaundice": jaundice_prob > self.threshold,
|
| 612 |
+
"confidence": max(jaundice_prob, 1 - jaundice_prob),
|
| 613 |
+
"jaundice_score": jaundice_prob,
|
| 614 |
+
"estimated_bilirubin": estimated_bilirubin,
|
| 615 |
+
"severity": severity,
|
| 616 |
+
"needs_phototherapy": needs_phototherapy,
|
| 617 |
+
}
|
| 618 |
+
if estimated_bilirubin_ml is not None:
|
| 619 |
+
result_item["estimated_bilirubin_ml"] = estimated_bilirubin_ml
|
| 620 |
+
results.append(result_item)
|
| 621 |
+
|
| 622 |
+
return results
|
| 623 |
+
|
| 624 |
+
def analyze_kramer_zones(self, image: Union[str, Path, Image.Image]) -> Dict:
|
| 625 |
+
"""
|
| 626 |
+
Analyze jaundice using Kramer's zones concept.
|
| 627 |
+
|
| 628 |
+
Kramer's zones estimate bilirubin based on cephalocaudal progression:
|
| 629 |
+
- Zone 1 (face): ~5-6 mg/dL
|
| 630 |
+
- Zone 2 (chest): ~9 mg/dL
|
| 631 |
+
- Zone 3 (abdomen): ~12 mg/dL
|
| 632 |
+
- Zone 4 (arms/legs): ~15 mg/dL
|
| 633 |
+
- Zone 5 (hands/feet): ~20+ mg/dL
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
image: Full body or partial neonatal image
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
Dictionary with zone analysis
|
| 640 |
+
"""
|
| 641 |
+
pil_image = self.preprocess_image(image)
|
| 642 |
+
img_array = np.array(pil_image).astype(float)
|
| 643 |
+
|
| 644 |
+
# Simple color-based zone estimation
|
| 645 |
+
r = img_array[:, :, 0]
|
| 646 |
+
g = img_array[:, :, 1]
|
| 647 |
+
b = img_array[:, :, 2]
|
| 648 |
+
|
| 649 |
+
yellow_index = np.mean((r + g - b) / (r + g + b + 1e-6))
|
| 650 |
+
|
| 651 |
+
# Map yellow index to Kramer zone
|
| 652 |
+
if yellow_index < 0.25:
|
| 653 |
+
zone = 0
|
| 654 |
+
zone_bilirubin = 3
|
| 655 |
+
elif yellow_index < 0.30:
|
| 656 |
+
zone = 1
|
| 657 |
+
zone_bilirubin = 6
|
| 658 |
+
elif yellow_index < 0.35:
|
| 659 |
+
zone = 2
|
| 660 |
+
zone_bilirubin = 9
|
| 661 |
+
elif yellow_index < 0.40:
|
| 662 |
+
zone = 3
|
| 663 |
+
zone_bilirubin = 12
|
| 664 |
+
elif yellow_index < 0.45:
|
| 665 |
+
zone = 4
|
| 666 |
+
zone_bilirubin = 15
|
| 667 |
+
else:
|
| 668 |
+
zone = 5
|
| 669 |
+
zone_bilirubin = 20
|
| 670 |
+
|
| 671 |
+
return {
|
| 672 |
+
"kramer_zone": zone,
|
| 673 |
+
"zone_description": self._get_zone_description(zone),
|
| 674 |
+
"estimated_bilirubin_by_zone": zone_bilirubin,
|
| 675 |
+
"yellow_index": round(yellow_index, 3),
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
def _get_zone_description(self, zone: int) -> str:
|
| 679 |
+
"""Get description for Kramer zone."""
|
| 680 |
+
descriptions = {
|
| 681 |
+
0: "No visible jaundice",
|
| 682 |
+
1: "Face and neck (Zone 1)",
|
| 683 |
+
2: "Upper trunk (Zone 2)",
|
| 684 |
+
3: "Lower trunk and thighs (Zone 3)",
|
| 685 |
+
4: "Arms and lower legs (Zone 4)",
|
| 686 |
+
5: "Hands and feet (Zone 5) - Severe",
|
| 687 |
+
}
|
| 688 |
+
return descriptions.get(zone, "Unknown")
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def test_detector():
|
| 692 |
+
"""Test the jaundice detector with sample images."""
|
| 693 |
+
print("Testing Jaundice Detector...")
|
| 694 |
+
|
| 695 |
+
detector = JaundiceDetector()
|
| 696 |
+
|
| 697 |
+
data_dir = Path(__file__).parent.parent.parent / "data" / "raw" / "neojaundice" / "images"
|
| 698 |
+
|
| 699 |
+
if data_dir.exists():
|
| 700 |
+
sample_images = list(data_dir.glob("*.jpg"))[:3]
|
| 701 |
+
|
| 702 |
+
for img_path in sample_images:
|
| 703 |
+
print(f"\nAnalyzing: {img_path.name}")
|
| 704 |
+
result = detector.detect(img_path)
|
| 705 |
+
print(f" Jaundice detected: {result['has_jaundice']}")
|
| 706 |
+
print(f" Confidence: {result['confidence']:.2%}")
|
| 707 |
+
print(f" Estimated bilirubin: {result['estimated_bilirubin']} mg/dL")
|
| 708 |
+
print(f" Severity: {result['severity']}")
|
| 709 |
+
print(f" Needs phototherapy: {result['needs_phototherapy']}")
|
| 710 |
+
print(f" Recommendation: {result['recommendation']}")
|
| 711 |
+
else:
|
| 712 |
+
print(f"Dataset not found at {data_dir}")
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
if __name__ == "__main__":
|
| 716 |
+
test_detector()
|
src/nexus/pipeline.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NEXUS Pipeline Module
|
| 3 |
+
|
| 4 |
+
Integrates all detection modules into a unified diagnostic pipeline
|
| 5 |
+
for maternal-neonatal care.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional, Union
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class PatientInfo:
|
| 17 |
+
"""Patient information for context."""
|
| 18 |
+
patient_id: str
|
| 19 |
+
age_days: Optional[int] = None # For neonates
|
| 20 |
+
gestational_age: Optional[int] = None # Weeks
|
| 21 |
+
birth_weight: Optional[int] = None # Grams
|
| 22 |
+
gender: Optional[str] = None
|
| 23 |
+
is_maternal: bool = False # True for mother, False for neonate
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class AssessmentResult:
|
| 28 |
+
"""Complete assessment result."""
|
| 29 |
+
patient: PatientInfo
|
| 30 |
+
timestamp: str
|
| 31 |
+
anemia_result: Optional[Dict] = None
|
| 32 |
+
jaundice_result: Optional[Dict] = None
|
| 33 |
+
cry_result: Optional[Dict] = None
|
| 34 |
+
overall_risk: str = "unknown"
|
| 35 |
+
priority_actions: List[str] = None
|
| 36 |
+
referral_needed: bool = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class NEXUSPipeline:
|
| 40 |
+
"""
|
| 41 |
+
NEXUS Integrated Diagnostic Pipeline
|
| 42 |
+
|
| 43 |
+
Combines anemia, jaundice, and cry analysis into a unified
|
| 44 |
+
assessment workflow for maternal-neonatal care.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# Default paths for trained model checkpoints
|
| 48 |
+
DEFAULT_CHECKPOINT_DIR = Path(__file__).parent.parent.parent / "models" / "checkpoints"
|
| 49 |
+
DEFAULT_LINEAR_PROBE_DIR = Path(__file__).parent.parent.parent / "models" / "linear_probes"
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
device: Optional[str] = None,
|
| 54 |
+
lazy_load: bool = True,
|
| 55 |
+
anemia_checkpoint: Optional[Union[str, Path]] = None,
|
| 56 |
+
jaundice_checkpoint: Optional[Union[str, Path]] = None,
|
| 57 |
+
cry_checkpoint: Optional[Union[str, Path]] = None,
|
| 58 |
+
use_linear_probes: bool = True,
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Initialize NEXUS Pipeline.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
device: Device for model inference
|
| 65 |
+
lazy_load: If True, load models only when needed
|
| 66 |
+
anemia_checkpoint: Path to trained anemia classifier checkpoint
|
| 67 |
+
jaundice_checkpoint: Path to trained jaundice classifier checkpoint
|
| 68 |
+
cry_checkpoint: Path to trained cry classifier checkpoint
|
| 69 |
+
use_linear_probes: If True, auto-load linear probes from default dir
|
| 70 |
+
"""
|
| 71 |
+
self.device = device
|
| 72 |
+
self.lazy_load = lazy_load
|
| 73 |
+
|
| 74 |
+
# Store checkpoint paths
|
| 75 |
+
self.anemia_checkpoint = anemia_checkpoint
|
| 76 |
+
self.jaundice_checkpoint = jaundice_checkpoint
|
| 77 |
+
self.cry_checkpoint = cry_checkpoint
|
| 78 |
+
|
| 79 |
+
# Auto-detect checkpoints from default locations
|
| 80 |
+
if use_linear_probes:
|
| 81 |
+
self._auto_detect_checkpoints()
|
| 82 |
+
|
| 83 |
+
self._anemia_detector = None
|
| 84 |
+
self._jaundice_detector = None
|
| 85 |
+
self._cry_analyzer = None
|
| 86 |
+
|
| 87 |
+
if not lazy_load:
|
| 88 |
+
self._load_all_models()
|
| 89 |
+
|
| 90 |
+
print("NEXUS Pipeline initialized")
|
| 91 |
+
|
| 92 |
+
def verify_hai_def_compliance(self) -> Dict:
|
| 93 |
+
"""
|
| 94 |
+
Verify which HAI-DEF models are loaded and report compliance.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Dictionary with model status and compliance flag.
|
| 98 |
+
"""
|
| 99 |
+
from .anemia_detector import MEDSIGLIP_MODEL_IDS
|
| 100 |
+
from .cry_analyzer import CryAnalyzer
|
| 101 |
+
|
| 102 |
+
status = {
|
| 103 |
+
"medsiglip": {
|
| 104 |
+
"expected": "google/medsiglip-448",
|
| 105 |
+
"configured_models": MEDSIGLIP_MODEL_IDS,
|
| 106 |
+
"anemia_loaded": self._anemia_detector is not None,
|
| 107 |
+
"jaundice_loaded": self._jaundice_detector is not None,
|
| 108 |
+
},
|
| 109 |
+
"hear": {
|
| 110 |
+
"expected": CryAnalyzer.HEAR_MODEL_ID,
|
| 111 |
+
"cry_loaded": self._cry_analyzer is not None,
|
| 112 |
+
"hear_active": getattr(self._cry_analyzer, '_hear_available', False) if self._cry_analyzer else False,
|
| 113 |
+
},
|
| 114 |
+
"medgemma": {
|
| 115 |
+
"expected": "google/medgemma-4b-it",
|
| 116 |
+
},
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# Check loaded model names
|
| 120 |
+
if self._anemia_detector:
|
| 121 |
+
status["medsiglip"]["anemia_model"] = getattr(self._anemia_detector, 'model_name', 'unknown')
|
| 122 |
+
if self._jaundice_detector:
|
| 123 |
+
status["medsiglip"]["jaundice_model"] = getattr(self._jaundice_detector, 'model_name', 'unknown')
|
| 124 |
+
|
| 125 |
+
# Overall compliance
|
| 126 |
+
anemia_ok = "medsiglip" in status["medsiglip"].get("anemia_model", "")
|
| 127 |
+
jaundice_ok = "medsiglip" in status["medsiglip"].get("jaundice_model", "")
|
| 128 |
+
hear_ok = status["hear"]["hear_active"]
|
| 129 |
+
|
| 130 |
+
status["compliant"] = anemia_ok or jaundice_ok or hear_ok
|
| 131 |
+
status["all_hai_def"] = anemia_ok and jaundice_ok and hear_ok
|
| 132 |
+
|
| 133 |
+
return status
|
| 134 |
+
|
| 135 |
+
def _auto_detect_checkpoints(self) -> None:
|
| 136 |
+
"""Auto-detect trained checkpoints from default directories."""
|
| 137 |
+
# Check for linear probes (.joblib sklearn models)
|
| 138 |
+
if self.anemia_checkpoint is None:
|
| 139 |
+
anemia_probe = self.DEFAULT_LINEAR_PROBE_DIR / "anemia_linear_probe.joblib"
|
| 140 |
+
if anemia_probe.exists():
|
| 141 |
+
self.anemia_checkpoint = anemia_probe
|
| 142 |
+
print(f"Auto-detected anemia probe: {anemia_probe}")
|
| 143 |
+
|
| 144 |
+
if self.jaundice_checkpoint is None:
|
| 145 |
+
jaundice_probe = self.DEFAULT_LINEAR_PROBE_DIR / "jaundice_linear_probe.joblib"
|
| 146 |
+
if jaundice_probe.exists():
|
| 147 |
+
self.jaundice_checkpoint = jaundice_probe
|
| 148 |
+
print(f"Auto-detected jaundice probe: {jaundice_probe}")
|
| 149 |
+
|
| 150 |
+
if self.cry_checkpoint is None:
|
| 151 |
+
cry_probe = self.DEFAULT_LINEAR_PROBE_DIR / "cry_linear_probe.joblib"
|
| 152 |
+
if cry_probe.exists():
|
| 153 |
+
self.cry_checkpoint = cry_probe
|
| 154 |
+
print(f"Auto-detected cry probe: {cry_probe}")
|
| 155 |
+
|
| 156 |
+
# Also check checkpoint dir for full fine-tuned models
|
| 157 |
+
if self.anemia_checkpoint is None:
|
| 158 |
+
anemia_best = self.DEFAULT_CHECKPOINT_DIR / "anemia_best.pt"
|
| 159 |
+
if anemia_best.exists():
|
| 160 |
+
self.anemia_checkpoint = anemia_best
|
| 161 |
+
print(f"Auto-detected anemia checkpoint: {anemia_best}")
|
| 162 |
+
|
| 163 |
+
def _load_all_models(self) -> None:
|
| 164 |
+
"""Load all detection models."""
|
| 165 |
+
self._get_anemia_detector()
|
| 166 |
+
self._get_jaundice_detector()
|
| 167 |
+
self._get_cry_analyzer()
|
| 168 |
+
|
| 169 |
+
def _get_anemia_detector(self):
|
| 170 |
+
"""Get or create anemia detector with optional trained classifier."""
|
| 171 |
+
if self._anemia_detector is None:
|
| 172 |
+
from .anemia_detector import AnemiaDetector
|
| 173 |
+
|
| 174 |
+
# Initialize detector
|
| 175 |
+
self._anemia_detector = AnemiaDetector(device=self.device)
|
| 176 |
+
|
| 177 |
+
# Load trained classifier if available
|
| 178 |
+
if self.anemia_checkpoint:
|
| 179 |
+
self._load_classifier_checkpoint(
|
| 180 |
+
self._anemia_detector,
|
| 181 |
+
self.anemia_checkpoint,
|
| 182 |
+
"anemia"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return self._anemia_detector
|
| 186 |
+
|
| 187 |
+
def _get_jaundice_detector(self):
|
| 188 |
+
"""Get or create jaundice detector with optional trained classifier."""
|
| 189 |
+
if self._jaundice_detector is None:
|
| 190 |
+
from .jaundice_detector import JaundiceDetector
|
| 191 |
+
|
| 192 |
+
self._jaundice_detector = JaundiceDetector(device=self.device)
|
| 193 |
+
|
| 194 |
+
# Load trained classifier if available
|
| 195 |
+
if self.jaundice_checkpoint:
|
| 196 |
+
self._load_classifier_checkpoint(
|
| 197 |
+
self._jaundice_detector,
|
| 198 |
+
self.jaundice_checkpoint,
|
| 199 |
+
"jaundice"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return self._jaundice_detector
|
| 203 |
+
|
| 204 |
+
def _get_cry_analyzer(self):
|
| 205 |
+
"""Get or create cry analyzer with optional trained classifier."""
|
| 206 |
+
if self._cry_analyzer is None:
|
| 207 |
+
from .cry_analyzer import CryAnalyzer
|
| 208 |
+
|
| 209 |
+
# Cry analyzer supports classifier_path directly
|
| 210 |
+
classifier_path = str(self.cry_checkpoint) if self.cry_checkpoint else None
|
| 211 |
+
self._cry_analyzer = CryAnalyzer(
|
| 212 |
+
device=self.device,
|
| 213 |
+
classifier_path=classifier_path
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return self._cry_analyzer
|
| 217 |
+
|
| 218 |
+
def _load_classifier_checkpoint(
|
| 219 |
+
self,
|
| 220 |
+
detector,
|
| 221 |
+
checkpoint_path: Union[str, Path],
|
| 222 |
+
model_type: str
|
| 223 |
+
) -> None:
|
| 224 |
+
"""
|
| 225 |
+
Load a trained classifier checkpoint into a detector.
|
| 226 |
+
|
| 227 |
+
Supports both linear probes (sklearn) and PyTorch checkpoints.
|
| 228 |
+
"""
|
| 229 |
+
import torch
|
| 230 |
+
|
| 231 |
+
checkpoint_path = Path(checkpoint_path)
|
| 232 |
+
if not checkpoint_path.exists():
|
| 233 |
+
print(f"Warning: {model_type} checkpoint not found: {checkpoint_path}")
|
| 234 |
+
return
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
# Check if it's a sklearn model (joblib)
|
| 238 |
+
if checkpoint_path.suffix in ['.pkl', '.joblib']:
|
| 239 |
+
import joblib
|
| 240 |
+
classifier = joblib.load(checkpoint_path)
|
| 241 |
+
detector.classifier = classifier
|
| 242 |
+
print(f"Loaded sklearn classifier for {model_type}")
|
| 243 |
+
|
| 244 |
+
# Check if it's a PyTorch model
|
| 245 |
+
elif checkpoint_path.suffix == '.pt':
|
| 246 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device or 'cpu')
|
| 247 |
+
|
| 248 |
+
# Handle different checkpoint formats
|
| 249 |
+
if 'classifier' in checkpoint:
|
| 250 |
+
# Linear probe format
|
| 251 |
+
detector.classifier = checkpoint['classifier']
|
| 252 |
+
print(f"Loaded linear probe for {model_type}")
|
| 253 |
+
elif 'model_state_dict' in checkpoint:
|
| 254 |
+
# Full model checkpoint - would need separate handling
|
| 255 |
+
print(f"Note: Full model checkpoint for {model_type} - using zero-shot")
|
| 256 |
+
else:
|
| 257 |
+
print(f"Unknown checkpoint format for {model_type}")
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Warning: Could not load {model_type} checkpoint: {e}")
|
| 261 |
+
|
| 262 |
+
def assess_maternal(
|
| 263 |
+
self,
|
| 264 |
+
patient: PatientInfo,
|
| 265 |
+
conjunctiva_image: Optional[Union[str, Path]] = None,
|
| 266 |
+
) -> AssessmentResult:
|
| 267 |
+
"""
|
| 268 |
+
Perform maternal health assessment.
|
| 269 |
+
|
| 270 |
+
Currently focuses on anemia detection via conjunctiva imaging.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
patient: Patient information
|
| 274 |
+
conjunctiva_image: Path to conjunctiva image
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
AssessmentResult with findings
|
| 278 |
+
"""
|
| 279 |
+
result = AssessmentResult(
|
| 280 |
+
patient=patient,
|
| 281 |
+
timestamp=datetime.now().isoformat(),
|
| 282 |
+
priority_actions=[],
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Anemia detection
|
| 286 |
+
if conjunctiva_image:
|
| 287 |
+
detector = self._get_anemia_detector()
|
| 288 |
+
result.anemia_result = detector.detect(conjunctiva_image)
|
| 289 |
+
|
| 290 |
+
# Add color analysis
|
| 291 |
+
color_info = detector.analyze_color_features(conjunctiva_image)
|
| 292 |
+
result.anemia_result["color_analysis"] = color_info
|
| 293 |
+
|
| 294 |
+
# Determine actions
|
| 295 |
+
if result.anemia_result["risk_level"] == "high":
|
| 296 |
+
result.priority_actions.append("URGENT: Refer for blood test - suspected severe anemia")
|
| 297 |
+
result.referral_needed = True
|
| 298 |
+
result.overall_risk = "high"
|
| 299 |
+
elif result.anemia_result["risk_level"] == "medium":
|
| 300 |
+
result.priority_actions.append("Schedule blood test within 48 hours")
|
| 301 |
+
result.overall_risk = "medium"
|
| 302 |
+
else:
|
| 303 |
+
result.overall_risk = "low"
|
| 304 |
+
|
| 305 |
+
return result
|
| 306 |
+
|
| 307 |
+
def assess_neonate(
|
| 308 |
+
self,
|
| 309 |
+
patient: PatientInfo,
|
| 310 |
+
skin_image: Optional[Union[str, Path]] = None,
|
| 311 |
+
cry_audio: Optional[Union[str, Path]] = None,
|
| 312 |
+
) -> AssessmentResult:
|
| 313 |
+
"""
|
| 314 |
+
Perform neonatal health assessment.
|
| 315 |
+
|
| 316 |
+
Includes jaundice detection and cry analysis.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
patient: Patient information
|
| 320 |
+
skin_image: Path to skin/sclera image for jaundice
|
| 321 |
+
cry_audio: Path to cry audio file
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
AssessmentResult with findings
|
| 325 |
+
"""
|
| 326 |
+
result = AssessmentResult(
|
| 327 |
+
patient=patient,
|
| 328 |
+
timestamp=datetime.now().isoformat(),
|
| 329 |
+
priority_actions=[],
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
risk_scores = []
|
| 333 |
+
|
| 334 |
+
# Jaundice detection
|
| 335 |
+
if skin_image:
|
| 336 |
+
detector = self._get_jaundice_detector()
|
| 337 |
+
result.jaundice_result = detector.detect(skin_image)
|
| 338 |
+
|
| 339 |
+
# Add zone analysis
|
| 340 |
+
zone_info = detector.analyze_kramer_zones(skin_image)
|
| 341 |
+
result.jaundice_result["zone_analysis"] = zone_info
|
| 342 |
+
|
| 343 |
+
if result.jaundice_result["severity"] == "critical":
|
| 344 |
+
result.priority_actions.insert(0, "CRITICAL: Immediate phototherapy required")
|
| 345 |
+
result.referral_needed = True
|
| 346 |
+
risk_scores.append(1.0)
|
| 347 |
+
elif result.jaundice_result["severity"] == "severe":
|
| 348 |
+
result.priority_actions.append("URGENT: Start phototherapy")
|
| 349 |
+
result.referral_needed = True
|
| 350 |
+
risk_scores.append(0.8)
|
| 351 |
+
elif result.jaundice_result["severity"] == "moderate":
|
| 352 |
+
result.priority_actions.append("Monitor closely, recheck in 12-24 hours")
|
| 353 |
+
risk_scores.append(0.5)
|
| 354 |
+
else:
|
| 355 |
+
risk_scores.append(0.2)
|
| 356 |
+
|
| 357 |
+
# Cry analysis
|
| 358 |
+
if cry_audio:
|
| 359 |
+
analyzer = self._get_cry_analyzer()
|
| 360 |
+
result.cry_result = analyzer.analyze(cry_audio)
|
| 361 |
+
|
| 362 |
+
if result.cry_result["risk_level"] == "high":
|
| 363 |
+
result.priority_actions.insert(0, "URGENT: Abnormal cry - assess for birth asphyxia")
|
| 364 |
+
result.referral_needed = True
|
| 365 |
+
risk_scores.append(1.0)
|
| 366 |
+
elif result.cry_result["risk_level"] == "medium":
|
| 367 |
+
result.priority_actions.append("Monitor cry patterns, reassess in 30 minutes")
|
| 368 |
+
risk_scores.append(0.5)
|
| 369 |
+
else:
|
| 370 |
+
risk_scores.append(0.2)
|
| 371 |
+
|
| 372 |
+
# Determine overall risk
|
| 373 |
+
if risk_scores:
|
| 374 |
+
max_risk = max(risk_scores)
|
| 375 |
+
if max_risk >= 0.8:
|
| 376 |
+
result.overall_risk = "high"
|
| 377 |
+
elif max_risk >= 0.5:
|
| 378 |
+
result.overall_risk = "medium"
|
| 379 |
+
else:
|
| 380 |
+
result.overall_risk = "low"
|
| 381 |
+
|
| 382 |
+
return result
|
| 383 |
+
|
| 384 |
+
def agentic_assessment(
|
| 385 |
+
self,
|
| 386 |
+
patient_type: str = "newborn",
|
| 387 |
+
conjunctiva_image: Optional[Union[str, Path]] = None,
|
| 388 |
+
skin_image: Optional[Union[str, Path]] = None,
|
| 389 |
+
cry_audio: Optional[Union[str, Path]] = None,
|
| 390 |
+
danger_signs: Optional[List[Dict]] = None,
|
| 391 |
+
patient_info: Optional[Dict] = None,
|
| 392 |
+
) -> Dict:
|
| 393 |
+
"""
|
| 394 |
+
Run the full agentic clinical workflow with 6 specialized agents.
|
| 395 |
+
|
| 396 |
+
This provides richer output than full_assessment() — each agent emits
|
| 397 |
+
step-by-step reasoning traces forming a complete audit trail.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
patient_type: "pregnant" or "newborn"
|
| 401 |
+
conjunctiva_image: Path to conjunctiva image for anemia screening
|
| 402 |
+
skin_image: Path to skin image for jaundice detection
|
| 403 |
+
cry_audio: Path to cry audio for asphyxia detection
|
| 404 |
+
danger_signs: List of danger sign dicts with keys: id, label, severity, present
|
| 405 |
+
patient_info: Patient information dict
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
Dict with workflow result including agent_traces list
|
| 409 |
+
"""
|
| 410 |
+
from .agentic_workflow import (
|
| 411 |
+
AgenticWorkflowEngine,
|
| 412 |
+
AgentPatientInfo,
|
| 413 |
+
DangerSign,
|
| 414 |
+
WorkflowInput,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# Build patient info
|
| 418 |
+
info = AgentPatientInfo(patient_type=patient_type)
|
| 419 |
+
if patient_info:
|
| 420 |
+
info.patient_id = patient_info.get("patient_id", "")
|
| 421 |
+
info.gestational_weeks = patient_info.get("gestational_weeks")
|
| 422 |
+
info.birth_weight = patient_info.get("birth_weight")
|
| 423 |
+
info.apgar_score = patient_info.get("apgar_score")
|
| 424 |
+
info.age_hours = patient_info.get("age_hours")
|
| 425 |
+
|
| 426 |
+
# Build danger signs
|
| 427 |
+
signs = []
|
| 428 |
+
if danger_signs:
|
| 429 |
+
for s in danger_signs:
|
| 430 |
+
signs.append(DangerSign(
|
| 431 |
+
id=s.get("id", ""),
|
| 432 |
+
label=s.get("label", ""),
|
| 433 |
+
severity=s.get("severity", "medium"),
|
| 434 |
+
present=s.get("present", True),
|
| 435 |
+
))
|
| 436 |
+
|
| 437 |
+
workflow_input = WorkflowInput(
|
| 438 |
+
patient_type=patient_type,
|
| 439 |
+
patient_info=info,
|
| 440 |
+
danger_signs=signs,
|
| 441 |
+
conjunctiva_image=conjunctiva_image,
|
| 442 |
+
skin_image=skin_image,
|
| 443 |
+
cry_audio=cry_audio,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Create engine with existing model instances to avoid reloading
|
| 447 |
+
engine = AgenticWorkflowEngine(
|
| 448 |
+
anemia_detector=self._anemia_detector,
|
| 449 |
+
jaundice_detector=self._jaundice_detector,
|
| 450 |
+
cry_analyzer=self._cry_analyzer,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
result = engine.execute(workflow_input)
|
| 454 |
+
|
| 455 |
+
# Serialize to dict
|
| 456 |
+
return {
|
| 457 |
+
"success": result.success,
|
| 458 |
+
"patient_type": result.patient_type,
|
| 459 |
+
"who_classification": result.who_classification,
|
| 460 |
+
"clinical_synthesis": result.clinical_synthesis,
|
| 461 |
+
"recommendation": result.recommendation,
|
| 462 |
+
"immediate_actions": result.immediate_actions,
|
| 463 |
+
"processing_time_ms": result.processing_time_ms,
|
| 464 |
+
"timestamp": result.timestamp,
|
| 465 |
+
"triage": {
|
| 466 |
+
"risk_level": result.triage_result.risk_level,
|
| 467 |
+
"score": result.triage_result.score,
|
| 468 |
+
"critical_signs": result.triage_result.critical_signs,
|
| 469 |
+
"immediate_referral": result.triage_result.immediate_referral_needed,
|
| 470 |
+
} if result.triage_result else None,
|
| 471 |
+
"referral": {
|
| 472 |
+
"referral_needed": result.referral_result.referral_needed,
|
| 473 |
+
"urgency": result.referral_result.urgency,
|
| 474 |
+
"facility_level": result.referral_result.facility_level,
|
| 475 |
+
"reason": result.referral_result.reason,
|
| 476 |
+
"timeframe": result.referral_result.timeframe,
|
| 477 |
+
} if result.referral_result else None,
|
| 478 |
+
"protocol": {
|
| 479 |
+
"classification": result.protocol_result.classification,
|
| 480 |
+
"applicable_protocols": result.protocol_result.applicable_protocols,
|
| 481 |
+
"treatment_recommendations": result.protocol_result.treatment_recommendations,
|
| 482 |
+
"follow_up_schedule": result.protocol_result.follow_up_schedule,
|
| 483 |
+
} if result.protocol_result else None,
|
| 484 |
+
"agent_traces": [
|
| 485 |
+
{
|
| 486 |
+
"agent_name": t.agent_name,
|
| 487 |
+
"status": t.status,
|
| 488 |
+
"reasoning": t.reasoning,
|
| 489 |
+
"findings": t.findings,
|
| 490 |
+
"confidence": t.confidence,
|
| 491 |
+
"processing_time_ms": t.processing_time_ms,
|
| 492 |
+
}
|
| 493 |
+
for t in result.agent_traces
|
| 494 |
+
],
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
def full_assessment(
|
| 498 |
+
self,
|
| 499 |
+
patient: PatientInfo,
|
| 500 |
+
conjunctiva_image: Optional[Union[str, Path]] = None,
|
| 501 |
+
skin_image: Optional[Union[str, Path]] = None,
|
| 502 |
+
cry_audio: Optional[Union[str, Path]] = None,
|
| 503 |
+
) -> AssessmentResult:
|
| 504 |
+
"""
|
| 505 |
+
Perform full assessment (maternal or neonatal based on patient info).
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
patient: Patient information
|
| 509 |
+
conjunctiva_image: For maternal anemia screening
|
| 510 |
+
skin_image: For neonatal jaundice detection
|
| 511 |
+
cry_audio: For neonatal cry analysis
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
Complete AssessmentResult
|
| 515 |
+
"""
|
| 516 |
+
if patient.is_maternal:
|
| 517 |
+
return self.assess_maternal(patient, conjunctiva_image)
|
| 518 |
+
else:
|
| 519 |
+
return self.assess_neonate(patient, skin_image, cry_audio)
|
| 520 |
+
|
| 521 |
+
def generate_report(self, result: AssessmentResult) -> str:
|
| 522 |
+
"""
|
| 523 |
+
Generate a text report from assessment result.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
result: AssessmentResult from assessment
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
Formatted report string
|
| 530 |
+
"""
|
| 531 |
+
lines = [
|
| 532 |
+
"=" * 60,
|
| 533 |
+
"NEXUS HEALTH ASSESSMENT REPORT",
|
| 534 |
+
"=" * 60,
|
| 535 |
+
"",
|
| 536 |
+
f"Patient ID: {result.patient.patient_id}",
|
| 537 |
+
f"Assessment Time: {result.timestamp}",
|
| 538 |
+
f"Patient Type: {'Maternal' if result.patient.is_maternal else 'Neonatal'}",
|
| 539 |
+
"",
|
| 540 |
+
]
|
| 541 |
+
|
| 542 |
+
if result.patient.age_days is not None:
|
| 543 |
+
lines.append(f"Age: {result.patient.age_days} days")
|
| 544 |
+
if result.patient.gestational_age is not None:
|
| 545 |
+
lines.append(f"Gestational Age: {result.patient.gestational_age} weeks")
|
| 546 |
+
if result.patient.birth_weight is not None:
|
| 547 |
+
lines.append(f"Birth Weight: {result.patient.birth_weight} grams")
|
| 548 |
+
|
| 549 |
+
lines.extend(["", "-" * 60, "FINDINGS", "-" * 60, ""])
|
| 550 |
+
|
| 551 |
+
# Anemia findings
|
| 552 |
+
if result.anemia_result:
|
| 553 |
+
lines.extend([
|
| 554 |
+
"ANEMIA SCREENING:",
|
| 555 |
+
f" Status: {'ANEMIC' if result.anemia_result['is_anemic'] else 'Normal'}",
|
| 556 |
+
f" Confidence: {result.anemia_result['confidence']:.1%}",
|
| 557 |
+
f" Risk Level: {result.anemia_result['risk_level'].upper()}",
|
| 558 |
+
"",
|
| 559 |
+
])
|
| 560 |
+
|
| 561 |
+
# Jaundice findings
|
| 562 |
+
if result.jaundice_result:
|
| 563 |
+
lines.extend([
|
| 564 |
+
"JAUNDICE ASSESSMENT:",
|
| 565 |
+
f" Status: {'JAUNDICE DETECTED' if result.jaundice_result['has_jaundice'] else 'Normal'}",
|
| 566 |
+
f" Estimated Bilirubin: {result.jaundice_result['estimated_bilirubin']} mg/dL",
|
| 567 |
+
f" Severity: {result.jaundice_result['severity'].upper()}",
|
| 568 |
+
f" Phototherapy Needed: {'YES' if result.jaundice_result['needs_phototherapy'] else 'No'}",
|
| 569 |
+
"",
|
| 570 |
+
])
|
| 571 |
+
|
| 572 |
+
# Cry analysis findings
|
| 573 |
+
if result.cry_result:
|
| 574 |
+
lines.extend([
|
| 575 |
+
"CRY ANALYSIS:",
|
| 576 |
+
f" Status: {'ABNORMAL' if result.cry_result['is_abnormal'] else 'Normal'}",
|
| 577 |
+
f" Asphyxia Risk: {result.cry_result['asphyxia_risk']:.1%}",
|
| 578 |
+
f" Cry Type: {result.cry_result['cry_type']}",
|
| 579 |
+
f" Risk Level: {result.cry_result['risk_level'].upper()}",
|
| 580 |
+
"",
|
| 581 |
+
])
|
| 582 |
+
|
| 583 |
+
lines.extend(["-" * 60, "OVERALL ASSESSMENT", "-" * 60, ""])
|
| 584 |
+
lines.append(f"Overall Risk Level: {result.overall_risk.upper()}")
|
| 585 |
+
lines.append(f"Referral Needed: {'YES' if result.referral_needed else 'No'}")
|
| 586 |
+
|
| 587 |
+
if result.priority_actions:
|
| 588 |
+
lines.extend(["", "PRIORITY ACTIONS:"])
|
| 589 |
+
for i, action in enumerate(result.priority_actions, 1):
|
| 590 |
+
lines.append(f" {i}. {action}")
|
| 591 |
+
|
| 592 |
+
lines.extend(["", "=" * 60])
|
| 593 |
+
|
| 594 |
+
return "\n".join(lines)
|
| 595 |
+
|
| 596 |
+
def to_json(self, result: AssessmentResult) -> str:
|
| 597 |
+
"""Convert assessment result to JSON string."""
|
| 598 |
+
data = {
|
| 599 |
+
"patient": {
|
| 600 |
+
"patient_id": result.patient.patient_id,
|
| 601 |
+
"age_days": result.patient.age_days,
|
| 602 |
+
"gestational_age": result.patient.gestational_age,
|
| 603 |
+
"birth_weight": result.patient.birth_weight,
|
| 604 |
+
"gender": result.patient.gender,
|
| 605 |
+
"is_maternal": result.patient.is_maternal,
|
| 606 |
+
},
|
| 607 |
+
"timestamp": result.timestamp,
|
| 608 |
+
"anemia_result": result.anemia_result,
|
| 609 |
+
"jaundice_result": result.jaundice_result,
|
| 610 |
+
"cry_result": result.cry_result,
|
| 611 |
+
"overall_risk": result.overall_risk,
|
| 612 |
+
"priority_actions": result.priority_actions,
|
| 613 |
+
"referral_needed": result.referral_needed,
|
| 614 |
+
}
|
| 615 |
+
return json.dumps(data, indent=2)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def demo():
|
| 619 |
+
"""Demo the NEXUS pipeline."""
|
| 620 |
+
print("NEXUS Pipeline Demo")
|
| 621 |
+
print("=" * 60)
|
| 622 |
+
|
| 623 |
+
# Initialize pipeline
|
| 624 |
+
pipeline = NEXUSPipeline(lazy_load=True)
|
| 625 |
+
|
| 626 |
+
# Demo maternal assessment
|
| 627 |
+
print("\n--- Maternal Assessment Demo ---")
|
| 628 |
+
maternal_patient = PatientInfo(
|
| 629 |
+
patient_id="M001",
|
| 630 |
+
is_maternal=True,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
data_dir = Path(__file__).parent.parent.parent / "data" / "raw"
|
| 634 |
+
anemia_images = list((data_dir / "eyes-defy-anemia").rglob("*.jpg"))[:1]
|
| 635 |
+
|
| 636 |
+
if anemia_images:
|
| 637 |
+
result = pipeline.assess_maternal(maternal_patient, anemia_images[0])
|
| 638 |
+
print(pipeline.generate_report(result))
|
| 639 |
+
|
| 640 |
+
# Demo neonatal assessment
|
| 641 |
+
print("\n--- Neonatal Assessment Demo ---")
|
| 642 |
+
neonatal_patient = PatientInfo(
|
| 643 |
+
patient_id="N001",
|
| 644 |
+
age_days=3,
|
| 645 |
+
gestational_age=38,
|
| 646 |
+
birth_weight=3200,
|
| 647 |
+
gender="M",
|
| 648 |
+
is_maternal=False,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
jaundice_images = list((data_dir / "neojaundice" / "images").glob("*.jpg"))[:1]
|
| 652 |
+
cry_files = list((data_dir / "donate-a-cry").rglob("*.wav"))[:1]
|
| 653 |
+
|
| 654 |
+
skin_image = jaundice_images[0] if jaundice_images else None
|
| 655 |
+
cry_audio = cry_files[0] if cry_files else None
|
| 656 |
+
|
| 657 |
+
if skin_image or cry_audio:
|
| 658 |
+
result = pipeline.assess_neonate(neonatal_patient, skin_image, cry_audio)
|
| 659 |
+
print(pipeline.generate_report(result))
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
if __name__ == "__main__":
|
| 663 |
+
demo()
|