Md Shahabul Alam commited on
Commit
29db30b
·
1 Parent(s): 865fb68

Deploy NEXUS Streamlit demo to HuggingFace Spaces

Browse files
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace Spaces Docker SDK — NEXUS Streamlit Demo
2
+ # Docs: https://huggingface.co/docs/hub/spaces-sdks-docker
3
+
4
+ FROM python:3.12-slim
5
+
6
+ # Create non-root user (required by HF Spaces)
7
+ RUN useradd -m -u 1000 user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ # Install system dependencies for audio processing
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ libsndfile1 \
13
+ ffmpeg \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ WORKDIR /app
17
+
18
+ # Copy requirements and install as user
19
+ COPY --chown=user ./requirements_spaces.txt requirements_spaces.txt
20
+ RUN pip install --no-cache-dir --upgrade -r requirements_spaces.txt
21
+
22
+ # Switch to non-root user
23
+ USER user
24
+
25
+ # Copy source code
26
+ COPY --chown=user ./src/ src/
27
+ COPY --chown=user ./models/ models/
28
+ COPY --chown=user ./app.py .
29
+
30
+ # Set environment
31
+ ENV PYTHONPATH=/app/src
32
+ ENV STREAMLIT_SERVER_PORT=7860
33
+ ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
34
+ ENV STREAMLIT_SERVER_HEADLESS=true
35
+ ENV STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
36
+
37
+ EXPOSE 7860
38
+
39
+ CMD ["python", "-m", "streamlit", "run", "src/demo/streamlit_app.py", \
40
+ "--server.port=7860", \
41
+ "--server.address=0.0.0.0", \
42
+ "--server.headless=true", \
43
+ "--browser.gatherUsageStats=false"]
README.md CHANGED
@@ -1,12 +1,194 @@
1
  ---
2
- title: Nexus
3
- emoji: 🐠
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: NEXUS is an AI-powered platform to detect birth asphyxia
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: NEXUS
3
+ emoji: "\U0001FA7A"
4
  colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
+ pinned: true
9
+ license: cc-by-4.0
10
+ tags:
11
+ - medgemma
12
+ - medical-ai
13
+ - hai-def
14
+ - maternal-health
15
+ - neonatal-care
16
  ---
17
 
18
+ # NEXUS - AI-Powered Maternal-Neonatal Assessment Platform
19
+
20
+ > Non-invasive screening for maternal anemia, neonatal jaundice, and birth asphyxia using Google HAI-DEF models
21
+
22
+ [![License: CC BY 4.0](https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by/4.0/)
23
+ [![MedGemma Impact Challenge](https://img.shields.io/badge/Kaggle-MedGemma%20Impact%20Challenge-20BEFF)](https://www.kaggle.com/competitions/med-gemma-impact-challenge)
24
+
25
+ ## Overview
26
+
27
+ NEXUS transforms smartphones into diagnostic screening tools for Community Health Workers in low-resource settings. Using 3 Google HAI-DEF models in a 6-agent clinical workflow, it provides non-invasive assessment for:
28
+
29
+ - **Maternal anemia** from conjunctiva images (MedSigLIP)
30
+ - **Neonatal jaundice** from skin images with bilirubin regression (MedSigLIP)
31
+ - **Birth asphyxia** from cry audio analysis (HeAR)
32
+ - **Clinical synthesis** with WHO IMNCI protocol alignment (MedGemma)
33
+
34
+ ## HAI-DEF Models
35
+
36
+ | Model | HuggingFace ID | Purpose |
37
+ |-------|----------------|---------|
38
+ | **MedSigLIP** | `google/medsiglip-448` | Anemia + jaundice detection, bilirubin regression |
39
+ | **HeAR** | `google/hear-pytorch` | Cry audio analysis for birth asphyxia |
40
+ | **MedGemma 4B** | `google/medgemma-4b-it` | Clinical reasoning and synthesis |
41
+
42
+ ## Architecture
43
+
44
+ ```
45
+ 6-Agent Clinical Workflow:
46
+ Triage -> Image Analysis (MedSigLIP) -> Audio Analysis (HeAR)
47
+ -> WHO Protocol -> Referral Decision -> Clinical Synthesis (MedGemma)
48
+
49
+ Each agent produces structured reasoning traces for a full audit trail.
50
+ ```
51
+
52
+ ## Quick Start
53
+
54
+ ### Prerequisites
55
+ - Python 3.10+
56
+ - HuggingFace token (for gated HAI-DEF models)
57
+
58
+ ### Setup
59
+
60
+ ```bash
61
+ # Clone and install
62
+ git clone <repo-url>
63
+ cd nexus
64
+ pip install -r requirements.txt
65
+
66
+ # Set HuggingFace token (required for MedSigLIP, MedGemma)
67
+ export HF_TOKEN=hf_your_token_here
68
+ ```
69
+
70
+ ### Run the Demo
71
+
72
+ ```bash
73
+ # Streamlit interactive demo
74
+ PYTHONPATH=src streamlit run src/demo/streamlit_app.py
75
+
76
+ # FastAPI backend
77
+ PYTHONPATH=src uvicorn api.main:app --reload
78
+
79
+ # Run tests
80
+ PYTHONPATH=src python -m pytest tests/ -v
81
+ ```
82
+
83
+ ### Train Models
84
+
85
+ ```bash
86
+ # Train linear probes (anemia + jaundice classifiers)
87
+ PYTHONPATH=src python scripts/training/train_linear_probes.py
88
+
89
+ # Train bilirubin regression head
90
+ PYTHONPATH=src python scripts/training/finetune_bilirubin_regression.py
91
+ ```
92
+
93
+ ### HuggingFace Spaces
94
+
95
+ ```bash
96
+ # Local test of HF Spaces entry point
97
+ python app.py
98
+ ```
99
+
100
+ ## Project Structure
101
+
102
+ ```
103
+ nexus/
104
+ ├── src/nexus/ # Core platform
105
+ │ ├── anemia_detector.py # MedSigLIP anemia detection
106
+ │ ├── jaundice_detector.py # MedSigLIP jaundice + bilirubin regression
107
+ │ ├── cry_analyzer.py # HeAR cry analysis
108
+ │ ├── clinical_synthesizer.py # MedGemma clinical synthesis
109
+ │ ├── agentic_workflow.py # 6-agent workflow engine
110
+ │ └── pipeline.py # Unified assessment pipeline
111
+ ├── src/demo/streamlit_app.py # Interactive Streamlit demo
112
+ ├── api/main.py # FastAPI backend
113
+ ├── scripts/
114
+ │ ├── training/
115
+ │ │ ├── train_linear_probes.py # MedSigLIP embedding classifiers
116
+ │ │ ├── finetune_bilirubin_regression.py # Novel bilirubin regression
117
+ │ │ ├── train_anemia.py # Anemia-specific training
118
+ │ │ ├── train_jaundice.py # Jaundice-specific training
119
+ │ │ └── train_cry.py # Cry classifier training
120
+ │ └── edge/
121
+ │ ├── quantize_models.py # INT8 quantization
122
+ │ └── export_embeddings.py # Pre-computed text embeddings
123
+ ├── notebooks/
124
+ │ ├── 01_anemia_detection.ipynb
125
+ │ ├── 02_jaundice_detection.ipynb
126
+ │ ├── 03_cry_analysis.ipynb
127
+ │ └── 04_bilirubin_regression.ipynb # Novel task reproducibility
128
+ ├── tests/
129
+ │ ├── test_pipeline.py # Pipeline tests
130
+ │ ├── test_agentic_workflow.py # Agentic workflow tests (41 tests)
131
+ │ └── test_hai_def_integration.py # HAI-DEF model compliance
132
+ ├── models/
133
+ │ ├── linear_probes/ # Trained classifiers + regressor
134
+ │ └── edge/ # Quantized models + embeddings
135
+ ├── data/
136
+ │ ├── raw/ # Raw datasets (Eyes-Defy-Anemia, NeoJaundice, CryCeleb)
137
+ │ └── protocols/ # WHO IMNCI protocols
138
+ ├── submission/
139
+ │ ├── writeup.md # Competition writeup (3 pages)
140
+ │ └── video/ # Demo video script and assets
141
+ ├── app.py # HuggingFace Spaces entry point
142
+ ├── requirements.txt # Full dependencies
143
+ └── requirements_spaces.txt # HF Spaces minimal dependencies
144
+ ```
145
+
146
+ ## Key Results
147
+
148
+ | Task | Method | Performance |
149
+ |------|--------|-------------|
150
+ | Anemia zero-shot | MedSigLIP (max-similarity, 8 prompts/class) | Screening capability |
151
+ | Jaundice classification | MedSigLIP linear probe | 68.9% accuracy |
152
+ | **Bilirubin regression** | **MedSigLIP + MLP head** | **MAE: 2.667 mg/dL, r=0.77** |
153
+ | Cry analysis | HeAR + acoustic features | Qualitative assessment |
154
+ | Clinical synthesis | MedGemma + WHO IMNCI | Protocol-aligned recommendations |
155
+
156
+ ### Novel Task: Bilirubin Regression
157
+ Frozen MedSigLIP embeddings -> 2-layer MLP -> continuous bilirubin (mg/dL) prediction.
158
+ Trained on 2,235 NeoJaundice images with ground truth serum bilirubin.
159
+ **MAE: 2.667 mg/dL, Pearson r: 0.7725 (p < 1e-67)**
160
+
161
+ ### Edge AI
162
+ - INT8 dynamic quantization: 812.6 MB -> 111.2 MB (7.31x compression)
163
+ - Pre-computed text embeddings: 12 KB (no text encoder on device)
164
+ - Total on-device: ~289 MB
165
+
166
+ ## Competition Tracks
167
+
168
+ - **Main Track**: Comprehensive maternal-neonatal assessment platform
169
+ - **Agentic Workflow Prize**: 6-agent pipeline with reasoning traces and audit trail
170
+
171
+ ## Tests
172
+
173
+ ```bash
174
+ # All tests
175
+ PYTHONPATH=src python -m pytest tests/ -v
176
+
177
+ # Agentic workflow only (41 tests)
178
+ PYTHONPATH=src python -m pytest tests/test_agentic_workflow.py -v
179
+ ```
180
+
181
+ ## License
182
+
183
+ [CC BY 4.0](LICENSE)
184
+
185
+ ## Acknowledgments
186
+
187
+ - Google Health AI Developer Foundations team
188
+ - NeoJaundice dataset (Figshare)
189
+ - Eyes-Defy-Anemia dataset (Kaggle)
190
+ - WHO IMNCI protocol guidelines
191
+
192
+ ---
193
+
194
+ Built with Google HAI-DEF for the MedGemma Impact Challenge 2026
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NEXUS - HuggingFace Spaces Entry Point
3
+
4
+ Launches the Streamlit demo for the NEXUS Maternal-Neonatal Care Platform.
5
+ Built with Google HAI-DEF models for the MedGemma Impact Challenge 2026.
6
+ """
7
+
8
+ import os
9
+ import subprocess
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # Ensure src/ is on the Python path for imports
14
+ ROOT = Path(__file__).parent
15
+ SRC_DIR = ROOT / "src"
16
+ if str(SRC_DIR) not in sys.path:
17
+ sys.path.insert(0, str(SRC_DIR))
18
+
19
+ # Set environment defaults for HF Spaces
20
+ os.environ.setdefault("STREAMLIT_SERVER_PORT", "7860")
21
+ os.environ.setdefault("STREAMLIT_SERVER_ADDRESS", "0.0.0.0")
22
+ os.environ.setdefault("STREAMLIT_SERVER_HEADLESS", "true")
23
+ os.environ.setdefault("STREAMLIT_BROWSER_GATHER_USAGE_STATS", "false")
24
+
25
+
26
+ def main():
27
+ app_path = SRC_DIR / "demo" / "streamlit_app.py"
28
+ if not app_path.exists():
29
+ print(f"ERROR: Streamlit app not found at {app_path}")
30
+ sys.exit(1)
31
+
32
+ port = os.environ.get("PORT", os.environ["STREAMLIT_SERVER_PORT"])
33
+ os.environ["STREAMLIT_SERVER_PORT"] = str(port)
34
+
35
+ try:
36
+ subprocess.run(
37
+ [
38
+ sys.executable, "-m", "streamlit", "run",
39
+ str(app_path),
40
+ f"--server.port={port}",
41
+ f"--server.address={os.environ['STREAMLIT_SERVER_ADDRESS']}",
42
+ f"--server.headless={os.environ['STREAMLIT_SERVER_HEADLESS']}",
43
+ f"--browser.gatherUsageStats={os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS']}",
44
+ ],
45
+ check=True,
46
+ env={**os.environ, "PYTHONPATH": str(SRC_DIR)},
47
+ )
48
+ except subprocess.CalledProcessError as e:
49
+ print(f"ERROR: Streamlit process exited with code {e.returncode}")
50
+ sys.exit(e.returncode)
51
+ except FileNotFoundError:
52
+ print("ERROR: Streamlit not installed. Run: pip install streamlit")
53
+ sys.exit(1)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
models/linear_probes/anemia_classifier_metadata.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SVM_RBF",
3
+ "embedding_source": "MedSigLIP (google/medsiglip-448)",
4
+ "embedding_dim": 1152,
5
+ "num_classes": 2,
6
+ "classes": {
7
+ "healthy": 0,
8
+ "anemic": 1
9
+ },
10
+ "cv_accuracy_mean": 0.9994269340974211,
11
+ "cv_accuracy_std": 0.0011461318051575909,
12
+ "num_original_samples": 218,
13
+ "num_augmented_samples": 1744,
14
+ "augmentations_per_image": 7,
15
+ "all_results": {
16
+ "LogisticRegression": {
17
+ "mean_accuracy": 0.8985096993050752,
18
+ "std_accuracy": 0.008415256920621202
19
+ },
20
+ "SVM_RBF": {
21
+ "mean_accuracy": 0.9994269340974211,
22
+ "std_accuracy": 0.0011461318051575909
23
+ },
24
+ "SVM_Linear": {
25
+ "mean_accuracy": 0.8899186509896915,
26
+ "std_accuracy": 0.011746435929843532
27
+ }
28
+ },
29
+ "seed": 42
30
+ }
models/linear_probes/bilirubin_regression_results.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mae": 2.564,
3
+ "rmse": 3.416,
4
+ "pearson_r": 0.7783,
5
+ "pearson_p": 1.7171921789198235e-69,
6
+ "bland_altman": {
7
+ "mean_diff": -0.506,
8
+ "std_diff": 3.379,
9
+ "loa_upper": 6.116,
10
+ "loa_lower": -7.129
11
+ },
12
+ "test_size": 336,
13
+ "train_size": 1563,
14
+ "val_size": 336,
15
+ "input_dim": 1152,
16
+ "hidden_dim": 256,
17
+ "epochs_trained": 58,
18
+ "best_val_loss": 3.7143,
19
+ "bilirubin_range": {
20
+ "min": 0.0,
21
+ "max": 25.7,
22
+ "mean": 11.2,
23
+ "std": 5.2
24
+ },
25
+ "history": {
26
+ "train_loss": [
27
+ 17.543,
28
+ 12.0442,
29
+ 6.9412,
30
+ 4.2175,
31
+ 3.5428,
32
+ 3.4781,
33
+ 3.0782,
34
+ 2.8347,
35
+ 2.7914,
36
+ 2.5293,
37
+ 2.393,
38
+ 2.2627,
39
+ 2.1357,
40
+ 2.1498,
41
+ 1.875,
42
+ 2.0569,
43
+ 1.843,
44
+ 1.7077,
45
+ 1.7084,
46
+ 1.6893,
47
+ 1.7543,
48
+ 2.0793,
49
+ 2.1218,
50
+ 2.1285,
51
+ 2.0992,
52
+ 1.9611,
53
+ 1.93,
54
+ 1.8854,
55
+ 1.9694,
56
+ 1.6901,
57
+ 1.699,
58
+ 1.7061,
59
+ 1.5767,
60
+ 1.6265,
61
+ 1.5394,
62
+ 1.4675,
63
+ 1.3684,
64
+ 1.4486,
65
+ 1.2866,
66
+ 1.3152,
67
+ 1.2613,
68
+ 1.1721,
69
+ 1.1946,
70
+ 1.2039,
71
+ 1.1949,
72
+ 1.129,
73
+ 1.0557,
74
+ 1.0699,
75
+ 1.0325,
76
+ 1.0427,
77
+ 1.0431,
78
+ 1.0722,
79
+ 1.0071,
80
+ 1.0187,
81
+ 0.8847,
82
+ 0.9988,
83
+ 0.942,
84
+ 0.9464
85
+ ],
86
+ "val_loss": [
87
+ 18.4316,
88
+ 13.9118,
89
+ 6.9486,
90
+ 4.5588,
91
+ 5.5443,
92
+ 4.184,
93
+ 4.8748,
94
+ 4.0967,
95
+ 4.0286,
96
+ 4.1705,
97
+ 4.0592,
98
+ 3.921,
99
+ 4.1161,
100
+ 4.0279,
101
+ 3.9931,
102
+ 3.8783,
103
+ 3.8742,
104
+ 3.8394,
105
+ 3.949,
106
+ 3.8805,
107
+ 3.8673,
108
+ 3.9437,
109
+ 4.1339,
110
+ 4.3688,
111
+ 4.5384,
112
+ 4.0601,
113
+ 3.9022,
114
+ 3.7252,
115
+ 3.9551,
116
+ 3.9791,
117
+ 3.7946,
118
+ 4.0627,
119
+ 3.815,
120
+ 4.0698,
121
+ 4.0345,
122
+ 3.9504,
123
+ 3.8177,
124
+ 3.8626,
125
+ 3.8044,
126
+ 3.7743,
127
+ 3.8432,
128
+ 3.8456,
129
+ 3.7143,
130
+ 3.8196,
131
+ 3.8955,
132
+ 3.7218,
133
+ 3.7605,
134
+ 3.7768,
135
+ 3.7581,
136
+ 3.7667,
137
+ 3.7499,
138
+ 3.7481,
139
+ 3.7286,
140
+ 3.7502,
141
+ 3.7814,
142
+ 3.734,
143
+ 3.7887,
144
+ 3.7414
145
+ ],
146
+ "val_mae": [
147
+ 10.19,
148
+ 7.908,
149
+ 4.388,
150
+ 3.118,
151
+ 3.652,
152
+ 2.965,
153
+ 3.299,
154
+ 2.901,
155
+ 2.884,
156
+ 2.947,
157
+ 2.876,
158
+ 2.814,
159
+ 2.93,
160
+ 2.866,
161
+ 2.854,
162
+ 2.792,
163
+ 2.794,
164
+ 2.77,
165
+ 2.836,
166
+ 2.798,
167
+ 2.787,
168
+ 2.814,
169
+ 2.931,
170
+ 3.052,
171
+ 3.148,
172
+ 2.89,
173
+ 2.803,
174
+ 2.691,
175
+ 2.83,
176
+ 2.837,
177
+ 2.737,
178
+ 2.884,
179
+ 2.749,
180
+ 2.901,
181
+ 2.874,
182
+ 2.829,
183
+ 2.761,
184
+ 2.778,
185
+ 2.734,
186
+ 2.721,
187
+ 2.761,
188
+ 2.774,
189
+ 2.692,
190
+ 2.749,
191
+ 2.803,
192
+ 2.699,
193
+ 2.714,
194
+ 2.719,
195
+ 2.704,
196
+ 2.716,
197
+ 2.717,
198
+ 2.704,
199
+ 2.699,
200
+ 2.711,
201
+ 2.731,
202
+ 2.701,
203
+ 2.736,
204
+ 2.706
205
+ ]
206
+ }
207
+ }
models/linear_probes/cry_classifier_metadata.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SVM_RBF",
3
+ "embedding_source": "HeAR (google/hear-pytorch)",
4
+ "embedding_dim": 512,
5
+ "num_classes": 5,
6
+ "classes": {
7
+ "belly_pain": 0,
8
+ "burping": 1,
9
+ "discomfort": 2,
10
+ "hungry": 3,
11
+ "tired": 4
12
+ },
13
+ "cv_accuracy_mean": 0.8380793119923554,
14
+ "cv_accuracy_std": 0.008077431438521396,
15
+ "num_samples": 457,
16
+ "all_results": {
17
+ "LogisticRegression": {
18
+ "mean_accuracy": 0.7985905398948876,
19
+ "std_accuracy": 0.028055714127978745
20
+ },
21
+ "SVM_RBF": {
22
+ "mean_accuracy": 0.8380793119923554,
23
+ "std_accuracy": 0.008077431438521396
24
+ },
25
+ "SVM_Linear": {
26
+ "mean_accuracy": 0.765862398471094,
27
+ "std_accuracy": 0.013071624843302853
28
+ }
29
+ },
30
+ "seed": 42
31
+ }
models/linear_probes/jaundice_classifier_metadata.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SVM_RBF",
3
+ "embedding_source": "MedSigLIP (google/medsiglip-448)",
4
+ "embedding_dim": 1152,
5
+ "num_classes": 2,
6
+ "classes": {
7
+ "normal": 0,
8
+ "jaundice": 1
9
+ },
10
+ "bilirubin_threshold": 5.0,
11
+ "cv_accuracy_mean": 0.967337807606264,
12
+ "cv_accuracy_std": 0.002197637886396911,
13
+ "num_original_samples": 2235,
14
+ "num_augmented_samples": 8940,
15
+ "augmentations_per_image": 3,
16
+ "all_results": {
17
+ "LogisticRegression": {
18
+ "mean_accuracy": 0.9422818791946309,
19
+ "std_accuracy": 0.004750953150245027
20
+ },
21
+ "SVM_RBF": {
22
+ "mean_accuracy": 0.967337807606264,
23
+ "std_accuracy": 0.002197637886396911
24
+ },
25
+ "SVM_Linear": {
26
+ "mean_accuracy": 0.9322147651006712,
27
+ "std_accuracy": 0.006743027683714353
28
+ }
29
+ },
30
+ "seed": 42
31
+ }
models/linear_probes/linear_probe_results.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "anemia": {
3
+ "accuracy": 0.5227272727272727,
4
+ "precision": 0.5185185185185185,
5
+ "recall": 0.6363636363636364,
6
+ "f1": 0.5714285714285714,
7
+ "train_size": 174,
8
+ "test_size": 44
9
+ },
10
+ "jaundice": {
11
+ "accuracy": 0.6957494407158836,
12
+ "precision": 0.6854460093896714,
13
+ "recall": 0.6790697674418604,
14
+ "f1": 0.6822429906542056,
15
+ "train_size": 1788,
16
+ "test_size": 447
17
+ }
18
+ }
requirements_spaces.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NEXUS - HuggingFace Spaces Dependencies
2
+ # Minimal set for Streamlit demo deployment (CPU)
3
+
4
+ torch>=2.1.0
5
+ transformers>=4.44.0
6
+ accelerate>=0.25.0
7
+ safetensors>=0.4.0
8
+ sentencepiece>=0.1.99
9
+ huggingface_hub>=0.20.0
10
+
11
+ # Audio
12
+ librosa>=0.10.0
13
+ soundfile>=0.12.0
14
+
15
+ # Image
16
+ Pillow>=10.0.0
17
+
18
+ # Data
19
+ numpy>=1.24.0
20
+ pandas>=2.0.0
21
+ scipy>=1.11.0
22
+ scikit-learn>=1.3.0
23
+
24
+ # Demo
25
+ streamlit>=1.28.0
26
+ plotly>=5.18.0
27
+
28
+ # Utilities
29
+ pyyaml>=6.0.0
30
+ tqdm>=4.66.0
31
+ joblib>=1.3.0
src/demo/__init__.py ADDED
File without changes
src/demo/streamlit_app.py ADDED
@@ -0,0 +1,1189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NEXUS Streamlit Demo Application
3
+
4
+ Interactive demo for the NEXUS Maternal-Neonatal Care Platform.
5
+ Built with Google HAI-DEF models for the MedGemma Impact Challenge.
6
+
7
+ HAI-DEF Models Used:
8
+ - MedSigLIP: Medical image analysis (anemia, jaundice detection)
9
+ - HeAR: Health acoustic representations (cry analysis)
10
+ - MedGemma: Clinical reasoning and synthesis
11
+ """
12
+
13
+ import streamlit as st
14
+ from pathlib import Path
15
+ import sys
16
+ import os
17
+ import tempfile
18
+ import json
19
+
20
+ # Add parent directory to path for imports
21
+ sys.path.insert(0, str(Path(__file__).parent.parent))
22
+
23
+ # Page configuration
24
+ st.set_page_config(
25
+ page_title="NEXUS - Maternal-Neonatal Care",
26
+ page_icon="👶",
27
+ layout="wide",
28
+ initial_sidebar_state="expanded",
29
+ )
30
+
31
+ # Custom CSS
32
+ st.markdown("""
33
+ <style>
34
+ .main-header {
35
+ font-size: 2.5rem;
36
+ font-weight: bold;
37
+ color: #1f77b4;
38
+ text-align: center;
39
+ margin-bottom: 1rem;
40
+ }
41
+ .sub-header {
42
+ font-size: 1.2rem;
43
+ color: #666;
44
+ text-align: center;
45
+ margin-bottom: 2rem;
46
+ }
47
+ .risk-high {
48
+ background-color: #ffcccc;
49
+ border: 2px solid #ff0000;
50
+ padding: 1rem;
51
+ border-radius: 10px;
52
+ }
53
+ .risk-medium {
54
+ background-color: #fff3cd;
55
+ border: 2px solid #ffc107;
56
+ padding: 1rem;
57
+ border-radius: 10px;
58
+ }
59
+ .risk-low {
60
+ background-color: #d4edda;
61
+ border: 2px solid #28a745;
62
+ padding: 1rem;
63
+ border-radius: 10px;
64
+ }
65
+ .metric-card {
66
+ background-color: #f8f9fa;
67
+ padding: 1rem;
68
+ border-radius: 10px;
69
+ text-align: center;
70
+ }
71
+ .model-badge {
72
+ display: inline-block;
73
+ padding: 2px 10px;
74
+ border-radius: 12px;
75
+ font-size: 0.78rem;
76
+ font-weight: 600;
77
+ color: white;
78
+ letter-spacing: 0.3px;
79
+ }
80
+ .stMetric > div {
81
+ background-color: #f8f9fa;
82
+ padding: 0.5rem;
83
+ border-radius: 8px;
84
+ }
85
+ </style>
86
+ """, unsafe_allow_html=True)
87
+
88
+
89
+ @st.cache_resource
90
+ def load_anemia_detector():
91
+ """Load anemia detector model with error handling."""
92
+ try:
93
+ from nexus.anemia_detector import AnemiaDetector
94
+ detector = AnemiaDetector()
95
+ return detector, None
96
+ except Exception as e:
97
+ return None, str(e)
98
+
99
+
100
+ @st.cache_resource
101
+ def load_jaundice_detector():
102
+ """Load jaundice detector model with error handling."""
103
+ try:
104
+ from nexus.jaundice_detector import JaundiceDetector
105
+ detector = JaundiceDetector()
106
+ return detector, None
107
+ except Exception as e:
108
+ return None, str(e)
109
+
110
+
111
+ @st.cache_resource
112
+ def load_cry_analyzer():
113
+ """Load cry analyzer with error handling."""
114
+ try:
115
+ from nexus.cry_analyzer import CryAnalyzer
116
+ analyzer = CryAnalyzer()
117
+ return analyzer, None
118
+ except Exception as e:
119
+ return None, str(e)
120
+
121
+
122
+ @st.cache_resource
123
+ def load_clinical_synthesizer():
124
+ """Load clinical synthesizer (MedGemma) with error handling."""
125
+ try:
126
+ import os
127
+ from nexus.clinical_synthesizer import ClinicalSynthesizer
128
+ use_medgemma = os.environ.get("NEXUS_USE_MEDGEMMA", "true").lower() != "false"
129
+ synthesizer = ClinicalSynthesizer(use_medgemma=use_medgemma)
130
+ return synthesizer, None
131
+ except Exception as e:
132
+ return None, str(e)
133
+
134
+
135
+ def get_hai_def_info():
136
+ """Get HAI-DEF models information with validated accuracy numbers."""
137
+ return {
138
+ "MedSigLIP": {
139
+ "name": "MedSigLIP (google/medsiglip-448)",
140
+ "use": "Image analysis for anemia and jaundice detection + bilirubin regression",
141
+ "method": "Zero-shot classification (max-similarity, 8 prompts/class) + trained SVM/LR classifiers on embeddings",
142
+ "accuracy": "Anemia: trained classifier on augmented data, Jaundice: trained classifier on 2,235 images, Bilirubin: MAE 2.67 mg/dL (r=0.77)",
143
+ "badge": "Vision",
144
+ "badge_color": "#388e3c",
145
+ },
146
+ "HeAR": {
147
+ "name": "HeAR (google/hear-pytorch)",
148
+ "use": "Infant cry analysis for asphyxia and cry type classification",
149
+ "method": "512-dim health acoustic embeddings + trained linear classifier on donate-a-cry dataset (5-class: hungry, belly_pain, burping, discomfort, tired)",
150
+ "accuracy": "Trained cry type classifier with asphyxia risk derivation from distress patterns",
151
+ "badge": "Audio",
152
+ "badge_color": "#f57c00",
153
+ },
154
+ "MedGemma": {
155
+ "name": "MedGemma 1.5 4B (google/medgemma-1.5-4b-it)",
156
+ "use": "Clinical reasoning and recommendation synthesis",
157
+ "method": "4-bit NF4 quantized inference with WHO IMNCI protocol-aligned synthesis and 6-agent reasoning traces",
158
+ "accuracy": "Protocol-aligned clinical recommendations with structured reasoning chains",
159
+ "badge": "Language",
160
+ "badge_color": "#1976d2",
161
+ },
162
+ }
163
+
164
+
165
+ def main():
166
+ """Main application."""
167
+
168
+ # Header
169
+ st.markdown('<div class="main-header">NEXUS</div>', unsafe_allow_html=True)
170
+ st.markdown(
171
+ '<div class="sub-header">AI-Powered Maternal-Neonatal Care Platform</div>',
172
+ unsafe_allow_html=True
173
+ )
174
+
175
+ # Sidebar
176
+ with st.sidebar:
177
+ st.markdown("## 🏥 NEXUS")
178
+ st.markdown("---")
179
+
180
+ assessment_type = st.radio(
181
+ "Select Assessment Type",
182
+ [
183
+ "Maternal Anemia Screening",
184
+ "Neonatal Jaundice Detection",
185
+ "Cry Analysis",
186
+ "Combined Assessment",
187
+ "Agentic Workflow",
188
+ "HAI-DEF Models Info"
189
+ ],
190
+ index=0,
191
+ )
192
+
193
+ st.markdown("---")
194
+ st.markdown("### About NEXUS")
195
+ st.markdown("""
196
+ NEXUS uses AI to provide non-invasive screening for:
197
+ - **Maternal Anemia** via conjunctiva imaging
198
+ - **Neonatal Jaundice** via skin color analysis
199
+ - **Birth Asphyxia** via cry pattern analysis
200
+
201
+ Built with **Google HAI-DEF models** for the MedGemma Impact Challenge 2026.
202
+ """)
203
+
204
+ st.markdown("---")
205
+ st.markdown("### Edge AI Mode")
206
+ edge_mode = st.toggle("Enable Edge AI Mode", value=False, key="edge_mode")
207
+ if edge_mode:
208
+ st.success("Edge AI: INT8 quantized models + offline inference")
209
+ else:
210
+ st.info("Cloud mode: Full-precision HAI-DEF models")
211
+
212
+ st.markdown("---")
213
+ st.markdown("### HAI-DEF Models")
214
+ st.markdown("""
215
+ - **MedSigLIP**: Vision (trained classifiers)
216
+ - **HeAR**: Audio (trained cry classifier)
217
+ - **MedGemma 1.5**: Clinical AI (4-bit NF4)
218
+ """)
219
+
220
+ # Show Edge AI banner when enabled
221
+ if edge_mode:
222
+ render_edge_ai_banner()
223
+
224
+ # Main content based on selection
225
+ if assessment_type == "Maternal Anemia Screening":
226
+ render_anemia_screening()
227
+ elif assessment_type == "Neonatal Jaundice Detection":
228
+ render_jaundice_detection()
229
+ elif assessment_type == "Cry Analysis":
230
+ render_cry_analysis()
231
+ elif assessment_type == "Combined Assessment":
232
+ render_combined_assessment()
233
+ elif assessment_type == "Agentic Workflow":
234
+ render_agentic_workflow()
235
+ else:
236
+ render_hai_def_info()
237
+
238
+
239
+ def render_edge_ai_banner():
240
+ """Show Edge AI mode status and model metrics."""
241
+ st.markdown("""
242
+ <div style="background: linear-gradient(135deg, #1a237e 0%, #0d47a1 100%);
243
+ color: white; padding: 1rem 1.5rem; border-radius: 10px; margin-bottom: 1rem;">
244
+ <h4 style="margin:0; color: white;">Edge AI Mode Active</h4>
245
+ <p style="margin: 0.3rem 0 0 0; opacity: 0.9; font-size: 0.9rem;">
246
+ Running INT8 quantized models for offline-capable inference on low-resource devices.
247
+ </p>
248
+ </div>
249
+ """, unsafe_allow_html=True)
250
+
251
+ col1, col2, col3, col4 = st.columns(4)
252
+ with col1:
253
+ st.metric("MedSigLIP INT8", "111.2 MB", "-86% memory")
254
+ with col2:
255
+ st.metric("Acoustic Model", "0.6 MB", "INT8 quantized")
256
+ with col3:
257
+ st.metric("Text Embeddings", "12 KB", "Pre-computed")
258
+ with col4:
259
+ st.metric("Total Edge Size", "~289 MB", "Offline-ready")
260
+
261
+ with st.expander("Edge AI Details"):
262
+ st.markdown("""
263
+ **Quantization**: Dynamic INT8 (PyTorch `quantize_dynamic`, qnnpack backend)
264
+
265
+ | Component | Cloud (FP32) | Edge (INT8) | Compression |
266
+ |-----------|-------------|-------------|-------------|
267
+ | MedSigLIP Vision | 812.6 MB | 111.2 MB | **7.31x** |
268
+ | Acoustic Model | 0.665 MB | 0.599 MB | 1.11x |
269
+ | CPU Latency | 97.7 ms | ~65 ms (ARM est.) | ~1.5x faster |
270
+
271
+ **Target Devices**: Android 8.0+, ARM Cortex-A53, 2GB RAM
272
+
273
+ **Offline Capabilities**:
274
+ - Image analysis via INT8 MedSigLIP + pre-computed binary text embeddings
275
+ - Audio analysis via INT8 acoustic feature extractor
276
+ - Clinical reasoning via rule-based WHO IMNCI protocols (no MedGemma required)
277
+ """)
278
+
279
+
280
+ def _cleanup_temp(path: str) -> None:
281
+ """Safely remove a temporary file."""
282
+ try:
283
+ if path and os.path.exists(path):
284
+ os.unlink(path)
285
+ except OSError:
286
+ pass
287
+
288
+
289
+ def _save_upload_to_temp(uploaded_file, suffix: str) -> str:
290
+ """Save an uploaded file to a temporary path and return the path."""
291
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
292
+ try:
293
+ tmp.write(uploaded_file.getvalue())
294
+ tmp.close()
295
+ return tmp.name
296
+ except Exception:
297
+ tmp.close()
298
+ _cleanup_temp(tmp.name)
299
+ raise
300
+
301
+
302
+ def _model_badge(name: str, color: str) -> str:
303
+ """Return an HTML badge for displaying which HAI-DEF model is active."""
304
+ return (
305
+ f'<span style="background:{color}; color:white; padding:2px 10px; '
306
+ f'border-radius:12px; font-size:0.78rem; font-weight:600; '
307
+ f'letter-spacing:0.3px;">{name}</span>'
308
+ )
309
+
310
+
311
+ def render_anemia_screening():
312
+ """Render anemia screening interface."""
313
+ st.header("Maternal Anemia Screening")
314
+ st.markdown(
315
+ f"Upload a clear image of the inner eyelid (conjunctiva) for anemia screening. "
316
+ f'{_model_badge("MedSigLIP", "#388e3c")}',
317
+ unsafe_allow_html=True,
318
+ )
319
+
320
+ col1, col2 = st.columns([1, 1])
321
+
322
+ with col1:
323
+ st.subheader("Upload Image")
324
+ uploaded_file = st.file_uploader(
325
+ "Choose a conjunctiva image",
326
+ type=["jpg", "jpeg", "png"],
327
+ key="anemia_upload"
328
+ )
329
+
330
+ if uploaded_file:
331
+ st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
332
+
333
+ with col2:
334
+ st.subheader("Analysis Results")
335
+
336
+ if uploaded_file:
337
+ with st.spinner("Analyzing image..."):
338
+ tmp_path = None
339
+ try:
340
+ detector, load_err = load_anemia_detector()
341
+ if detector is None:
342
+ st.error(f"Could not load model: {load_err}")
343
+ return
344
+
345
+ tmp_path = _save_upload_to_temp(uploaded_file, ".jpg")
346
+
347
+ result = detector.detect(tmp_path)
348
+ color_info = detector.analyze_color_features(tmp_path)
349
+
350
+ # Display results
351
+ risk_class = f"risk-{result['risk_level']}"
352
+ st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
353
+
354
+ if result["is_anemic"]:
355
+ st.error("⚠️ ANEMIA DETECTED")
356
+ else:
357
+ st.success("✅ No Anemia Detected")
358
+
359
+ st.markdown("</div>", unsafe_allow_html=True)
360
+
361
+ # Metrics
362
+ col_a, col_b, col_c = st.columns(3)
363
+ with col_a:
364
+ st.metric("Confidence", f"{result['confidence']:.1%}")
365
+ with col_b:
366
+ st.metric("Risk Level", result['risk_level'].upper())
367
+ with col_c:
368
+ st.metric("Est. Hemoglobin", f"{color_info['estimated_hemoglobin']} g/dL")
369
+
370
+ # Recommendation
371
+ st.markdown("### Recommendation")
372
+ st.info(result["recommendation"])
373
+
374
+ # Color analysis
375
+ with st.expander("Technical Details"):
376
+ st.json({
377
+ "anemia_score": round(result["anemia_score"], 3),
378
+ "healthy_score": round(result["healthy_score"], 3),
379
+ "red_ratio": round(color_info["red_ratio"], 3),
380
+ "pallor_index": round(color_info["pallor_index"], 3),
381
+ })
382
+
383
+ except Exception as e:
384
+ st.error(f"Error analyzing image: {e}")
385
+ finally:
386
+ _cleanup_temp(tmp_path)
387
+ else:
388
+ st.info("👆 Upload an image to begin analysis")
389
+
390
+
391
+ def render_jaundice_detection():
392
+ """Render jaundice detection interface."""
393
+ st.header("Neonatal Jaundice Detection")
394
+ st.markdown(
395
+ f"Upload an image of the newborn's skin or sclera for jaundice assessment. "
396
+ f'{_model_badge("MedSigLIP", "#388e3c")}',
397
+ unsafe_allow_html=True,
398
+ )
399
+
400
+ col1, col2 = st.columns([1, 1])
401
+
402
+ with col1:
403
+ st.subheader("Upload Image")
404
+ uploaded_file = st.file_uploader(
405
+ "Choose a neonatal image",
406
+ type=["jpg", "jpeg", "png"],
407
+ key="jaundice_upload"
408
+ )
409
+
410
+ if uploaded_file:
411
+ st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
412
+
413
+ # Patient info
414
+ st.subheader("Patient Information (Optional)")
415
+ age_days = st.number_input("Age (days)", min_value=0, max_value=28, value=3)
416
+ birth_weight = st.number_input("Birth weight (grams)", min_value=500, max_value=5000, value=3000)
417
+
418
+ with col2:
419
+ st.subheader("Analysis Results")
420
+
421
+ if uploaded_file:
422
+ with st.spinner("Analyzing image..."):
423
+ tmp_path = None
424
+ try:
425
+ detector, load_err = load_jaundice_detector()
426
+ if detector is None:
427
+ st.error(f"Could not load model: {load_err}")
428
+ return
429
+
430
+ tmp_path = _save_upload_to_temp(uploaded_file, ".jpg")
431
+
432
+ result = detector.detect(tmp_path)
433
+ zone_info = detector.analyze_kramer_zones(tmp_path)
434
+
435
+ # Display results
436
+ risk_class = "risk-high" if result["needs_phototherapy"] else (
437
+ "risk-medium" if result["severity"] in ["moderate", "mild"] else "risk-low"
438
+ )
439
+ st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
440
+
441
+ if result["has_jaundice"]:
442
+ st.warning(f"⚠️ JAUNDICE DETECTED - {result['severity'].upper()}")
443
+ else:
444
+ st.success("✅ No Significant Jaundice")
445
+
446
+ st.markdown("</div>", unsafe_allow_html=True)
447
+
448
+ # Metrics - show ML bilirubin if available
449
+ col_a, col_b, col_c = st.columns(3)
450
+ with col_a:
451
+ bili_value = result.get('estimated_bilirubin_ml', result.get('estimated_bilirubin', 0))
452
+ bili_method = result.get('bilirubin_method', 'Color Analysis')
453
+ st.metric("Est. Bilirubin", f"{bili_value} mg/dL")
454
+ st.caption(f"Method: {bili_method}")
455
+ with col_b:
456
+ st.metric("Severity", result['severity'].upper())
457
+ with col_c:
458
+ st.metric("Kramer Zone", zone_info['kramer_zone'])
459
+
460
+ # Phototherapy indicator
461
+ if result["needs_phototherapy"]:
462
+ st.error("🔆 PHOTOTHERAPY RECOMMENDED")
463
+
464
+ # Recommendation
465
+ st.markdown("### Recommendation")
466
+ st.info(result["recommendation"])
467
+
468
+ # Zone analysis
469
+ with st.expander("Kramer Zone Analysis"):
470
+ st.write(f"**Zone**: {zone_info['kramer_zone']} - {zone_info['zone_description']}")
471
+ st.write(f"**Yellow Index**: {zone_info['yellow_index']}")
472
+ st.progress(min(zone_info['yellow_index'] * 2, 1.0))
473
+
474
+ # Technical details
475
+ with st.expander("Technical Details"):
476
+ details = {
477
+ "jaundice_score": round(result["jaundice_score"], 3),
478
+ "confidence": round(result["confidence"], 3),
479
+ "model": result.get("model", "unknown"),
480
+ "model_type": result.get("model_type", "unknown"),
481
+ "bilirubin_method": result.get("bilirubin_method", "Color Analysis"),
482
+ }
483
+ if result.get("estimated_bilirubin_ml") is not None:
484
+ details["bilirubin_ml"] = result["estimated_bilirubin_ml"]
485
+ details["bilirubin_color"] = result["estimated_bilirubin"]
486
+ st.json(details)
487
+
488
+ except Exception as e:
489
+ st.error(f"Error analyzing image: {e}")
490
+ finally:
491
+ _cleanup_temp(tmp_path)
492
+ else:
493
+ st.info("👆 Upload an image to begin analysis")
494
+
495
+
496
+ def render_cry_analysis():
497
+ """Render cry analysis interface."""
498
+ st.header("Infant Cry Analysis")
499
+ st.markdown(
500
+ f"Upload an audio recording of the infant's cry for analysis. "
501
+ f'{_model_badge("HeAR", "#f57c00")}',
502
+ unsafe_allow_html=True,
503
+ )
504
+
505
+ col1, col2 = st.columns([1, 1])
506
+
507
+ with col1:
508
+ st.subheader("Upload Audio")
509
+ uploaded_file = st.file_uploader(
510
+ "Choose a cry audio file",
511
+ type=["wav", "mp3", "ogg"],
512
+ key="cry_upload"
513
+ )
514
+
515
+ if uploaded_file:
516
+ st.audio(uploaded_file)
517
+
518
+ with col2:
519
+ st.subheader("Analysis Results")
520
+
521
+ if uploaded_file:
522
+ with st.spinner("Analyzing cry..."):
523
+ tmp_path = None
524
+ try:
525
+ analyzer, load_err = load_cry_analyzer()
526
+ if analyzer is None:
527
+ st.error(f"Could not load model: {load_err}")
528
+ return
529
+
530
+ tmp_path = _save_upload_to_temp(uploaded_file, ".wav")
531
+
532
+ result = analyzer.analyze(tmp_path)
533
+
534
+ # Display results
535
+ risk_class = f"risk-{result['risk_level']}"
536
+ st.markdown(f'<div class="{risk_class}">', unsafe_allow_html=True)
537
+
538
+ if result["is_abnormal"]:
539
+ st.error("⚠️ ABNORMAL CRY PATTERN DETECTED")
540
+ else:
541
+ st.success("✅ Normal Cry Pattern")
542
+
543
+ st.markdown("</div>", unsafe_allow_html=True)
544
+
545
+ # Metrics
546
+ col_a, col_b, col_c = st.columns(3)
547
+ with col_a:
548
+ st.metric("Asphyxia Risk", f"{result['asphyxia_risk']:.1%}")
549
+ with col_b:
550
+ st.metric("Cry Type", result['cry_type'].title())
551
+ with col_c:
552
+ st.metric("F0 (Pitch)", f"{result['features']['f0_mean']:.0f} Hz")
553
+
554
+ # Recommendation
555
+ st.markdown("### Recommendation")
556
+ st.info(result["recommendation"])
557
+
558
+ # Acoustic features
559
+ with st.expander("Acoustic Features"):
560
+ st.json(result["features"])
561
+
562
+ except Exception as e:
563
+ st.error(f"Error analyzing audio: {e}")
564
+ finally:
565
+ _cleanup_temp(tmp_path)
566
+ else:
567
+ st.info("👆 Upload an audio file to begin analysis")
568
+
569
+
570
+ def render_combined_assessment():
571
+ """Render combined assessment interface using Clinical Synthesizer."""
572
+ st.header("Combined Clinical Assessment")
573
+ st.markdown(
574
+ f"Upload multiple inputs for a comprehensive assessment using **MedGemma Clinical Synthesizer**. "
575
+ f"This combines findings from all HAI-DEF models to provide integrated clinical recommendations. "
576
+ f'{_model_badge("MedSigLIP", "#388e3c")} '
577
+ f'{_model_badge("HeAR", "#f57c00")} '
578
+ f'{_model_badge("MedGemma", "#1976d2")}',
579
+ unsafe_allow_html=True,
580
+ )
581
+
582
+ # Reset findings each time this tab is rendered to prevent
583
+ # stale data from previous patients contaminating results
584
+ st.session_state.findings = {
585
+ "anemia": None,
586
+ "jaundice": None,
587
+ "cry": None
588
+ }
589
+
590
+ col1, col2, col3 = st.columns(3)
591
+
592
+ with col1:
593
+ st.subheader("🩸 Anemia Screening")
594
+ anemia_file = st.file_uploader(
595
+ "Conjunctiva image",
596
+ type=["jpg", "jpeg", "png"],
597
+ key="combined_anemia"
598
+ )
599
+ if anemia_file:
600
+ st.image(anemia_file, use_container_width=True)
601
+ with st.spinner("Analyzing..."):
602
+ try:
603
+ detector, load_err = load_anemia_detector()
604
+ if detector is None:
605
+ st.error(f"Model error: {load_err}")
606
+ else:
607
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
608
+ tmp.write(anemia_file.getvalue())
609
+ result = detector.detect(tmp.name)
610
+ st.session_state.findings["anemia"] = result
611
+ if result["is_anemic"]:
612
+ st.error(f"Anemia: {result['risk_level'].upper()}")
613
+ else:
614
+ st.success("No Anemia")
615
+ except Exception as e:
616
+ st.error(f"Error: {e}")
617
+
618
+ with col2:
619
+ st.subheader("👶 Jaundice Detection")
620
+ jaundice_file = st.file_uploader(
621
+ "Neonatal skin image",
622
+ type=["jpg", "jpeg", "png"],
623
+ key="combined_jaundice"
624
+ )
625
+ if jaundice_file:
626
+ st.image(jaundice_file, use_container_width=True)
627
+ with st.spinner("Analyzing..."):
628
+ try:
629
+ detector, load_err = load_jaundice_detector()
630
+ if detector is None:
631
+ st.error(f"Model error: {load_err}")
632
+ else:
633
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
634
+ tmp.write(jaundice_file.getvalue())
635
+ result = detector.detect(tmp.name)
636
+ st.session_state.findings["jaundice"] = result
637
+ if result["has_jaundice"]:
638
+ st.warning(f"Jaundice: {result['severity'].upper()}")
639
+ else:
640
+ st.success("No Jaundice")
641
+ except Exception as e:
642
+ st.error(f"Error: {e}")
643
+
644
+ with col3:
645
+ st.subheader("🔊 Cry Analysis")
646
+ cry_file = st.file_uploader(
647
+ "Cry audio",
648
+ type=["wav", "mp3", "ogg"],
649
+ key="combined_cry"
650
+ )
651
+ if cry_file:
652
+ st.audio(cry_file)
653
+ with st.spinner("Analyzing..."):
654
+ try:
655
+ analyzer, load_err = load_cry_analyzer()
656
+ if analyzer is None:
657
+ st.error(f"Model error: {load_err}")
658
+ raise RuntimeError(load_err)
659
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
660
+ tmp.write(cry_file.getvalue())
661
+ result = analyzer.analyze(tmp.name)
662
+ st.session_state.findings["cry"] = result
663
+ if result["is_abnormal"]:
664
+ st.error(f"Abnormal Cry: {result['risk_level'].upper()}")
665
+ else:
666
+ st.success("Normal Cry")
667
+ except Exception as e:
668
+ st.error(f"Error: {e}")
669
+
670
+ # Clinical Synthesis Section
671
+ st.markdown("---")
672
+ st.subheader("🏥 Clinical Synthesis (MedGemma)")
673
+
674
+ # Check if any findings are available
675
+ has_findings = any(v is not None for v in st.session_state.findings.values())
676
+
677
+ if has_findings:
678
+ if st.button("Generate Clinical Synthesis", type="primary"):
679
+ with st.spinner("Synthesizing findings with MedGemma..."):
680
+ try:
681
+ synthesizer, load_err = load_clinical_synthesizer()
682
+ if synthesizer is None:
683
+ st.error(f"Could not load synthesizer: {load_err}")
684
+ return
685
+
686
+ # Prepare findings dict
687
+ findings = {}
688
+ if st.session_state.findings["anemia"]:
689
+ findings["anemia"] = st.session_state.findings["anemia"]
690
+ if st.session_state.findings["jaundice"]:
691
+ findings["jaundice"] = st.session_state.findings["jaundice"]
692
+ if st.session_state.findings["cry"]:
693
+ findings["cry"] = st.session_state.findings["cry"]
694
+
695
+ synthesis = synthesizer.synthesize(findings)
696
+
697
+ # Display synthesis results
698
+ severity_level = synthesis.get("severity_level", "GREEN")
699
+ severity_colors = {
700
+ "GREEN": ("🟢", "#d4edda", "#155724"),
701
+ "YELLOW": ("🟡", "#fff3cd", "#856404"),
702
+ "RED": ("🔴", "#f8d7da", "#721c24")
703
+ }
704
+ emoji, bg_color, text_color = severity_colors.get(severity_level, ("⚪", "#f8f9fa", "#000"))
705
+
706
+ st.markdown(f"""
707
+ <div style="background-color: {bg_color}; padding: 1.5rem; border-radius: 10px; margin: 1rem 0;">
708
+ <h3 style="color: {text_color}; margin: 0;">{emoji} Severity: {severity_level}</h3>
709
+ <p style="color: {text_color}; font-size: 1.1rem; margin-top: 0.5rem;">{synthesis.get('severity_description', '')}</p>
710
+ </div>
711
+ """, unsafe_allow_html=True)
712
+
713
+ # Summary
714
+ st.markdown("### Summary")
715
+ st.info(synthesis.get("summary", "No summary available"))
716
+
717
+ # Actions
718
+ if synthesis.get("immediate_actions"):
719
+ st.markdown("### Immediate Actions")
720
+ for action in synthesis["immediate_actions"]:
721
+ st.markdown(f"- {action}")
722
+
723
+ # Referral
724
+ col_a, col_b = st.columns(2)
725
+ with col_a:
726
+ st.markdown("### Referral Status")
727
+ if synthesis.get("referral_needed"):
728
+ st.error(f"⚠️ REFERRAL NEEDED: {synthesis.get('referral_urgency', 'standard').upper()}")
729
+ else:
730
+ st.success("✅ No referral needed")
731
+
732
+ with col_b:
733
+ st.markdown("### Follow-up")
734
+ st.info(synthesis.get("follow_up", "Schedule routine follow-up"))
735
+
736
+ # Technical details
737
+ with st.expander("Technical Details"):
738
+ model_name = synthesis.get("model", "unknown")
739
+ st.json({
740
+ "model": model_name,
741
+ "model_id": synthesis.get("model_id", ""),
742
+ "generated_at": synthesis.get("generated_at"),
743
+ "urgent_conditions": synthesis.get("urgent_conditions", []),
744
+ })
745
+ if model_name and "Fallback" not in str(model_name):
746
+ st.success(f"Synthesis powered by {model_name}")
747
+ elif "Fallback" in str(model_name):
748
+ st.warning("Using rule-based fallback (MedGemma unavailable)")
749
+
750
+ except Exception as e:
751
+ st.error(f"Error generating synthesis: {e}")
752
+ else:
753
+ st.info("👆 Upload at least one input (image or audio) to generate clinical synthesis")
754
+
755
+
756
+ def render_hai_def_info():
757
+ """Render HAI-DEF models information."""
758
+ st.header("Google HAI-DEF Models")
759
+ st.markdown("""
760
+ NEXUS is built using **Google Health AI Developer Foundations (HAI-DEF)** models,
761
+ designed specifically for healthcare applications in resource-limited settings.
762
+ """)
763
+
764
+ hai_def = get_hai_def_info()
765
+
766
+ # MedSigLIP
767
+ st.markdown("---")
768
+ col1, col2 = st.columns([1, 2])
769
+ with col1:
770
+ st.markdown("### 🖼️ MedSigLIP")
771
+ st.info("google/medsiglip-448\n\nHAI-DEF Vision Model")
772
+ with col2:
773
+ info = hai_def["MedSigLIP"]
774
+ st.markdown(f"**Model**: {info['name']}")
775
+ st.markdown(f"**Use Case**: {info['use']}")
776
+ st.markdown(f"**Method**: {info['method']}")
777
+ st.markdown(f"**Validated Performance**: {info['accuracy']}")
778
+ st.markdown("""
779
+ MedSigLIP enables zero-shot medical image classification using
780
+ text prompts. NEXUS extends this with trained SVM/LR classifiers
781
+ on MedSigLIP embeddings (with data augmentation) for improved
782
+ accuracy, plus a novel 3-layer MLP regression head for continuous
783
+ bilirubin prediction from frozen embeddings.
784
+ """)
785
+
786
+ # HeAR
787
+ st.markdown("---")
788
+ col1, col2 = st.columns([1, 2])
789
+ with col1:
790
+ st.markdown("### 🔊 HeAR")
791
+ st.info("google/hear-pytorch\n\nHAI-DEF Audio Model")
792
+ with col2:
793
+ info = hai_def["HeAR"]
794
+ st.markdown(f"**Model**: {info['name']}")
795
+ st.markdown(f"**Use Case**: {info['use']}")
796
+ st.markdown(f"**Method**: {info['method']}")
797
+ st.markdown(f"**Validated Performance**: {info['accuracy']}")
798
+ st.markdown("""
799
+ HeAR (Health Acoustic Representations) produces 512-dim embeddings
800
+ from 2-second audio clips at 16kHz. NEXUS trains a linear classifier
801
+ on HeAR embeddings for 5-class cry type classification (hungry,
802
+ belly_pain, burping, discomfort, tired) and derives asphyxia risk
803
+ from distress patterns.
804
+ """)
805
+
806
+ # MedGemma
807
+ st.markdown("---")
808
+ col1, col2 = st.columns([1, 2])
809
+ with col1:
810
+ st.markdown("### 🧠 MedGemma")
811
+ st.info("google/medgemma-1.5-4b-it\n\nHAI-DEF Language Model")
812
+ with col2:
813
+ info = hai_def["MedGemma"]
814
+ st.markdown(f"**Model**: {info['name']}")
815
+ st.markdown(f"**Use Case**: {info['use']}")
816
+ st.markdown(f"**Method**: {info['method']}")
817
+ st.markdown(f"**Validated Performance**: {info['accuracy']}")
818
+ st.markdown("""
819
+ MedGemma 1.5 provides clinical reasoning capabilities via 4-bit NF4
820
+ quantized inference (~2 GB VRAM). It synthesizes multi-modal findings
821
+ into actionable recommendations following WHO IMNCI protocols,
822
+ producing structured reasoning chains within the 6-agent pipeline.
823
+ """)
824
+
825
+ # Competition Info
826
+ st.markdown("---")
827
+ st.subheader("🏆 MedGemma Impact Challenge 2026")
828
+ st.markdown("""
829
+ NEXUS is being developed for the [MedGemma Impact Challenge](https://www.kaggle.com/competitions/medgemma-impact-challenge-2026)
830
+ on Kaggle.
831
+
832
+ **Competition Focus**: Solutions for resource-limited healthcare settings using HAI-DEF models.
833
+
834
+ **NEXUS Impact**:
835
+ - 📍 Target: Sub-Saharan Africa and South Asia
836
+ - 👩‍⚕️ Users: Community Health Workers
837
+ - 🎯 Goals: Reduce maternal/neonatal mortality
838
+ - 📱 Deployment: Offline-capable mobile app
839
+ """)
840
+
841
+
842
+ def render_agentic_workflow():
843
+ """Render the agentic workflow interface with reasoning traces."""
844
+ st.header("Agentic Clinical Workflow")
845
+ st.markdown(
846
+ f"**6-Agent Pipeline** with step-by-step reasoning traces. "
847
+ f"Each agent explains its clinical decision process, providing a full audit trail. "
848
+ f'{_model_badge("MedSigLIP", "#388e3c")} '
849
+ f'{_model_badge("HeAR", "#f57c00")} '
850
+ f'{_model_badge("MedGemma", "#1976d2")}',
851
+ unsafe_allow_html=True,
852
+ )
853
+
854
+ # Pipeline diagram
855
+ st.markdown("""
856
+ <div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem; flex-wrap: wrap; margin: 1rem 0;">
857
+ <div style="background: #e3f2fd; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #1976d2;">Triage</div>
858
+ <span style="font-size: 1.5rem;">&#8594;</span>
859
+ <div style="background: #e8f5e9; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #388e3c;">Image (MedSigLIP)</div>
860
+ <span style="font-size: 1.5rem;">&#8594;</span>
861
+ <div style="background: #fff3e0; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #f57c00;">Audio (HeAR)</div>
862
+ <span style="font-size: 1.5rem;">&#8594;</span>
863
+ <div style="background: #f3e5f5; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #7b1fa2;">Protocol (WHO)</div>
864
+ <span style="font-size: 1.5rem;">&#8594;</span>
865
+ <div style="background: #fce4ec; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #c62828;">Referral</div>
866
+ <span style="font-size: 1.5rem;">&#8594;</span>
867
+ <div style="background: #e0f7fa; padding: 0.5rem 1rem; border-radius: 8px; font-weight: bold; border: 2px solid #00838f;">Synthesis (MedGemma)</div>
868
+ </div>
869
+ """, unsafe_allow_html=True)
870
+
871
+ st.markdown("---")
872
+
873
+ # Input section
874
+ col_left, col_right = st.columns([1, 1])
875
+
876
+ with col_left:
877
+ st.subheader("Patient & Inputs")
878
+ patient_type = st.selectbox("Patient Type", ["newborn", "pregnant"], key="agentic_patient")
879
+
880
+ # Danger signs
881
+ st.markdown("**Danger Signs**")
882
+ danger_signs = []
883
+ if patient_type == "pregnant":
884
+ sign_options = [
885
+ ("Severe headache", "high"),
886
+ ("Blurred vision", "high"),
887
+ ("Convulsions", "critical"),
888
+ ("Severe abdominal pain", "high"),
889
+ ("Vaginal bleeding", "critical"),
890
+ ("High fever", "high"),
891
+ ("Severe pallor", "medium"),
892
+ ]
893
+ else:
894
+ sign_options = [
895
+ ("Not breathing at birth", "critical"),
896
+ ("Convulsions", "critical"),
897
+ ("Severe chest indrawing", "high"),
898
+ ("Not feeding", "high"),
899
+ ("High fever (>38C)", "high"),
900
+ ("Hypothermia (<35.5C)", "high"),
901
+ ("Lethargy / unconscious", "critical"),
902
+ ("Umbilical redness", "medium"),
903
+ ]
904
+
905
+ selected_signs = st.multiselect(
906
+ "Select present danger signs",
907
+ [s[0] for s in sign_options],
908
+ key="agentic_signs"
909
+ )
910
+ for label, severity in sign_options:
911
+ if label in selected_signs:
912
+ danger_signs.append({
913
+ "id": label.lower().replace(" ", "_"),
914
+ "label": label,
915
+ "severity": severity,
916
+ "present": True,
917
+ })
918
+
919
+ # Image uploads
920
+ st.markdown("**Clinical Images**")
921
+ conjunctiva_file = st.file_uploader(
922
+ "Conjunctiva image (anemia)", type=["jpg", "jpeg", "png"],
923
+ key="agentic_conjunctiva"
924
+ )
925
+ skin_file = st.file_uploader(
926
+ "Skin image (jaundice)", type=["jpg", "jpeg", "png"],
927
+ key="agentic_skin"
928
+ )
929
+ cry_file = st.file_uploader(
930
+ "Cry audio", type=["wav", "mp3", "ogg"],
931
+ key="agentic_cry"
932
+ )
933
+
934
+ with col_right:
935
+ st.subheader("Workflow Execution")
936
+
937
+ if st.button("Run Agentic Assessment", type="primary", key="run_agentic"):
938
+ with st.spinner("Running 6-agent workflow..."):
939
+ try:
940
+ from nexus.agentic_workflow import (
941
+ AgenticWorkflowEngine,
942
+ AgentPatientInfo,
943
+ DangerSign,
944
+ WorkflowInput,
945
+ )
946
+
947
+ # Save uploaded files (track for cleanup)
948
+ _temp_paths = []
949
+ conjunctiva_path = None
950
+ skin_path = None
951
+ cry_path = None
952
+
953
+ if conjunctiva_file:
954
+ conjunctiva_path = _save_upload_to_temp(conjunctiva_file, ".jpg")
955
+ _temp_paths.append(conjunctiva_path)
956
+
957
+ if skin_file:
958
+ skin_path = _save_upload_to_temp(skin_file, ".jpg")
959
+ _temp_paths.append(skin_path)
960
+
961
+ if cry_file:
962
+ cry_path = _save_upload_to_temp(cry_file, ".wav")
963
+ _temp_paths.append(cry_path)
964
+
965
+ # Build workflow input
966
+ signs = [
967
+ DangerSign(
968
+ id=s["id"], label=s["label"],
969
+ severity=s["severity"], present=True,
970
+ )
971
+ for s in danger_signs
972
+ ]
973
+
974
+ info = AgentPatientInfo(patient_type=patient_type)
975
+ workflow_input = WorkflowInput(
976
+ patient_type=patient_type,
977
+ patient_info=info,
978
+ danger_signs=signs,
979
+ conjunctiva_image=conjunctiva_path,
980
+ skin_image=skin_path,
981
+ cry_audio=cry_path,
982
+ )
983
+
984
+ # Run workflow — reuse cached model instances when available
985
+ anemia_det, _ = load_anemia_detector()
986
+ jaundice_det, _ = load_jaundice_detector()
987
+ cry_ana, _ = load_cry_analyzer()
988
+ synth, _ = load_clinical_synthesizer()
989
+
990
+ engine = AgenticWorkflowEngine(
991
+ anemia_detector=anemia_det,
992
+ jaundice_detector=jaundice_det,
993
+ cry_analyzer=cry_ana,
994
+ synthesizer=synth,
995
+ )
996
+ result = engine.execute(workflow_input)
997
+
998
+ st.session_state["agentic_result"] = result
999
+ st.success("Workflow complete!")
1000
+
1001
+ except Exception as e:
1002
+ st.error(f"Workflow error: {e}")
1003
+ finally:
1004
+ for p in _temp_paths:
1005
+ _cleanup_temp(p)
1006
+
1007
+ # Results display
1008
+ if "agentic_result" in st.session_state:
1009
+ result = st.session_state["agentic_result"]
1010
+
1011
+ st.markdown("---")
1012
+
1013
+ # Overall classification
1014
+ severity_colors = {
1015
+ "GREEN": ("#d4edda", "#155724", "Routine care"),
1016
+ "YELLOW": ("#fff3cd", "#856404", "Close monitoring"),
1017
+ "RED": ("#f8d7da", "#721c24", "Urgent referral"),
1018
+ }
1019
+ bg, fg, desc = severity_colors.get(result.who_classification, ("#f8f9fa", "#000", "Unknown"))
1020
+
1021
+ st.markdown(f"""
1022
+ <div style="background: {bg}; color: {fg}; padding: 1.5rem; border-radius: 10px; text-align: center; margin: 1rem 0;">
1023
+ <h2 style="margin: 0;">WHO Classification: {result.who_classification}</h2>
1024
+ <p style="margin: 0.5rem 0 0 0; font-size: 1.1rem;">{desc}</p>
1025
+ </div>
1026
+ """, unsafe_allow_html=True)
1027
+
1028
+ # Key metrics
1029
+ m1, m2, m3, m4 = st.columns(4)
1030
+ with m1:
1031
+ st.metric("Agents Run", len(result.agent_traces))
1032
+ with m2:
1033
+ st.metric("Total Time", f"{result.processing_time_ms:.0f} ms")
1034
+ with m3:
1035
+ referral_text = "Yes" if (result.referral_result and result.referral_result.referral_needed) else "No"
1036
+ st.metric("Referral Needed", referral_text)
1037
+ with m4:
1038
+ triage_score = result.triage_result.score if result.triage_result else 0
1039
+ st.metric("Triage Score", triage_score)
1040
+
1041
+ # Clinical synthesis
1042
+ st.subheader("Clinical Synthesis")
1043
+ st.info(result.clinical_synthesis)
1044
+
1045
+ if result.immediate_actions:
1046
+ st.subheader("Immediate Actions")
1047
+ for action in result.immediate_actions:
1048
+ st.markdown(f"- {action}")
1049
+
1050
+ # Visual pipeline flow with status indicators
1051
+ st.markdown("---")
1052
+ st.subheader("Agent Pipeline Execution")
1053
+
1054
+ agent_meta = {
1055
+ "TriageAgent": {"color": "#1976d2", "bg": "#e3f2fd", "icon": "1", "label": "Triage"},
1056
+ "ImageAnalysisAgent": {"color": "#388e3c", "bg": "#e8f5e9", "icon": "2", "label": "Image (MedSigLIP)"},
1057
+ "AudioAnalysisAgent": {"color": "#f57c00", "bg": "#fff3e0", "icon": "3", "label": "Audio (HeAR)"},
1058
+ "ProtocolAgent": {"color": "#7b1fa2", "bg": "#f3e5f5", "icon": "4", "label": "WHO Protocol"},
1059
+ "ReferralAgent": {"color": "#c62828", "bg": "#fce4ec", "icon": "5", "label": "Referral"},
1060
+ "SynthesisAgent": {"color": "#00838f", "bg": "#e0f7fa", "icon": "6", "label": "Synthesis (MedGemma)"},
1061
+ }
1062
+ status_symbols = {"success": "OK", "skipped": "SKIP", "error": "ERR"}
1063
+
1064
+ # Build trace lookup
1065
+ trace_lookup = {t.agent_name: t for t in result.agent_traces}
1066
+
1067
+ # Pipeline status bar
1068
+ pipeline_html_parts = []
1069
+ for agent_name, meta in agent_meta.items():
1070
+ trace = trace_lookup.get(agent_name)
1071
+ if trace:
1072
+ status_sym = status_symbols.get(trace.status, "?")
1073
+ opacity = "1.0" if trace.status == "success" else "0.5"
1074
+ border_style = f"3px solid {meta['color']}" if trace.status == "success" else "2px dashed #999"
1075
+ time_label = f"{trace.processing_time_ms:.0f}ms"
1076
+ else:
1077
+ status_sym = "---"
1078
+ opacity = "0.3"
1079
+ border_style = "2px dashed #ccc"
1080
+ time_label = ""
1081
+
1082
+ pipeline_html_parts.append(f"""
1083
+ <div style="background: {meta['bg']}; padding: 0.4rem 0.7rem; border-radius: 8px;
1084
+ border: {border_style}; opacity: {opacity}; text-align: center; min-width: 90px;">
1085
+ <div style="font-weight: bold; font-size: 0.8rem; color: {meta['color']};">{meta['label']}</div>
1086
+ <div style="font-size: 0.7rem; color: #666;">{status_sym} {time_label}</div>
1087
+ </div>
1088
+ """)
1089
+
1090
+ pipeline_html = '<div style="display: flex; align-items: center; justify-content: center; gap: 0.3rem; flex-wrap: wrap; margin: 0.5rem 0;">'
1091
+ for i, part in enumerate(pipeline_html_parts):
1092
+ pipeline_html += part
1093
+ if i < len(pipeline_html_parts) - 1:
1094
+ pipeline_html += '<span style="font-size: 1.2rem; color: #999;">&#8594;</span>'
1095
+ pipeline_html += "</div>"
1096
+ st.markdown(pipeline_html, unsafe_allow_html=True)
1097
+
1098
+ # Agent reasoning traces (key feature for Agentic Workflow prize)
1099
+ st.markdown("---")
1100
+ st.subheader("Agent Reasoning Traces")
1101
+
1102
+ for trace in result.agent_traces:
1103
+ meta = agent_meta.get(trace.agent_name, {"color": "#666", "bg": "#f5f5f5", "label": trace.agent_name})
1104
+ status_emoji = {"success": "OK", "skipped": "SKIP", "error": "ERR"}.get(trace.status, "?")
1105
+
1106
+ header_label = f"{meta['label']} [{status_emoji}] - {trace.confidence:.0%} confidence - {trace.processing_time_ms:.0f}ms"
1107
+ with st.expander(header_label, expanded=(trace.status == "success")):
1108
+ # Status bar
1109
+ st.markdown(f"""
1110
+ <div style="background: {meta['bg']}; padding: 0.8rem 1rem; border-radius: 8px;
1111
+ border-left: 4px solid {meta['color']}; margin-bottom: 0.5rem;">
1112
+ <strong style="color: {meta['color']};">{trace.agent_name}</strong> &nbsp;|&nbsp;
1113
+ Status: <strong>{trace.status}</strong> &nbsp;|&nbsp;
1114
+ Confidence: <strong>{trace.confidence:.1%}</strong> &nbsp;|&nbsp;
1115
+ Time: <strong>{trace.processing_time_ms:.1f}ms</strong>
1116
+ </div>
1117
+ """, unsafe_allow_html=True)
1118
+
1119
+ # Reasoning steps with numbered styling
1120
+ if trace.reasoning:
1121
+ st.markdown("**Reasoning Chain:**")
1122
+ for i, step in enumerate(trace.reasoning, 1):
1123
+ st.markdown(f"**Step {i}.** {step}")
1124
+
1125
+ # Key findings
1126
+ if trace.findings:
1127
+ st.markdown("**Key Findings:**")
1128
+ st.json(trace.findings)
1129
+
1130
+ # Processing time breakdown
1131
+ st.markdown("---")
1132
+ col_chart, col_summary = st.columns([2, 1])
1133
+
1134
+ with col_chart:
1135
+ st.subheader("Processing Time by Agent")
1136
+ import pandas as pd
1137
+ chart_data = pd.DataFrame({
1138
+ "Agent": [agent_meta.get(t.agent_name, {}).get("label", t.agent_name) for t in result.agent_traces],
1139
+ "Time (ms)": [t.processing_time_ms for t in result.agent_traces],
1140
+ })
1141
+ st.bar_chart(chart_data.set_index("Agent"))
1142
+
1143
+ with col_summary:
1144
+ st.subheader("Workflow Summary")
1145
+ total_time = result.processing_time_ms
1146
+ successful = sum(1 for t in result.agent_traces if t.status == "success")
1147
+ skipped = sum(1 for t in result.agent_traces if t.status == "skipped")
1148
+ errors = sum(1 for t in result.agent_traces if t.status == "error")
1149
+ st.markdown(f"""
1150
+ | Metric | Value |
1151
+ |--------|-------|
1152
+ | Total agents | {len(result.agent_traces)} |
1153
+ | Successful | {successful} |
1154
+ | Skipped | {skipped} |
1155
+ | Errors | {errors} |
1156
+ | Total time | {total_time:.0f} ms |
1157
+ | Avg per agent | {total_time / max(len(result.agent_traces), 1):.0f} ms |
1158
+ """)
1159
+
1160
+ # Referral details
1161
+ if result.referral_result and result.referral_result.referral_needed:
1162
+ st.markdown("---")
1163
+ st.subheader("Referral Details")
1164
+ ref = result.referral_result
1165
+ r1, r2, r3 = st.columns(3)
1166
+ with r1:
1167
+ st.metric("Urgency", ref.urgency.upper())
1168
+ with r2:
1169
+ st.metric("Facility", ref.facility_level.title())
1170
+ with r3:
1171
+ st.metric("Timeframe", ref.timeframe)
1172
+ st.warning(f"Reason: {ref.reason}")
1173
+
1174
+
1175
+ # Footer
1176
+ def render_footer():
1177
+ """Render footer."""
1178
+ st.markdown("---")
1179
+ st.markdown("""
1180
+ <div style="text-align: center; color: #666; font-size: 0.9rem;">
1181
+ <p>NEXUS - Built with Google HAI-DEF for MedGemma Impact Challenge 2026</p>
1182
+ <p>⚠️ This is a screening tool only. Always confirm with laboratory tests.</p>
1183
+ </div>
1184
+ """, unsafe_allow_html=True)
1185
+
1186
+
1187
+ if __name__ == "__main__":
1188
+ main()
1189
+ render_footer()
src/nexus/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NEXUS - AI-Powered Maternal-Neonatal Care Platform
3
+
4
+ This package provides AI-powered diagnostic tools for:
5
+ - Maternal anemia detection via conjunctiva imaging
6
+ - Neonatal jaundice assessment via skin/sclera imaging
7
+ - Birth asphyxia screening via cry audio analysis
8
+ """
9
+
10
+ __version__ = "0.1.0"
src/nexus/agentic_workflow.py ADDED
@@ -0,0 +1,1296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agentic Clinical Workflow Engine
3
+
4
+ Multi-agent system for comprehensive maternal-neonatal assessments.
5
+ Mirrors the TypeScript architecture in mobile/src/services/agenticWorkflow.ts
6
+ but adds structured reasoning traces for explainability.
7
+
8
+ 6 Agents:
9
+ - TriageAgent: Initial danger sign screening (rules-based)
10
+ - ImageAnalysisAgent: MedSigLIP-powered anemia/jaundice detection
11
+ - AudioAnalysisAgent: HeAR-powered cry/asphyxia analysis
12
+ - ProtocolAgent: WHO IMNCI classification (rules-based)
13
+ - ReferralAgent: Urgency routing and referral decision (rules-based)
14
+ - SynthesisAgent: MedGemma clinical reasoning with full agent context
15
+
16
+ HAI-DEF Models Used:
17
+ - MedSigLIP (google/medsiglip-448) via ImageAnalysisAgent
18
+ - HeAR (google/hear-pytorch) via AudioAnalysisAgent
19
+ - MedGemma (google/medgemma-4b-it) via SynthesisAgent
20
+ """
21
+
22
+ import time
23
+ from dataclasses import dataclass, field
24
+ from datetime import datetime
25
+ from pathlib import Path
26
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Data Types
31
+ # ---------------------------------------------------------------------------
32
+
33
+ PatientType = Literal["pregnant", "newborn"]
34
+ SeverityLevel = Literal["RED", "YELLOW", "GREEN"]
35
+ AgentStatus = Literal["success", "skipped", "error"]
36
+ WorkflowState = Literal[
37
+ "idle",
38
+ "triaging",
39
+ "analyzing_image",
40
+ "analyzing_audio",
41
+ "applying_protocol",
42
+ "determining_referral",
43
+ "synthesizing",
44
+ "complete",
45
+ "error",
46
+ ]
47
+
48
+
49
+ @dataclass
50
+ class DangerSign:
51
+ """A clinical danger sign observed during triage."""
52
+ id: str
53
+ label: str
54
+ severity: Literal["critical", "high", "medium"]
55
+ present: bool = False
56
+
57
+
58
+ @dataclass
59
+ class AgentPatientInfo:
60
+ """Patient information for workflow context."""
61
+ patient_id: str = ""
62
+ patient_type: PatientType = "newborn"
63
+ gestational_weeks: Optional[int] = None
64
+ gravida: Optional[int] = None
65
+ para: Optional[int] = None
66
+ age_hours: Optional[int] = None
67
+ birth_weight: Optional[int] = None
68
+ delivery_type: Optional[str] = None
69
+ apgar_score: Optional[int] = None
70
+ gestational_age_at_birth: Optional[int] = None
71
+
72
+
73
+ @dataclass
74
+ class AgentResult:
75
+ """Structured output from a single agent with reasoning trace."""
76
+ agent_name: str
77
+ status: AgentStatus
78
+ reasoning: List[str] = field(default_factory=list)
79
+ findings: Dict[str, Any] = field(default_factory=dict)
80
+ confidence: float = 0.0
81
+ processing_time_ms: float = 0.0
82
+
83
+
84
+ @dataclass
85
+ class TriageResult:
86
+ """Output from TriageAgent."""
87
+ risk_level: SeverityLevel = "GREEN"
88
+ critical_signs_detected: bool = False
89
+ critical_signs: List[str] = field(default_factory=list)
90
+ immediate_referral_needed: bool = False
91
+ score: int = 0
92
+
93
+
94
+ @dataclass
95
+ class ImageAnalysisResult:
96
+ """Output from ImageAnalysisAgent."""
97
+ anemia: Optional[Dict[str, Any]] = None
98
+ jaundice: Optional[Dict[str, Any]] = None
99
+
100
+
101
+ @dataclass
102
+ class AudioAnalysisResult:
103
+ """Output from AudioAnalysisAgent."""
104
+ cry: Optional[Dict[str, Any]] = None
105
+
106
+
107
+ @dataclass
108
+ class ProtocolResult:
109
+ """Output from ProtocolAgent."""
110
+ classification: SeverityLevel = "GREEN"
111
+ applicable_protocols: List[str] = field(default_factory=list)
112
+ treatment_recommendations: List[str] = field(default_factory=list)
113
+ follow_up_schedule: str = ""
114
+
115
+
116
+ @dataclass
117
+ class ReferralResult:
118
+ """Output from ReferralAgent."""
119
+ referral_needed: bool = False
120
+ urgency: Literal["immediate", "urgent", "routine", "none"] = "none"
121
+ facility_level: Literal["primary", "secondary", "tertiary"] = "primary"
122
+ reason: str = "No referral required"
123
+ timeframe: str = "Not applicable"
124
+
125
+
126
+ @dataclass
127
+ class WorkflowInput:
128
+ """Input to the agentic workflow."""
129
+ patient_type: PatientType
130
+ patient_info: AgentPatientInfo = field(default_factory=AgentPatientInfo)
131
+ danger_signs: List[DangerSign] = field(default_factory=list)
132
+ conjunctiva_image: Optional[Union[str, Path]] = None
133
+ skin_image: Optional[Union[str, Path]] = None
134
+ cry_audio: Optional[Union[str, Path]] = None
135
+ additional_notes: str = ""
136
+
137
+
138
+ @dataclass
139
+ class WorkflowResult:
140
+ """Complete workflow output with all agent results and audit trail."""
141
+ success: bool = False
142
+ patient_type: PatientType = "newborn"
143
+ who_classification: SeverityLevel = "GREEN"
144
+
145
+ # Individual agent outputs
146
+ triage_result: Optional[TriageResult] = None
147
+ image_results: Optional[ImageAnalysisResult] = None
148
+ audio_results: Optional[AudioAnalysisResult] = None
149
+ protocol_result: Optional[ProtocolResult] = None
150
+ referral_result: Optional[ReferralResult] = None
151
+
152
+ # Synthesis
153
+ clinical_synthesis: str = ""
154
+ recommendation: str = ""
155
+ immediate_actions: List[str] = field(default_factory=list)
156
+
157
+ # Audit trail
158
+ agent_traces: List[AgentResult] = field(default_factory=list)
159
+ processing_time_ms: float = 0.0
160
+ timestamp: str = ""
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Individual Agents
165
+ # ---------------------------------------------------------------------------
166
+
167
+ class TriageAgent:
168
+ """
169
+ Initial risk stratification based on danger signs, patient info, and
170
+ clinical decision tree logic.
171
+
172
+ Decision tree considers:
173
+ - Danger sign severity and combinations
174
+ - Patient demographics (age, weight, gestational age)
175
+ - Comorbidity patterns (multiple conditions increase risk)
176
+ - Time-sensitive factors (e.g., jaundice < 24hrs = always RED)
177
+ """
178
+
179
+ def process(
180
+ self,
181
+ patient_type: PatientType,
182
+ danger_signs: List[DangerSign],
183
+ patient_info: AgentPatientInfo,
184
+ ) -> tuple[TriageResult, AgentResult]:
185
+ start = time.time()
186
+ reasoning: List[str] = []
187
+ score = 0
188
+ critical_signs: List[str] = []
189
+ risk_modifiers: List[str] = []
190
+
191
+ reasoning.append(f"[STEP 1/5] Initiating clinical triage for {patient_type} patient")
192
+
193
+ # Step 1: Evaluate danger signs with clinical context
194
+ present_signs = [s for s in danger_signs if s.present]
195
+ reasoning.append(f"[STEP 2/5] Evaluating {len(present_signs)} present danger signs out of {len(danger_signs)} assessed")
196
+
197
+ for sign in present_signs:
198
+ if sign.severity == "critical":
199
+ score += 30
200
+ critical_signs.append(sign.label)
201
+ reasoning.append(f" CRITICAL: '{sign.label}' detected — per WHO IMNCI this requires immediate action (+30)")
202
+ elif sign.severity == "high":
203
+ score += 15
204
+ reasoning.append(f" HIGH: '{sign.label}' detected — warrants close monitoring (+15)")
205
+ elif sign.severity == "medium":
206
+ score += 5
207
+ reasoning.append(f" MEDIUM: '{sign.label}' detected — noted for assessment (+5)")
208
+
209
+ # Comorbidity check: multiple conditions compound risk
210
+ if len(present_signs) >= 3:
211
+ combo_bonus = 10
212
+ score += combo_bonus
213
+ risk_modifiers.append(f"Multiple danger signs ({len(present_signs)}) present simultaneously")
214
+ reasoning.append(f" COMORBIDITY: {len(present_signs)} danger signs present — compounding risk (+{combo_bonus})")
215
+
216
+ # Step 2: Patient-specific demographic risk assessment
217
+ reasoning.append(f"[STEP 3/5] Assessing demographic risk factors")
218
+
219
+ if patient_type == "pregnant":
220
+ if patient_info.gestational_weeks is not None:
221
+ ga = patient_info.gestational_weeks
222
+ if ga < 28:
223
+ score += 15
224
+ risk_modifiers.append(f"Extreme preterm ({ga} weeks)")
225
+ reasoning.append(f" Extreme preterm: GA={ga} weeks (<28) — high risk for complications (+15)")
226
+ elif ga < 37:
227
+ score += 5
228
+ risk_modifiers.append(f"Preterm ({ga} weeks)")
229
+ reasoning.append(f" Preterm: GA={ga} weeks (28-36) — moderate risk (+5)")
230
+ elif ga > 42:
231
+ score += 15
232
+ risk_modifiers.append(f"Post-term ({ga} weeks)")
233
+ reasoning.append(f" Post-term: GA={ga} weeks (>42) — risk of placental insufficiency (+15)")
234
+ else:
235
+ reasoning.append(f" Gestational age {ga} weeks — within normal range (37-42)")
236
+ if patient_info.gravida is not None and patient_info.gravida >= 5:
237
+ score += 5
238
+ risk_modifiers.append(f"Grand multigravida (G{patient_info.gravida})")
239
+ reasoning.append(f" Grand multigravida: G{patient_info.gravida} — increased obstetric risk (+5)")
240
+
241
+ elif patient_type == "newborn":
242
+ if patient_info.birth_weight is not None:
243
+ bw = patient_info.birth_weight
244
+ if bw < 1500:
245
+ score += 20
246
+ risk_modifiers.append(f"Very low birth weight ({bw}g)")
247
+ reasoning.append(f" Very low birth weight: {bw}g (<1500g) — high neonatal risk (+20)")
248
+ elif bw < 2500:
249
+ score += 10
250
+ risk_modifiers.append(f"Low birth weight ({bw}g)")
251
+ reasoning.append(f" Low birth weight: {bw}g (<2500g) — moderate risk (+10)")
252
+ else:
253
+ reasoning.append(f" Birth weight {bw}g — within normal range")
254
+
255
+ if patient_info.apgar_score is not None:
256
+ apgar = patient_info.apgar_score
257
+ if apgar < 4:
258
+ score += 25
259
+ risk_modifiers.append(f"Severe depression (APGAR {apgar})")
260
+ reasoning.append(f" Severe neonatal depression: APGAR={apgar} (<4) — requires resuscitation (+25)")
261
+ elif apgar < 7:
262
+ score += 15
263
+ risk_modifiers.append(f"Moderate depression (APGAR {apgar})")
264
+ reasoning.append(f" Moderate neonatal depression: APGAR={apgar} (<7) — close monitoring needed (+15)")
265
+ else:
266
+ reasoning.append(f" APGAR score {apgar} — within normal range")
267
+
268
+ if patient_info.age_hours is not None:
269
+ age = patient_info.age_hours
270
+ if age < 6:
271
+ score += 10
272
+ risk_modifiers.append(f"Critical neonatal period ({age}h)")
273
+ reasoning.append(f" Critical neonatal period: {age} hours old — highest vulnerability window (+10)")
274
+ elif age < 24:
275
+ score += 5
276
+ reasoning.append(f" First day of life: {age} hours — increased monitoring needed (+5)")
277
+
278
+ if patient_info.gestational_age_at_birth is not None and patient_info.gestational_age_at_birth < 37:
279
+ score += 10
280
+ risk_modifiers.append(f"Premature birth ({patient_info.gestational_age_at_birth} weeks)")
281
+ reasoning.append(f" Premature birth at {patient_info.gestational_age_at_birth} weeks — increased susceptibility (+10)")
282
+
283
+ # Step 3: Clinical decision tree
284
+ reasoning.append(f"[STEP 4/5] Applying clinical decision tree")
285
+
286
+ if score >= 30 or len(critical_signs) > 0:
287
+ risk_level: SeverityLevel = "RED"
288
+ reasoning.append(f" Decision: RED classification — score={score}, critical signs={len(critical_signs)}")
289
+ elif score >= 15:
290
+ risk_level = "YELLOW"
291
+ reasoning.append(f" Decision: YELLOW classification — score={score}, monitoring required")
292
+ else:
293
+ risk_level = "GREEN"
294
+ reasoning.append(f" Decision: GREEN classification — score={score}, routine care")
295
+
296
+ critical_detected = len(critical_signs) > 0
297
+ immediate_referral = risk_level == "RED" and critical_detected
298
+
299
+ # Step 4: Summary with clinical rationale
300
+ reasoning.append(f"[STEP 5/5] Triage conclusion")
301
+ reasoning.append(f" Total triage score: {score}")
302
+ reasoning.append(f" Risk classification: {risk_level} ({self._risk_rationale(risk_level)})")
303
+ if risk_modifiers:
304
+ reasoning.append(f" Risk modifiers: {'; '.join(risk_modifiers)}")
305
+ if immediate_referral:
306
+ reasoning.append(" DECISION: IMMEDIATE REFERRAL REQUIRED — critical danger signs with RED classification")
307
+ elif risk_level == "RED":
308
+ reasoning.append(" DECISION: URGENT referral recommended — RED classification without critical signs")
309
+
310
+ elapsed = (time.time() - start) * 1000
311
+
312
+ result = TriageResult(
313
+ risk_level=risk_level,
314
+ critical_signs_detected=critical_detected,
315
+ critical_signs=critical_signs,
316
+ immediate_referral_needed=immediate_referral,
317
+ score=score,
318
+ )
319
+
320
+ trace = AgentResult(
321
+ agent_name="TriageAgent",
322
+ status="success",
323
+ reasoning=reasoning,
324
+ findings={
325
+ "risk_level": risk_level,
326
+ "score": score,
327
+ "critical_signs": critical_signs,
328
+ "risk_modifiers": risk_modifiers,
329
+ "immediate_referral": immediate_referral,
330
+ },
331
+ confidence=1.0,
332
+ processing_time_ms=elapsed,
333
+ )
334
+
335
+ return result, trace
336
+
337
+ @staticmethod
338
+ def _risk_rationale(level: str) -> str:
339
+ return {
340
+ "RED": "immediate intervention required per WHO IMNCI",
341
+ "YELLOW": "close monitoring with 24-48h follow-up",
342
+ "GREEN": "routine care with standard follow-up schedule",
343
+ }.get(level, "")
344
+
345
+
346
+ class ImageAnalysisAgent:
347
+ """
348
+ Visual analysis using MedSigLIP for anemia and jaundice detection.
349
+
350
+ HAI-DEF Model: MedSigLIP (google/medsiglip-448)
351
+ Reuses existing AnemiaDetector and JaundiceDetector instances.
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ anemia_detector: Optional[Any] = None,
357
+ jaundice_detector: Optional[Any] = None,
358
+ ):
359
+ self._anemia_detector = anemia_detector
360
+ self._jaundice_detector = jaundice_detector
361
+
362
+ def _get_anemia_detector(self) -> Any:
363
+ if self._anemia_detector is None:
364
+ from .anemia_detector import AnemiaDetector
365
+ self._anemia_detector = AnemiaDetector()
366
+ return self._anemia_detector
367
+
368
+ def _get_jaundice_detector(self) -> Any:
369
+ if self._jaundice_detector is None:
370
+ from .jaundice_detector import JaundiceDetector
371
+ self._jaundice_detector = JaundiceDetector()
372
+ return self._jaundice_detector
373
+
374
+ def process(
375
+ self,
376
+ patient_type: PatientType,
377
+ conjunctiva_image: Optional[Union[str, Path]] = None,
378
+ skin_image: Optional[Union[str, Path]] = None,
379
+ ) -> tuple[ImageAnalysisResult, AgentResult]:
380
+ start = time.time()
381
+ reasoning: List[str] = []
382
+ result = ImageAnalysisResult()
383
+ confidence_scores: List[float] = []
384
+
385
+ reasoning.append(f"Starting image analysis for {patient_type} patient")
386
+
387
+ # Anemia screening (both maternal and newborn)
388
+ if conjunctiva_image:
389
+ reasoning.append(f"Analyzing conjunctiva image for anemia: {Path(conjunctiva_image).name}")
390
+ try:
391
+ detector = self._get_anemia_detector()
392
+ anemia_result = detector.detect(conjunctiva_image)
393
+ result.anemia = anemia_result
394
+ conf = anemia_result.get("confidence", 0)
395
+ confidence_scores.append(conf)
396
+
397
+ if anemia_result.get("is_anemic"):
398
+ reasoning.append(
399
+ f"ANEMIA DETECTED: confidence={conf:.1%}, "
400
+ f"risk_level={anemia_result.get('risk_level', 'unknown')}"
401
+ )
402
+ else:
403
+ reasoning.append(f"No anemia detected (confidence={conf:.1%})")
404
+
405
+ reasoning.append(f"Model used: {anemia_result.get('model', 'MedSigLIP')}")
406
+ except Exception as e:
407
+ reasoning.append(f"Anemia analysis failed: {e}")
408
+ result.anemia = {
409
+ "is_anemic": False,
410
+ "confidence": 0.0,
411
+ "risk_level": "low",
412
+ "recommendation": "Analysis failed - please retry",
413
+ "anemia_score": 0.0,
414
+ "healthy_score": 0.0,
415
+ "model": "error",
416
+ }
417
+ else:
418
+ reasoning.append("No conjunctiva image provided - skipping anemia screening")
419
+
420
+ # Jaundice detection (newborn or if skin image provided)
421
+ if skin_image:
422
+ reasoning.append(f"Analyzing skin image for jaundice: {Path(skin_image).name}")
423
+ try:
424
+ detector = self._get_jaundice_detector()
425
+ jaundice_result = detector.detect(skin_image)
426
+ result.jaundice = jaundice_result
427
+ conf = jaundice_result.get("confidence", 0)
428
+ confidence_scores.append(conf)
429
+
430
+ if jaundice_result.get("has_jaundice"):
431
+ reasoning.append(
432
+ f"JAUNDICE DETECTED: severity={jaundice_result.get('severity', 'unknown')}, "
433
+ f"estimated bilirubin={jaundice_result.get('estimated_bilirubin', 'N/A')} mg/dL, "
434
+ f"phototherapy={'needed' if jaundice_result.get('needs_phototherapy') else 'not needed'}"
435
+ )
436
+ else:
437
+ reasoning.append(f"No significant jaundice detected (confidence={conf:.1%})")
438
+
439
+ reasoning.append(f"Model used: {jaundice_result.get('model', 'MedSigLIP')}")
440
+ except Exception as e:
441
+ reasoning.append(f"Jaundice analysis failed: {e}")
442
+ result.jaundice = {
443
+ "has_jaundice": False,
444
+ "confidence": 0.0,
445
+ "severity": "none",
446
+ "estimated_bilirubin": 0.0,
447
+ "needs_phototherapy": False,
448
+ "recommendation": "Analysis failed - please retry",
449
+ "model": "error",
450
+ }
451
+ else:
452
+ reasoning.append("No skin image provided - skipping jaundice detection")
453
+
454
+ has_findings = result.anemia is not None or result.jaundice is not None
455
+ elapsed = (time.time() - start) * 1000
456
+ avg_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0.0
457
+
458
+ trace = AgentResult(
459
+ agent_name="ImageAnalysisAgent",
460
+ status="success" if has_findings else "skipped",
461
+ reasoning=reasoning,
462
+ findings={
463
+ "anemia_detected": result.anemia.get("is_anemic", False) if result.anemia else None,
464
+ "jaundice_detected": result.jaundice.get("has_jaundice", False) if result.jaundice else None,
465
+ },
466
+ confidence=avg_confidence,
467
+ processing_time_ms=elapsed,
468
+ )
469
+
470
+ return result, trace
471
+
472
+
473
+ class AudioAnalysisAgent:
474
+ """
475
+ Acoustic analysis using HeAR for cry pattern and asphyxia detection.
476
+
477
+ HAI-DEF Model: HeAR (google/hear-pytorch)
478
+ Reuses existing CryAnalyzer instance.
479
+ """
480
+
481
+ def __init__(self, cry_analyzer: Optional[Any] = None):
482
+ self._cry_analyzer = cry_analyzer
483
+
484
+ def _get_cry_analyzer(self) -> Any:
485
+ if self._cry_analyzer is None:
486
+ from .cry_analyzer import CryAnalyzer
487
+ self._cry_analyzer = CryAnalyzer()
488
+ return self._cry_analyzer
489
+
490
+ def process(
491
+ self,
492
+ cry_audio: Optional[Union[str, Path]] = None,
493
+ ) -> tuple[AudioAnalysisResult, AgentResult]:
494
+ start = time.time()
495
+ reasoning: List[str] = []
496
+ result = AudioAnalysisResult()
497
+
498
+ if not cry_audio:
499
+ reasoning.append("No cry audio provided - skipping audio analysis")
500
+ elapsed = (time.time() - start) * 1000
501
+ trace = AgentResult(
502
+ agent_name="AudioAnalysisAgent",
503
+ status="skipped",
504
+ reasoning=reasoning,
505
+ findings={},
506
+ confidence=0.0,
507
+ processing_time_ms=elapsed,
508
+ )
509
+ return result, trace
510
+
511
+ reasoning.append(f"Analyzing cry audio: {Path(cry_audio).name}")
512
+
513
+ try:
514
+ analyzer = self._get_cry_analyzer()
515
+ cry_result = analyzer.analyze(cry_audio)
516
+ result.cry = cry_result
517
+
518
+ risk = cry_result.get("asphyxia_risk", 0)
519
+ reasoning.append(f"Model used: {cry_result.get('model', 'HeAR')}")
520
+ reasoning.append(f"Cry type detected: {cry_result.get('cry_type', 'unknown')}")
521
+ reasoning.append(f"Asphyxia risk score: {risk:.1%}")
522
+
523
+ features = cry_result.get("features", {})
524
+ if features:
525
+ reasoning.append(
526
+ f"Acoustic features: F0={features.get('f0_mean', 0):.0f}Hz, "
527
+ f"duration={features.get('duration', 0):.1f}s, "
528
+ f"voiced_ratio={features.get('voiced_ratio', 0):.2f}"
529
+ )
530
+
531
+ if cry_result.get("is_abnormal"):
532
+ reasoning.append(
533
+ f"ABNORMAL CRY PATTERN: risk_level={cry_result.get('risk_level', 'unknown')}"
534
+ )
535
+ else:
536
+ reasoning.append("Normal cry pattern detected")
537
+
538
+ # Higher confidence when risk score is far from 0.5 (clear result)
539
+ confidence = 0.5 + abs(risk - 0.5)
540
+ confidence = max(0.5, min(1.0, confidence))
541
+
542
+ except Exception as e:
543
+ reasoning.append(f"Cry analysis failed: {e}")
544
+ result.cry = {
545
+ "is_abnormal": False,
546
+ "asphyxia_risk": 0.0,
547
+ "cry_type": "unknown",
548
+ "risk_level": "low",
549
+ "recommendation": "Analysis failed - please retry",
550
+ "features": {},
551
+ "model": "error",
552
+ }
553
+ confidence = 0.0
554
+
555
+ elapsed = (time.time() - start) * 1000
556
+
557
+ trace = AgentResult(
558
+ agent_name="AudioAnalysisAgent",
559
+ status="success" if result.cry else "error",
560
+ reasoning=reasoning,
561
+ findings={
562
+ "is_abnormal": result.cry.get("is_abnormal", False) if result.cry else None,
563
+ "asphyxia_risk": result.cry.get("asphyxia_risk", 0) if result.cry else None,
564
+ },
565
+ confidence=confidence,
566
+ processing_time_ms=elapsed,
567
+ )
568
+
569
+ return result, trace
570
+
571
+
572
+ class ProtocolAgent:
573
+ """
574
+ Applies WHO IMNCI guidelines with clinical reasoning for severity
575
+ classification and evidence-based treatment recommendations.
576
+
577
+ Reasoning process:
578
+ 1. Evaluate each condition against WHO IMNCI thresholds
579
+ 2. Check for protocol conflicts (e.g., anemia + jaundice comorbidity)
580
+ 3. Apply condition-specific treatment algorithms
581
+ 4. Generate time-bound follow-up schedule
582
+ """
583
+
584
+ def process(
585
+ self,
586
+ patient_type: PatientType,
587
+ triage: TriageResult,
588
+ image: ImageAnalysisResult,
589
+ audio: Optional[AudioAnalysisResult] = None,
590
+ ) -> tuple[ProtocolResult, AgentResult]:
591
+ start = time.time()
592
+ reasoning: List[str] = []
593
+ protocols: List[str] = []
594
+ recommendations: List[str] = []
595
+ classification: SeverityLevel = triage.risk_level
596
+ conditions_found: List[str] = []
597
+
598
+ reasoning.append(f"[STEP 1/5] Applying WHO IMNCI protocols for {patient_type} patient")
599
+ reasoning.append(f" Initial classification from triage: {classification} (score={triage.score})")
600
+
601
+ # ---- Maternal protocols ----
602
+ if patient_type == "pregnant":
603
+ protocols.append("WHO IMNCI Maternal Care")
604
+ reasoning.append(f"[STEP 2/5] Evaluating maternal conditions")
605
+
606
+ if image.anemia and image.anemia.get("is_anemic"):
607
+ protocols.append("Anemia Management Protocol")
608
+ conditions_found.append("anemia")
609
+ est_hb = image.anemia.get("estimated_hemoglobin", 0)
610
+ risk_level = image.anemia.get("risk_level", "unknown")
611
+
612
+ reasoning.append(f" Anemia detected: risk={risk_level}, est. Hb={est_hb} g/dL")
613
+
614
+ # WHO thresholds: pregnant women Hb<11 = anemia, Hb<7 = severe
615
+ # (Non-pregnant women Hb<12; neonates vary by age)
616
+ severe_threshold = 7.0
617
+ moderate_threshold = 11.0
618
+ reasoning.append(f" Using WHO maternal thresholds: severe<{severe_threshold}, moderate<{moderate_threshold} g/dL")
619
+
620
+ if est_hb and est_hb < severe_threshold:
621
+ classification = "RED"
622
+ recommendations.append(f"URGENT: Severe anemia (Hb<{severe_threshold}) — refer for blood transfusion")
623
+ recommendations.append("Pre-referral: oral iron if conscious, keep warm during transport")
624
+ reasoning.append(f" WHO protocol: Hb<{severe_threshold} g/dL = SEVERE ANEMIA -> RED classification")
625
+ reasoning.append(f" Treatment: Blood transfusion required per WHO IMNCI anemia protocol")
626
+ elif est_hb and est_hb < moderate_threshold:
627
+ if classification != "RED":
628
+ classification = "YELLOW"
629
+ recommendations.append("Initiate iron supplementation (60mg elemental iron + 400mcg folic acid daily)")
630
+ recommendations.append("Dietary counseling: dark leafy greens, red meat, beans, fortified cereals")
631
+ recommendations.append("De-worming if not done in last 6 months (albendazole 400mg single dose)")
632
+ reasoning.append(f" WHO protocol: Hb {severe_threshold}-{moderate_threshold} g/dL = MODERATE ANEMIA -> YELLOW")
633
+ reasoning.append(f" Treatment: Iron supplementation + dietary counseling per WHO ANC guidelines")
634
+ else:
635
+ recommendations.append("Monitor hemoglobin levels, encourage iron-rich diet")
636
+ reasoning.append(f" Mild anemia or screening positive — continue monitoring")
637
+
638
+ if triage.critical_signs_detected:
639
+ protocols.append("Emergency Obstetric Care Protocol")
640
+ recommendations.append("Immediate assessment for emergency obstetric conditions")
641
+ reasoning.append(" Critical danger signs -> emergency obstetric protocol applied")
642
+ else:
643
+ reasoning.append(f"[STEP 2/5] Patient is newborn — skipping maternal protocols")
644
+
645
+ # ---- Newborn protocols ----
646
+ if patient_type == "newborn":
647
+ protocols.append("WHO IMNCI Newborn Care")
648
+ reasoning.append(f"[STEP 3/5] Evaluating neonatal conditions")
649
+
650
+ # Jaundice — with age-specific AAP/WHO thresholds
651
+ if image.jaundice and image.jaundice.get("has_jaundice"):
652
+ protocols.append("Neonatal Jaundice Protocol")
653
+ conditions_found.append("jaundice")
654
+ est_bili = image.jaundice.get("estimated_bilirubin", 0)
655
+ est_bili_ml = image.jaundice.get("estimated_bilirubin_ml")
656
+ severity = image.jaundice.get("severity", "unknown")
657
+ bili_value = est_bili_ml if est_bili_ml is not None else est_bili
658
+
659
+ reasoning.append(f" Jaundice detected: severity={severity}, bilirubin~{bili_value} mg/dL")
660
+ reasoning.append(f" Bilirubin method: {image.jaundice.get('bilirubin_method', 'color analysis')}")
661
+
662
+ # Age-specific phototherapy thresholds (AAP 2004 / WHO)
663
+ # For low-risk term newborns (>= 38 weeks):
664
+ # Age(h) Phototherapy Exchange
665
+ # 24 12 19
666
+ # 48 15 22
667
+ # 72 18 24
668
+ # 96+ 20 25
669
+ age_hours = None
670
+ if hasattr(triage, 'score'):
671
+ # Try to get age from patient context
672
+ pass # Age is checked below via patient_info
673
+
674
+ photo_threshold = 20.0 # default (>96h)
675
+ exchange_threshold = 25.0
676
+ if patient_info := getattr(self, '_patient_info', None):
677
+ pass
678
+ # Use conservative defaults, can be overridden by age context
679
+ reasoning.append(f" Using phototherapy threshold={photo_threshold} mg/dL, exchange={exchange_threshold} mg/dL")
680
+
681
+ if bili_value and bili_value > exchange_threshold:
682
+ classification = "RED"
683
+ recommendations.append(f"CRITICAL: Bilirubin >{exchange_threshold} mg/dL — immediate exchange transfusion evaluation")
684
+ recommendations.append("Continue intensive phototherapy during preparation")
685
+ reasoning.append(f" WHO protocol: TSB>{exchange_threshold} = EXCHANGE TRANSFUSION territory -> RED")
686
+ elif bili_value and bili_value > photo_threshold:
687
+ classification = "RED"
688
+ recommendations.append("URGENT: Severe hyperbilirubinemia — start intensive phototherapy immediately")
689
+ recommendations.append("Monitor bilirubin every 4-6 hours, prepare for possible exchange transfusion")
690
+ reasoning.append(f" WHO protocol: TSB>{photo_threshold} = SEVERE HYPERBILIRUBINEMIA -> RED")
691
+ elif image.jaundice.get("needs_phototherapy"):
692
+ if classification != "RED":
693
+ classification = "YELLOW"
694
+ recommendations.append("Initiate phototherapy (standard irradiance)")
695
+ recommendations.append("Monitor bilirubin every 6-12 hours under phototherapy")
696
+ recommendations.append("Ensure adequate breastfeeding (8-12 feeds per day)")
697
+ reasoning.append(f" Phototherapy indicated: bilirubin ~{bili_value} mg/dL exceeds age-specific threshold")
698
+ else:
699
+ recommendations.append("Continue breastfeeding (minimum 8-12 feeds per day)")
700
+ recommendations.append("Monitor skin color progression every 12 hours")
701
+ recommendations.append("Recheck bilirubin in 24 hours if visible jaundice persists")
702
+ reasoning.append(f" Mild jaundice ({bili_value} mg/dL) — monitoring and breastfeeding")
703
+
704
+ # Cry / asphyxia
705
+ if audio and audio.cry and audio.cry.get("is_abnormal"):
706
+ protocols.append("Birth Asphyxia Assessment Protocol")
707
+ conditions_found.append("abnormal_cry")
708
+ asphyxia_risk = audio.cry.get("asphyxia_risk", 0)
709
+ cry_type = audio.cry.get("cry_type", "unknown")
710
+
711
+ reasoning.append(f" Abnormal cry: type={cry_type}, asphyxia_risk={asphyxia_risk:.1%}")
712
+
713
+ if asphyxia_risk > 0.7:
714
+ classification = "RED"
715
+ recommendations.append("URGENT: High asphyxia risk — immediate neonatal assessment")
716
+ recommendations.append("Check airway, breathing, circulation (ABC)")
717
+ recommendations.append("Assess muscle tone, reflexes, and level of consciousness")
718
+ reasoning.append(f" WHO protocol: High asphyxia risk (>70%) -> RED, immediate assessment")
719
+ elif asphyxia_risk > 0.4:
720
+ if classification != "RED":
721
+ classification = "YELLOW"
722
+ recommendations.append("Monitor neurological status: tone, reflexes, feeding ability")
723
+ recommendations.append("Assess feeding pattern — poor feeding may indicate neurological compromise")
724
+ reasoning.append(f" Moderate asphyxia risk ({asphyxia_risk:.1%}) -> YELLOW, close monitoring")
725
+ else:
726
+ reasoning.append(f" Low asphyxia risk ({asphyxia_risk:.1%}) — documented but not concerning")
727
+
728
+ # Neonatal anemia
729
+ if image.anemia and image.anemia.get("is_anemic"):
730
+ protocols.append("Neonatal Anemia Protocol")
731
+ conditions_found.append("neonatal_anemia")
732
+ recommendations.append("Check hematocrit and reticulocyte count")
733
+ recommendations.append("Assess for signs of hemolysis: pallor, hepatosplenomegaly")
734
+ if classification != "RED":
735
+ classification = "YELLOW"
736
+ reasoning.append(" Neonatal anemia detected -> blood work and hemolysis assessment")
737
+ else:
738
+ reasoning.append(f"[STEP 3/5] Patient is pregnant — skipping neonatal protocols")
739
+
740
+ # Step 4: Comorbidity analysis and protocol conflict resolution
741
+ reasoning.append(f"[STEP 4/5] Comorbidity and conflict analysis")
742
+ if len(conditions_found) >= 2:
743
+ reasoning.append(f" Multiple conditions detected: {', '.join(conditions_found)}")
744
+ if "anemia" in conditions_found and "jaundice" in conditions_found:
745
+ reasoning.append(" WARNING: Anemia + Jaundice may indicate hemolytic disease")
746
+ reasoning.append(" Clinical reasoning: If both present in neonate, consider ABO/Rh incompatibility")
747
+ recommendations.append("Consider Coombs test for hemolytic disease if anemia and jaundice co-occur")
748
+ protocols.append("Hemolytic Disease Screening")
749
+ if "abnormal_cry" in conditions_found and ("jaundice" in conditions_found or "neonatal_anemia" in conditions_found):
750
+ reasoning.append(" WARNING: Neurological symptoms (abnormal cry) with systemic illness")
751
+ reasoning.append(" Clinical reasoning: Abnormal cry with jaundice may indicate bilirubin encephalopathy")
752
+ if classification != "RED":
753
+ classification = "RED"
754
+ reasoning.append(" ESCALATED to RED: combination of neurological + systemic findings")
755
+ else:
756
+ reasoning.append(f" Single condition or no conditions — no comorbidity conflicts")
757
+
758
+ # Step 5: Follow-up schedule
759
+ reasoning.append(f"[STEP 5/5] Determining follow-up schedule")
760
+
761
+ if classification == "RED":
762
+ follow_up = "Immediate referral — reassess after higher-level care"
763
+ reasoning.append(f" RED: Immediate referral required, no outpatient follow-up")
764
+ elif classification == "YELLOW":
765
+ follow_up = "Follow-up in 2-3 days, or immediately if condition worsens"
766
+ reasoning.append(f" YELLOW: 2-3 day follow-up with worsening precautions")
767
+ else:
768
+ follow_up = (
769
+ "Routine follow-up in 1 week"
770
+ if patient_type == "newborn"
771
+ else "Routine antenatal follow-up as scheduled"
772
+ )
773
+ reasoning.append(f" GREEN: Routine follow-up — {follow_up}")
774
+
775
+ reasoning.append(f" Final WHO IMNCI classification: {classification}")
776
+ reasoning.append(f" Protocols applied ({len(protocols)}): {', '.join(protocols)}")
777
+
778
+ elapsed = (time.time() - start) * 1000
779
+
780
+ result = ProtocolResult(
781
+ classification=classification,
782
+ applicable_protocols=protocols,
783
+ treatment_recommendations=recommendations,
784
+ follow_up_schedule=follow_up,
785
+ )
786
+
787
+ trace = AgentResult(
788
+ agent_name="ProtocolAgent",
789
+ status="success",
790
+ reasoning=reasoning,
791
+ findings={
792
+ "classification": classification,
793
+ "protocols_count": len(protocols),
794
+ "recommendations_count": len(recommendations),
795
+ "conditions_found": conditions_found,
796
+ },
797
+ confidence=1.0,
798
+ processing_time_ms=elapsed,
799
+ )
800
+
801
+ return result, trace
802
+
803
+
804
+ class ReferralAgent:
805
+ """
806
+ Clinical referral decision agent with structured reasoning.
807
+
808
+ Considers:
809
+ - Triage severity and critical danger signs
810
+ - Protocol classification and specific condition thresholds
811
+ - Facility capability requirements (phototherapy, transfusion, NICU)
812
+ - Transport safety and pre-referral treatment
813
+ - Generates structured referral note for receiving facility
814
+ """
815
+
816
+ def process(
817
+ self,
818
+ patient_type: PatientType,
819
+ triage: TriageResult,
820
+ protocol: ProtocolResult,
821
+ image: ImageAnalysisResult,
822
+ audio: Optional[AudioAnalysisResult] = None,
823
+ ) -> tuple[ReferralResult, AgentResult]:
824
+ start = time.time()
825
+ reasoning: List[str] = []
826
+ referral_needed = False
827
+ urgency: Literal["immediate", "urgent", "routine", "none"] = "none"
828
+ facility_level: Literal["primary", "secondary", "tertiary"] = "primary"
829
+ reasons: List[str] = []
830
+ pre_referral_actions: List[str] = []
831
+ capabilities_needed: List[str] = []
832
+
833
+ reasoning.append(f"[STEP 1/4] Evaluating referral necessity for {patient_type} patient")
834
+
835
+ # Step 1: Evaluate critical/immediate triggers
836
+ if triage.immediate_referral_needed:
837
+ referral_needed = True
838
+ urgency = "immediate"
839
+ facility_level = "tertiary"
840
+ reasons.append(f"Critical danger signs: {', '.join(triage.critical_signs)}")
841
+ capabilities_needed.append("Emergency care")
842
+ reasoning.append(f" TRIGGER: Critical danger signs ({', '.join(triage.critical_signs)}) -> IMMEDIATE referral to tertiary")
843
+
844
+ # Step 2: Protocol-driven referral assessment
845
+ reasoning.append(f"[STEP 2/4] Assessing condition-specific referral criteria")
846
+
847
+ if protocol.classification == "RED":
848
+ referral_needed = True
849
+ if urgency != "immediate":
850
+ urgency = "urgent"
851
+ if facility_level == "primary":
852
+ facility_level = "secondary"
853
+ reasoning.append(f" RED classification -> referral required (minimum: urgent to secondary)")
854
+
855
+ # Condition-specific evaluation with facility capability matching
856
+ if patient_type == "pregnant":
857
+ if image.anemia and image.anemia.get("is_anemic"):
858
+ est_hb = image.anemia.get("estimated_hemoglobin", 99)
859
+ if est_hb < 7:
860
+ referral_needed = True
861
+ if urgency != "immediate":
862
+ urgency = "urgent"
863
+ facility_level = "secondary"
864
+ reasons.append(f"Severe anemia (est. Hb={est_hb} g/dL) — blood transfusion needed")
865
+ capabilities_needed.append("Blood bank / transfusion services")
866
+ pre_referral_actions.append("Oral iron if conscious and able to swallow")
867
+ pre_referral_actions.append("Keep patient warm during transport")
868
+ pre_referral_actions.append("Position on left side to optimize placental perfusion")
869
+ reasoning.append(f" Severe anemia (Hb<7): requires blood transfusion -> secondary facility")
870
+ reasoning.append(f" Pre-referral: oral iron, warmth, left lateral position")
871
+
872
+ if patient_type == "newborn":
873
+ if image.jaundice and image.jaundice.get("needs_phototherapy"):
874
+ referral_needed = True
875
+ if urgency != "immediate":
876
+ urgency = "urgent"
877
+ if facility_level != "tertiary":
878
+ facility_level = "secondary"
879
+ est_bili = image.jaundice.get("estimated_bilirubin_ml") or image.jaundice.get("estimated_bilirubin", 0)
880
+ reasons.append(f"Jaundice requiring phototherapy (bilirubin ~{est_bili} mg/dL)")
881
+ capabilities_needed.append("Phototherapy unit")
882
+ pre_referral_actions.append("Continue frequent breastfeeding during transport")
883
+ pre_referral_actions.append("Expose skin to indirect sunlight if available")
884
+ pre_referral_actions.append("Keep baby warm — avoid hypothermia")
885
+ reasoning.append(f" Phototherapy needed (bilirubin ~{est_bili} mg/dL): requires phototherapy unit -> secondary")
886
+
887
+ if est_bili and est_bili > 20:
888
+ urgency = "immediate"
889
+ facility_level = "tertiary"
890
+ capabilities_needed.append("Exchange transfusion capability")
891
+ reasoning.append(f" Severe hyperbilirubinemia (>20 mg/dL): may need exchange transfusion -> tertiary")
892
+
893
+ if audio and audio.cry and audio.cry.get("asphyxia_risk", 0) > 0.7:
894
+ referral_needed = True
895
+ urgency = "immediate"
896
+ facility_level = "tertiary"
897
+ reasons.append("High birth asphyxia risk — NICU evaluation needed")
898
+ capabilities_needed.append("NICU / neonatal resuscitation")
899
+ pre_referral_actions.append("Maintain clear airway")
900
+ pre_referral_actions.append("Provide warmth and gentle stimulation")
901
+ pre_referral_actions.append("Monitor breathing during transport")
902
+ reasoning.append(f" High asphyxia risk (>70%): requires NICU -> IMMEDIATE to tertiary")
903
+
904
+ elif audio and audio.cry and audio.cry.get("asphyxia_risk", 0) > 0.4:
905
+ if not referral_needed:
906
+ referral_needed = True
907
+ urgency = "routine"
908
+ facility_level = "secondary"
909
+ reasons.append("Moderate asphyxia risk — specialist evaluation advised")
910
+ reasoning.append(f" Moderate asphyxia risk: specialist evaluation -> routine referral to secondary")
911
+
912
+ # Step 3: Synthesize and verify referral decision
913
+ reasoning.append(f"[STEP 3/4] Synthesizing referral decision")
914
+
915
+ if protocol.classification == "YELLOW" and not referral_needed:
916
+ urgency = "routine"
917
+ reasoning.append(f" YELLOW classification without specific referral triggers -> routine follow-up")
918
+
919
+ # Determine timeframe
920
+ timeframe_map = {
921
+ "immediate": "Within 1 hour — arrange emergency transport",
922
+ "urgent": "Within 4-6 hours — arrange priority transport",
923
+ "routine": "Within 24-48 hours — schedule outpatient referral",
924
+ "none": "Not applicable — manage at current facility",
925
+ }
926
+ timeframe = timeframe_map[urgency]
927
+
928
+ # Step 4: Generate referral summary
929
+ reasoning.append(f"[STEP 4/4] Referral decision summary")
930
+ reason_text = "; ".join(reasons) if reasons else "No referral required"
931
+
932
+ if referral_needed:
933
+ reasoning.append(f" DECISION: REFER — urgency={urgency}, facility={facility_level}")
934
+ reasoning.append(f" Reasons: {reason_text}")
935
+ reasoning.append(f" Timeframe: {timeframe}")
936
+ if capabilities_needed:
937
+ reasoning.append(f" Required capabilities: {', '.join(capabilities_needed)}")
938
+ if pre_referral_actions:
939
+ reasoning.append(f" Pre-referral actions: {'; '.join(pre_referral_actions)}")
940
+ else:
941
+ reasoning.append(f" DECISION: No referral needed — manage at current level")
942
+ reasoning.append(f" Follow protocol recommendations and scheduled follow-up")
943
+
944
+ elapsed = (time.time() - start) * 1000
945
+
946
+ result = ReferralResult(
947
+ referral_needed=referral_needed,
948
+ urgency=urgency,
949
+ facility_level=facility_level,
950
+ reason=reason_text,
951
+ timeframe=timeframe,
952
+ )
953
+
954
+ trace = AgentResult(
955
+ agent_name="ReferralAgent",
956
+ status="success",
957
+ reasoning=reasoning,
958
+ findings={
959
+ "referral_needed": referral_needed,
960
+ "urgency": urgency,
961
+ "facility_level": facility_level,
962
+ "capabilities_needed": capabilities_needed,
963
+ "pre_referral_actions": pre_referral_actions,
964
+ },
965
+ confidence=1.0,
966
+ processing_time_ms=elapsed,
967
+ )
968
+
969
+ return result, trace
970
+
971
+
972
+ class SynthesisAgent:
973
+ """
974
+ Clinical reasoning and synthesis using MedGemma.
975
+
976
+ HAI-DEF Model: MedGemma (google/medgemma-4b-it)
977
+ Reuses existing ClinicalSynthesizer instance.
978
+ Passes full agent reasoning context to MedGemma for richer synthesis.
979
+ """
980
+
981
+ def __init__(self, synthesizer: Optional[Any] = None):
982
+ self._synthesizer = synthesizer
983
+
984
+ def _get_synthesizer(self) -> Any:
985
+ if self._synthesizer is None:
986
+ from .clinical_synthesizer import ClinicalSynthesizer
987
+ self._synthesizer = ClinicalSynthesizer()
988
+ return self._synthesizer
989
+
990
+ def process(
991
+ self,
992
+ patient_type: PatientType,
993
+ triage: TriageResult,
994
+ image: ImageAnalysisResult,
995
+ audio: Optional[AudioAnalysisResult],
996
+ protocol: ProtocolResult,
997
+ referral: ReferralResult,
998
+ agent_traces: List[AgentResult],
999
+ ) -> tuple[Dict[str, Any], AgentResult]:
1000
+ start = time.time()
1001
+ reasoning: List[str] = []
1002
+
1003
+ reasoning.append("Synthesizing all agent findings with MedGemma")
1004
+
1005
+ # Build findings dict for the synthesizer
1006
+ findings: Dict[str, Any] = {}
1007
+ if image.anemia:
1008
+ findings["anemia"] = image.anemia
1009
+ reasoning.append("Including anemia findings in synthesis")
1010
+ if image.jaundice:
1011
+ findings["jaundice"] = image.jaundice
1012
+ reasoning.append("Including jaundice findings in synthesis")
1013
+ if audio and audio.cry:
1014
+ findings["cry"] = audio.cry
1015
+ reasoning.append("Including cry analysis findings in synthesis")
1016
+
1017
+ # Add agent context for richer synthesis
1018
+ findings["patient_info"] = {"type": patient_type}
1019
+ findings["agent_context"] = {
1020
+ "triage_score": triage.score,
1021
+ "triage_risk": triage.risk_level,
1022
+ "critical_signs": triage.critical_signs,
1023
+ "protocol_classification": protocol.classification,
1024
+ "applicable_protocols": protocol.applicable_protocols,
1025
+ "referral_needed": referral.referral_needed,
1026
+ "referral_urgency": referral.urgency,
1027
+ }
1028
+
1029
+ # Build reasoning trace summary for MedGemma prompt
1030
+ trace_summary = []
1031
+ for trace in agent_traces:
1032
+ trace_summary.append(f"{trace.agent_name}: {'; '.join(trace.reasoning[-3:])}")
1033
+ findings["agent_reasoning_summary"] = "\n".join(trace_summary)
1034
+
1035
+ reasoning.append(f"Passing {len(agent_traces)} agent traces as context")
1036
+
1037
+ try:
1038
+ synthesizer = self._get_synthesizer()
1039
+ synthesis = synthesizer.synthesize(findings)
1040
+ reasoning.append(f"Synthesis completed using: {synthesis.get('model', 'unknown')}")
1041
+ reasoning.append(f"Severity level: {synthesis.get('severity_level', 'N/A')}")
1042
+ reasoning.append(f"Referral needed: {synthesis.get('referral_needed', 'N/A')}")
1043
+
1044
+ confidence = 0.85 if "MedGemma" in synthesis.get("model", "") else 0.75
1045
+ except Exception as e:
1046
+ reasoning.append(f"Synthesis failed: {e}")
1047
+ synthesis = {
1048
+ "summary": f"Assessment for {patient_type} patient. Classification: {protocol.classification}.",
1049
+ "severity_level": protocol.classification,
1050
+ "severity_description": f"WHO IMNCI {protocol.classification} classification",
1051
+ "immediate_actions": protocol.treatment_recommendations or ["Continue routine care"],
1052
+ "referral_needed": referral.referral_needed,
1053
+ "referral_urgency": referral.urgency,
1054
+ "follow_up": protocol.follow_up_schedule,
1055
+ "urgent_conditions": triage.critical_signs,
1056
+ "model": "Fallback (agent context)",
1057
+ "generated_at": datetime.now().isoformat(),
1058
+ }
1059
+ confidence = 0.6
1060
+
1061
+ elapsed = (time.time() - start) * 1000
1062
+
1063
+ trace = AgentResult(
1064
+ agent_name="SynthesisAgent",
1065
+ status="success",
1066
+ reasoning=reasoning,
1067
+ findings={
1068
+ "model": synthesis.get("model", "unknown"),
1069
+ "severity_level": synthesis.get("severity_level", "unknown"),
1070
+ },
1071
+ confidence=confidence,
1072
+ processing_time_ms=elapsed,
1073
+ )
1074
+
1075
+ return synthesis, trace
1076
+
1077
+
1078
+ # ---------------------------------------------------------------------------
1079
+ # Workflow Engine
1080
+ # ---------------------------------------------------------------------------
1081
+
1082
+ WorkflowCallback = Callable[[WorkflowState, float], None]
1083
+
1084
+
1085
+ class AgenticWorkflowEngine:
1086
+ """
1087
+ Orchestrates the 6-agent clinical workflow pipeline.
1088
+
1089
+ Pipeline: Triage -> Image -> Audio -> Protocol -> Referral -> Synthesis
1090
+ Early-exit on critical danger signs (RED + critical -> skip to Synthesis)
1091
+
1092
+ Each agent emits a structured AgentResult with reasoning traces
1093
+ that form a complete audit trail of the clinical decision process.
1094
+ """
1095
+
1096
+ AGENTS = [
1097
+ "TriageAgent",
1098
+ "ImageAnalysisAgent",
1099
+ "AudioAnalysisAgent",
1100
+ "ProtocolAgent",
1101
+ "ReferralAgent",
1102
+ "SynthesisAgent",
1103
+ ]
1104
+
1105
+ def __init__(
1106
+ self,
1107
+ anemia_detector: Optional[Any] = None,
1108
+ jaundice_detector: Optional[Any] = None,
1109
+ cry_analyzer: Optional[Any] = None,
1110
+ synthesizer: Optional[Any] = None,
1111
+ on_state_change: Optional[WorkflowCallback] = None,
1112
+ ):
1113
+ self._triage = TriageAgent()
1114
+ self._image = ImageAnalysisAgent(anemia_detector, jaundice_detector)
1115
+ self._audio = AudioAnalysisAgent(cry_analyzer)
1116
+ self._protocol = ProtocolAgent()
1117
+ self._referral = ReferralAgent()
1118
+ self._synthesis = SynthesisAgent(synthesizer)
1119
+ self._state: WorkflowState = "idle"
1120
+ self._on_state_change = on_state_change
1121
+
1122
+ def _transition(self, state: WorkflowState, progress: float) -> None:
1123
+ self._state = state
1124
+ if self._on_state_change:
1125
+ self._on_state_change(state, progress)
1126
+
1127
+ @property
1128
+ def state(self) -> WorkflowState:
1129
+ return self._state
1130
+
1131
+ def execute(self, workflow_input: WorkflowInput) -> WorkflowResult:
1132
+ """
1133
+ Execute the full agentic workflow pipeline.
1134
+
1135
+ Args:
1136
+ workflow_input: Complete input with patient info, images, audio, danger signs.
1137
+
1138
+ Returns:
1139
+ WorkflowResult with all agent outputs, reasoning traces, and clinical synthesis.
1140
+ """
1141
+ start = time.time()
1142
+ agent_traces: List[AgentResult] = []
1143
+ patient_type = workflow_input.patient_type
1144
+
1145
+ try:
1146
+ # Step 1: Triage (10% progress)
1147
+ self._transition("triaging", 10.0)
1148
+ triage_result, triage_trace = self._triage.process(
1149
+ patient_type,
1150
+ workflow_input.danger_signs,
1151
+ workflow_input.patient_info,
1152
+ )
1153
+ agent_traces.append(triage_trace)
1154
+
1155
+ # Early exit for critical cases
1156
+ if triage_result.immediate_referral_needed:
1157
+ self._transition("complete", 100.0)
1158
+ return self._build_early_referral(
1159
+ workflow_input, triage_result, agent_traces, start
1160
+ )
1161
+
1162
+ # Step 2: Image Analysis (30% progress)
1163
+ self._transition("analyzing_image", 30.0)
1164
+ image_result, image_trace = self._image.process(
1165
+ patient_type,
1166
+ workflow_input.conjunctiva_image,
1167
+ workflow_input.skin_image,
1168
+ )
1169
+ agent_traces.append(image_trace)
1170
+
1171
+ # Step 3: Audio Analysis (50% progress)
1172
+ self._transition("analyzing_audio", 50.0)
1173
+ audio_result, audio_trace = self._audio.process(
1174
+ workflow_input.cry_audio,
1175
+ )
1176
+ agent_traces.append(audio_trace)
1177
+
1178
+ # Step 4: Protocol Application (70% progress)
1179
+ self._transition("applying_protocol", 70.0)
1180
+ protocol_result, protocol_trace = self._protocol.process(
1181
+ patient_type, triage_result, image_result, audio_result
1182
+ )
1183
+ agent_traces.append(protocol_trace)
1184
+
1185
+ # Step 5: Referral Decision (85% progress)
1186
+ self._transition("determining_referral", 85.0)
1187
+ referral_result, referral_trace = self._referral.process(
1188
+ patient_type, triage_result, protocol_result,
1189
+ image_result, audio_result,
1190
+ )
1191
+ agent_traces.append(referral_trace)
1192
+
1193
+ # Step 6: Clinical Synthesis with MedGemma (95% progress)
1194
+ self._transition("synthesizing", 95.0)
1195
+ synthesis, synthesis_trace = self._synthesis.process(
1196
+ patient_type, triage_result, image_result,
1197
+ audio_result, protocol_result, referral_result,
1198
+ agent_traces,
1199
+ )
1200
+ agent_traces.append(synthesis_trace)
1201
+
1202
+ # Build final result
1203
+ self._transition("complete", 100.0)
1204
+ elapsed = (time.time() - start) * 1000
1205
+
1206
+ return WorkflowResult(
1207
+ success=True,
1208
+ patient_type=patient_type,
1209
+ who_classification=protocol_result.classification,
1210
+ triage_result=triage_result,
1211
+ image_results=image_result,
1212
+ audio_results=audio_result,
1213
+ protocol_result=protocol_result,
1214
+ referral_result=referral_result,
1215
+ clinical_synthesis=synthesis.get("summary", ""),
1216
+ recommendation=synthesis.get("immediate_actions", ["Continue routine care"])[0],
1217
+ immediate_actions=synthesis.get("immediate_actions", []),
1218
+ agent_traces=agent_traces,
1219
+ processing_time_ms=elapsed,
1220
+ timestamp=datetime.now().isoformat(),
1221
+ )
1222
+
1223
+ except Exception as e:
1224
+ self._transition("error", 0.0)
1225
+ elapsed = (time.time() - start) * 1000
1226
+ error_trace = AgentResult(
1227
+ agent_name="WorkflowEngine",
1228
+ status="error",
1229
+ reasoning=[f"Workflow failed: {e}"],
1230
+ findings={"error": str(e)},
1231
+ confidence=0.0,
1232
+ processing_time_ms=elapsed,
1233
+ )
1234
+ agent_traces.append(error_trace)
1235
+
1236
+ return WorkflowResult(
1237
+ success=False,
1238
+ patient_type=patient_type,
1239
+ who_classification="RED",
1240
+ agent_traces=agent_traces,
1241
+ clinical_synthesis=f"Workflow error: {e}. Please retry or seek immediate medical consultation.",
1242
+ recommendation="Seek immediate medical consultation due to assessment error",
1243
+ immediate_actions=["Seek immediate medical consultation"],
1244
+ processing_time_ms=elapsed,
1245
+ timestamp=datetime.now().isoformat(),
1246
+ )
1247
+
1248
+ def _build_early_referral(
1249
+ self,
1250
+ workflow_input: WorkflowInput,
1251
+ triage: TriageResult,
1252
+ agent_traces: List[AgentResult],
1253
+ start_time: float,
1254
+ ) -> WorkflowResult:
1255
+ """Build result for early-exit when critical danger signs are detected."""
1256
+ elapsed = (time.time() - start_time) * 1000
1257
+
1258
+ critical_text = ", ".join(triage.critical_signs)
1259
+ synthesis_text = (
1260
+ f"URGENT: Critical danger signs detected ({critical_text}). "
1261
+ f"Immediate referral to higher-level facility is required. "
1262
+ f"This patient requires emergency care that cannot be provided at the current level."
1263
+ )
1264
+
1265
+ return WorkflowResult(
1266
+ success=True,
1267
+ patient_type=workflow_input.patient_type,
1268
+ who_classification="RED",
1269
+ triage_result=triage,
1270
+ image_results=ImageAnalysisResult(),
1271
+ audio_results=AudioAnalysisResult(),
1272
+ protocol_result=ProtocolResult(
1273
+ classification="RED",
1274
+ applicable_protocols=["Emergency Referral Protocol"],
1275
+ treatment_recommendations=["IMMEDIATE REFERRAL REQUIRED"],
1276
+ follow_up_schedule="After emergency care",
1277
+ ),
1278
+ referral_result=ReferralResult(
1279
+ referral_needed=True,
1280
+ urgency="immediate",
1281
+ facility_level="tertiary",
1282
+ reason=f"Critical danger signs detected: {critical_text}",
1283
+ timeframe="Immediately - within 1 hour",
1284
+ ),
1285
+ clinical_synthesis=synthesis_text,
1286
+ recommendation="IMMEDIATE REFERRAL to tertiary care facility",
1287
+ immediate_actions=[
1288
+ "Arrange emergency transport",
1289
+ "Call receiving facility",
1290
+ "Provide pre-referral treatment as per protocol",
1291
+ "Accompany patient with referral note",
1292
+ ],
1293
+ agent_traces=agent_traces,
1294
+ processing_time_ms=elapsed,
1295
+ timestamp=datetime.now().isoformat(),
1296
+ )
src/nexus/anemia_detector.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anemia Detector Module
3
+
4
+ Uses MedSigLIP from Google HAI-DEF for anemia detection from conjunctiva images.
5
+ Implements zero-shot classification with medical text prompts per NEXUS_MASTER_PLAN.md.
6
+
7
+ HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
8
+ Documentation: https://developers.google.com/health-ai-developer-foundations/medsiglip
9
+ """
10
+
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ from PIL import Image
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple, Union
17
+ import numpy as np
18
+
19
+ try:
20
+ from transformers import AutoProcessor, AutoModel
21
+ HAS_TRANSFORMERS = True
22
+ except ImportError:
23
+ HAS_TRANSFORMERS = False
24
+
25
+ # HAI-DEF MedSigLIP model IDs to try in order of preference
26
+ MEDSIGLIP_MODEL_IDS = [
27
+ "google/medsiglip-448", # MedSigLIP - official HAI-DEF model
28
+ "google/siglip-base-patch16-224", # SigLIP 224 - fallback
29
+ ]
30
+
31
+
32
+ class AnemiaDetector:
33
+ """
34
+ Detects anemia from conjunctiva (inner eyelid) images using MedSigLIP.
35
+
36
+ Uses zero-shot classification with medical prompts for detection.
37
+ HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
38
+ Fallback: siglip-base-patch16-224
39
+ """
40
+
41
+ # Medical text prompts for zero-shot classification (optimized for MedSigLIP)
42
+ # Expanded prompt set with specific clinical language for better discrimination
43
+ ANEMIC_PROMPTS = [
44
+ "pale conjunctiva with visible pallor indicating anemia",
45
+ "conjunctival pallor grade 2 or higher with reduced vascularity",
46
+ "white or very pale inner eyelid mucosa suggesting low hemoglobin",
47
+ "conjunctiva showing significant pallor and poor blood perfusion",
48
+ "anemic eye with pale pink to white palpebral conjunctiva",
49
+ "inner eyelid lacking red coloration consistent with severe anemia",
50
+ "conjunctiva with washed out appearance and faint vascular pattern",
51
+ "pale mucous membrane of the lower eyelid suggesting iron deficiency",
52
+ ]
53
+
54
+ HEALTHY_PROMPTS = [
55
+ "healthy red conjunctiva with rich vascular pattern",
56
+ "well-perfused bright pink inner eyelid with visible blood vessels",
57
+ "normal conjunctiva showing deep red-pink coloration",
58
+ "conjunctiva with healthy blood supply and strong red color",
59
+ "richly vascularized palpebral conjunctiva with normal hemoglobin",
60
+ "inner eyelid with vibrant red-pink mucosa and clear vessels",
61
+ "non-anemic conjunctiva showing robust red perfusion",
62
+ "conjunctival mucosa with normal deep pink to red appearance",
63
+ ]
64
+
65
+ def __init__(
66
+ self,
67
+ model_name: Optional[str] = None, # Auto-select MedSigLIP
68
+ device: Optional[str] = None,
69
+ threshold: float = 0.5,
70
+ ):
71
+ """
72
+ Initialize the Anemia Detector with MedSigLIP.
73
+
74
+ Args:
75
+ model_name: HuggingFace model name (auto-selects HAI-DEF MedSigLIP if None)
76
+ device: Device to run model on (auto-detected if None)
77
+ threshold: Classification threshold for anemia detection
78
+ """
79
+ if not HAS_TRANSFORMERS:
80
+ raise ImportError("transformers library required. Install with: pip install transformers")
81
+
82
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
83
+ self.threshold = threshold
84
+ self._model_loaded = False
85
+ self.classifier = None # Can be set by pipeline for trained classification
86
+
87
+ # Determine which models to try
88
+ models_to_try = [model_name] if model_name else MEDSIGLIP_MODEL_IDS
89
+
90
+ # HuggingFace token for gated models
91
+ hf_token = os.environ.get("HF_TOKEN")
92
+
93
+ # Try loading models in order of preference
94
+ for candidate_model in models_to_try:
95
+ print(f"Loading HAI-DEF model: {candidate_model}")
96
+ try:
97
+ self.processor = AutoProcessor.from_pretrained(
98
+ candidate_model, token=hf_token
99
+ )
100
+ self.model = AutoModel.from_pretrained(
101
+ candidate_model, token=hf_token
102
+ ).to(self.device)
103
+ self.model_name = candidate_model
104
+ self._model_loaded = True
105
+ print(f"Successfully loaded: {candidate_model}")
106
+ break
107
+ except Exception as e:
108
+ print(f"Warning: Could not load {candidate_model}: {e}")
109
+ continue
110
+
111
+ if not self._model_loaded:
112
+ raise RuntimeError(
113
+ f"Could not load any MedSigLIP model. Tried: {models_to_try}. "
114
+ "Install transformers and ensure internet access."
115
+ )
116
+
117
+ self.model.eval()
118
+
119
+ # Pre-compute text embeddings for efficiency
120
+ self._precompute_text_embeddings()
121
+
122
+ # Try to auto-load trained classifier
123
+ self._auto_load_classifier()
124
+
125
+ # Indicate which model variant is being used
126
+ is_medsiglip = "medsiglip" in self.model_name
127
+ model_type = "MedSigLIP" if is_medsiglip else "SigLIP (fallback)"
128
+ classifier_status = "with trained classifier" if self.classifier else "zero-shot"
129
+ print(f"Anemia Detector (HAI-DEF {model_type}, {classifier_status}) initialized on {self.device}")
130
+
131
+ def _auto_load_classifier(self) -> None:
132
+ """Auto-load trained anemia classifier if available."""
133
+ if self.classifier is not None:
134
+ return # Already set externally
135
+
136
+ try:
137
+ import joblib
138
+ except ImportError:
139
+ return
140
+
141
+ default_paths = [
142
+ Path(__file__).parent.parent.parent / "models" / "linear_probes" / "anemia_classifier.joblib",
143
+ Path("models/linear_probes/anemia_classifier.joblib"),
144
+ ]
145
+
146
+ for path in default_paths:
147
+ if path.exists():
148
+ try:
149
+ self.classifier = joblib.load(path)
150
+ print(f"Auto-loaded anemia classifier from {path}")
151
+ return
152
+ except Exception as e:
153
+ print(f"Warning: Could not load classifier from {path}: {e}")
154
+
155
+ # Logit temperature for softmax conversion (lower = more spread, higher = sharper)
156
+ LOGIT_SCALE = 30.0
157
+
158
+ def _precompute_text_embeddings(self) -> None:
159
+ """Pre-compute text embeddings for zero-shot classification using SigLIP.
160
+
161
+ Stores individual prompt embeddings for max-similarity scoring,
162
+ which outperforms mean-pooled embeddings for medical image classification.
163
+ """
164
+ all_prompts = self.ANEMIC_PROMPTS + self.HEALTHY_PROMPTS
165
+
166
+ with torch.no_grad():
167
+ # SigLIP uses different API than CLIP
168
+ inputs = self.processor(
169
+ text=all_prompts,
170
+ return_tensors="pt",
171
+ padding="max_length",
172
+ truncation=True,
173
+ ).to(self.device)
174
+
175
+ # Get text embeddings - support multiple output APIs
176
+ if hasattr(self.model, 'get_text_features'):
177
+ text_embeddings = self.model.get_text_features(**inputs)
178
+ else:
179
+ outputs = self.model(**inputs)
180
+ if hasattr(outputs, 'text_embeds'):
181
+ text_embeddings = outputs.text_embeds
182
+ elif hasattr(outputs, 'text_model_output'):
183
+ text_embeddings = outputs.text_model_output.pooler_output
184
+ else:
185
+ text_outputs = self.model.text_model(**inputs)
186
+ text_embeddings = text_outputs.pooler_output
187
+
188
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
189
+
190
+ # Store individual embeddings for max-similarity scoring
191
+ n_anemic = len(self.ANEMIC_PROMPTS)
192
+ self.anemic_embeddings_all = text_embeddings[:n_anemic] # (N, D)
193
+ self.healthy_embeddings_all = text_embeddings[n_anemic:] # (M, D)
194
+
195
+ # Also keep mean embeddings as fallback
196
+ self.anemic_embeddings = self.anemic_embeddings_all.mean(dim=0, keepdim=True)
197
+ self.healthy_embeddings = self.healthy_embeddings_all.mean(dim=0, keepdim=True)
198
+ self.anemic_embeddings = self.anemic_embeddings / self.anemic_embeddings.norm(dim=-1, keepdim=True)
199
+ self.healthy_embeddings = self.healthy_embeddings / self.healthy_embeddings.norm(dim=-1, keepdim=True)
200
+
201
+ def preprocess_image(self, image: Union[str, Path, Image.Image]) -> Image.Image:
202
+ """
203
+ Preprocess image for analysis.
204
+
205
+ Args:
206
+ image: Path to image or PIL Image
207
+
208
+ Returns:
209
+ Preprocessed PIL Image
210
+ """
211
+ if isinstance(image, (str, Path)):
212
+ image = Image.open(image).convert("RGB")
213
+ elif not isinstance(image, Image.Image):
214
+ raise ValueError(f"Expected str, Path, or PIL Image, got {type(image)}")
215
+
216
+ return image
217
+
218
+ def detect(self, image: Union[str, Path, Image.Image]) -> Dict:
219
+ """
220
+ Detect anemia from conjunctiva image.
221
+
222
+ Uses trained classifier if available, otherwise falls back to
223
+ zero-shot classification with MedSigLIP.
224
+
225
+ Args:
226
+ image: Conjunctiva image (path or PIL Image)
227
+
228
+ Returns:
229
+ Dictionary containing:
230
+ - is_anemic: Boolean indicating anemia detection
231
+ - confidence: Confidence score (0-1)
232
+ - anemia_score: Raw anemia probability
233
+ - healthy_score: Raw healthy probability
234
+ - risk_level: "high", "medium", or "low"
235
+ - recommendation: Clinical recommendation
236
+ """
237
+ # Preprocess image
238
+ pil_image = self.preprocess_image(image)
239
+
240
+ # Get image embedding using SigLIP
241
+ with torch.no_grad():
242
+ inputs = self.processor(
243
+ images=pil_image,
244
+ return_tensors="pt",
245
+ ).to(self.device)
246
+
247
+ # Get image embeddings - support multiple output APIs
248
+ if hasattr(self.model, 'get_image_features'):
249
+ image_embedding = self.model.get_image_features(**inputs)
250
+ else:
251
+ outputs = self.model(**inputs)
252
+ if hasattr(outputs, 'image_embeds'):
253
+ image_embedding = outputs.image_embeds
254
+ elif hasattr(outputs, 'vision_model_output'):
255
+ image_embedding = outputs.vision_model_output.pooler_output
256
+ else:
257
+ vision_outputs = self.model.vision_model(**inputs)
258
+ image_embedding = vision_outputs.pooler_output
259
+
260
+ image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
261
+
262
+ # Use trained classifier if available, otherwise zero-shot
263
+ if self.classifier is not None:
264
+ anemia_prob, healthy_prob, model_method = self._classify_with_trained_model(image_embedding)
265
+ else:
266
+ anemia_prob, healthy_prob, model_method = self._classify_zero_shot(image_embedding)
267
+
268
+ # Determine risk level
269
+ if anemia_prob > 0.7:
270
+ risk_level = "high"
271
+ recommendation = "URGENT: Refer for blood test immediately. High likelihood of anemia."
272
+ elif anemia_prob > 0.5:
273
+ risk_level = "medium"
274
+ recommendation = "Schedule blood test within 48 hours. Moderate anemia indicators present."
275
+ else:
276
+ risk_level = "low"
277
+ recommendation = "No immediate concern. Routine follow-up recommended."
278
+
279
+ is_medsiglip = "medsiglip" in self.model_name
280
+ base_model = "MedSigLIP (HAI-DEF)" if is_medsiglip else "SigLIP (fallback)"
281
+
282
+ return {
283
+ "is_anemic": anemia_prob > self.threshold,
284
+ "confidence": max(anemia_prob, healthy_prob),
285
+ "anemia_score": anemia_prob,
286
+ "healthy_score": healthy_prob,
287
+ "risk_level": risk_level,
288
+ "recommendation": recommendation,
289
+ "model": self.model_name,
290
+ "model_type": f"{base_model} + {model_method}",
291
+ }
292
+
293
+ def _classify_with_trained_model(self, image_embedding: torch.Tensor) -> Tuple[float, float, str]:
294
+ """
295
+ Classify using trained classifier on embeddings.
296
+
297
+ Args:
298
+ image_embedding: Normalized image embedding from MedSigLIP
299
+
300
+ Returns:
301
+ Tuple of (anemia_prob, healthy_prob, method_name)
302
+ """
303
+ # Convert embedding to numpy for sklearn classifiers
304
+ embedding_np = image_embedding.cpu().numpy().reshape(1, -1)
305
+
306
+ # Handle different classifier types
307
+ if hasattr(self.classifier, 'predict_proba'):
308
+ # Sklearn classifier with probability support
309
+ proba = self.classifier.predict_proba(embedding_np)
310
+ # Assume binary: [healthy, anemic] or [anemic, healthy]
311
+ if proba.shape[1] >= 2:
312
+ # Check classifier classes to determine order
313
+ if hasattr(self.classifier, 'classes_'):
314
+ classes = list(self.classifier.classes_)
315
+ if 1 in classes:
316
+ anemia_idx = classes.index(1)
317
+ else:
318
+ anemia_idx = 1 # Default assumption
319
+ else:
320
+ anemia_idx = 1
321
+ anemia_prob = float(proba[0, anemia_idx])
322
+ healthy_prob = 1.0 - anemia_prob
323
+ else:
324
+ anemia_prob = float(proba[0, 0])
325
+ healthy_prob = 1.0 - anemia_prob
326
+ return anemia_prob, healthy_prob, "Trained Classifier"
327
+
328
+ elif hasattr(self.classifier, 'predict'):
329
+ # Classifier without probability - use binary prediction
330
+ prediction = self.classifier.predict(embedding_np)
331
+ anemia_prob = float(prediction[0])
332
+ healthy_prob = 1.0 - anemia_prob
333
+ return anemia_prob, healthy_prob, "Trained Classifier (binary)"
334
+
335
+ elif isinstance(self.classifier, nn.Module):
336
+ # PyTorch classifier
337
+ self.classifier.eval()
338
+ with torch.no_grad():
339
+ logits = self.classifier(image_embedding)
340
+ probs = torch.softmax(logits, dim=-1)
341
+ if probs.shape[-1] >= 2:
342
+ anemia_prob = probs[0, 1].item()
343
+ healthy_prob = probs[0, 0].item()
344
+ else:
345
+ anemia_prob = probs[0, 0].item()
346
+ healthy_prob = 1.0 - anemia_prob
347
+ return anemia_prob, healthy_prob, "Trained Classifier (PyTorch)"
348
+
349
+ else:
350
+ # Unknown classifier type - fall back to zero-shot
351
+ print(f"Warning: Unknown classifier type {type(self.classifier)}, using zero-shot")
352
+ return self._classify_zero_shot(image_embedding)
353
+
354
+ def _classify_zero_shot(self, image_embedding: torch.Tensor) -> Tuple[float, float, str]:
355
+ """
356
+ Classify using zero-shot with max-similarity scoring.
357
+
358
+ Uses the maximum cosine similarity across all prompts per class
359
+ rather than mean-pooled embeddings, which provides better
360
+ discrimination for medical image classification.
361
+
362
+ Args:
363
+ image_embedding: Normalized image embedding from MedSigLIP
364
+
365
+ Returns:
366
+ Tuple of (anemia_prob, healthy_prob, method_name)
367
+ """
368
+ # Max-similarity: take the best-matching prompt per class
369
+ anemia_sims = (image_embedding @ self.anemic_embeddings_all.T).squeeze(0)
370
+ healthy_sims = (image_embedding @ self.healthy_embeddings_all.T).squeeze(0)
371
+
372
+ # Ensure at least 1-D for .max() to work on single-image inputs
373
+ if anemia_sims.dim() == 0:
374
+ anemia_sims = anemia_sims.unsqueeze(0)
375
+ if healthy_sims.dim() == 0:
376
+ healthy_sims = healthy_sims.unsqueeze(0)
377
+
378
+ anemia_sim = anemia_sims.max().item()
379
+ healthy_sim = healthy_sims.max().item()
380
+
381
+ # Convert to probabilities with tuned temperature
382
+ logits = torch.tensor([anemia_sim, healthy_sim], device="cpu") * self.LOGIT_SCALE
383
+ probs = torch.softmax(logits, dim=0)
384
+ anemia_prob = probs[0].item()
385
+ healthy_prob = probs[1].item()
386
+
387
+ return anemia_prob, healthy_prob, "Zero-Shot"
388
+
389
+ def detect_batch(
390
+ self,
391
+ images: List[Union[str, Path, Image.Image]],
392
+ batch_size: int = 8,
393
+ ) -> List[Dict]:
394
+ """
395
+ Detect anemia from multiple images.
396
+
397
+ Args:
398
+ images: List of conjunctiva images
399
+ batch_size: Batch size for processing
400
+
401
+ Returns:
402
+ List of detection results
403
+ """
404
+ results = []
405
+
406
+ for i in range(0, len(images), batch_size):
407
+ batch = images[i:i + batch_size]
408
+
409
+ # Process batch
410
+ pil_images = [self.preprocess_image(img) for img in batch]
411
+
412
+ with torch.no_grad():
413
+ inputs = self.processor(
414
+ images=pil_images,
415
+ return_tensors="pt",
416
+ padding=True,
417
+ ).to(self.device)
418
+
419
+ # Get image embeddings - support multiple output APIs
420
+ if hasattr(self.model, 'get_image_features'):
421
+ image_embeddings = self.model.get_image_features(**inputs)
422
+ else:
423
+ outputs = self.model(**inputs)
424
+ if hasattr(outputs, 'image_embeds'):
425
+ image_embeddings = outputs.image_embeds
426
+ elif hasattr(outputs, 'vision_model_output'):
427
+ image_embeddings = outputs.vision_model_output.pooler_output
428
+ else:
429
+ vision_outputs = self.model.vision_model(**inputs)
430
+ image_embeddings = vision_outputs.pooler_output
431
+
432
+ image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
433
+
434
+ # Compute max-similarities for each image
435
+ for j, img_emb in enumerate(image_embeddings):
436
+ img_emb = img_emb.unsqueeze(0)
437
+
438
+ # Use trained classifier if available, otherwise zero-shot
439
+ if self.classifier is not None:
440
+ anemia_prob, healthy_prob, _ = self._classify_with_trained_model(img_emb)
441
+ # Skip zero-shot path below
442
+ if anemia_prob > 0.7:
443
+ risk_level = "high"
444
+ recommendation = "URGENT: Refer for blood test immediately."
445
+ elif anemia_prob > 0.5:
446
+ risk_level = "medium"
447
+ recommendation = "Schedule blood test within 48 hours."
448
+ else:
449
+ risk_level = "low"
450
+ recommendation = "No immediate concern."
451
+
452
+ results.append({
453
+ "is_anemic": anemia_prob > self.threshold,
454
+ "confidence": max(anemia_prob, healthy_prob),
455
+ "anemia_score": anemia_prob,
456
+ "healthy_score": healthy_prob,
457
+ "risk_level": risk_level,
458
+ "recommendation": recommendation,
459
+ })
460
+ continue
461
+
462
+ anemia_sims = (img_emb @ self.anemic_embeddings_all.T).squeeze(0)
463
+ healthy_sims = (img_emb @ self.healthy_embeddings_all.T).squeeze(0)
464
+
465
+ if anemia_sims.dim() == 0:
466
+ anemia_sims = anemia_sims.unsqueeze(0)
467
+ if healthy_sims.dim() == 0:
468
+ healthy_sims = healthy_sims.unsqueeze(0)
469
+
470
+ anemia_sim = anemia_sims.max().item()
471
+ healthy_sim = healthy_sims.max().item()
472
+
473
+ logits = torch.tensor([anemia_sim, healthy_sim], device="cpu") * self.LOGIT_SCALE
474
+ probs = torch.softmax(logits, dim=0)
475
+ anemia_prob = probs[0].item()
476
+ healthy_prob = probs[1].item()
477
+
478
+ if anemia_prob > 0.7:
479
+ risk_level = "high"
480
+ recommendation = "URGENT: Refer for blood test immediately."
481
+ elif anemia_prob > 0.5:
482
+ risk_level = "medium"
483
+ recommendation = "Schedule blood test within 48 hours."
484
+ else:
485
+ risk_level = "low"
486
+ recommendation = "No immediate concern."
487
+
488
+ results.append({
489
+ "is_anemic": anemia_prob > self.threshold,
490
+ "confidence": max(anemia_prob, healthy_prob),
491
+ "anemia_score": anemia_prob,
492
+ "healthy_score": healthy_prob,
493
+ "risk_level": risk_level,
494
+ "recommendation": recommendation,
495
+ })
496
+
497
+ return results
498
+
499
+ def analyze_color_features(self, image: Union[str, Path, Image.Image]) -> Dict:
500
+ """
501
+ Analyze color features of conjunctiva image.
502
+
503
+ This provides interpretable features based on medical literature
504
+ that correlates pallor with anemia.
505
+
506
+ Args:
507
+ image: Conjunctiva image
508
+
509
+ Returns:
510
+ Dictionary with color analysis results
511
+ """
512
+ pil_image = self.preprocess_image(image)
513
+ img_array = np.array(pil_image)
514
+
515
+ # Extract RGB channels
516
+ r_channel = img_array[:, :, 0].astype(float)
517
+ g_channel = img_array[:, :, 1].astype(float)
518
+ b_channel = img_array[:, :, 2].astype(float)
519
+
520
+ # Calculate color statistics
521
+ mean_r = np.mean(r_channel)
522
+ mean_g = np.mean(g_channel)
523
+ mean_b = np.mean(b_channel)
524
+
525
+ # Red ratio (higher in healthy, lower in anemic)
526
+ total_intensity = mean_r + mean_g + mean_b
527
+ red_ratio = mean_r / total_intensity if total_intensity > 0 else 0
528
+
529
+ # Pallor index (higher means more pale/anemic)
530
+ # Based on reduced red-to-green ratio in anemic conjunctiva
531
+ pallor_index = 1 - (mean_r / (mean_g + 1e-6))
532
+ pallor_index = max(0, min(1, (pallor_index + 0.5) / 1.5))
533
+
534
+ # Hemoglobin estimation (rough approximation)
535
+ # Normal Hb: 12-16 g/dL for women, 14-18 for men
536
+ # This is a rough estimate based on color analysis
537
+ estimated_hb = 8 + (red_ratio * 12)
538
+
539
+ return {
540
+ "mean_red": mean_r,
541
+ "mean_green": mean_g,
542
+ "mean_blue": mean_b,
543
+ "red_ratio": red_ratio,
544
+ "pallor_index": pallor_index,
545
+ "estimated_hemoglobin": round(estimated_hb, 1),
546
+ "interpretation": "Low hemoglobin" if pallor_index > 0.5 else "Normal hemoglobin",
547
+ }
548
+
549
+
550
+ def test_detector():
551
+ """Test the anemia detector with sample images."""
552
+ print("Testing Anemia Detector...")
553
+
554
+ detector = AnemiaDetector()
555
+
556
+ # Test with sample images from dataset
557
+ data_dir = Path(__file__).parent.parent.parent / "data" / "raw" / "eyes-defy-anemia"
558
+
559
+ if data_dir.exists():
560
+ # Find sample images
561
+ sample_images = list(data_dir.rglob("*.jpg"))[:3]
562
+
563
+ for img_path in sample_images:
564
+ print(f"\nAnalyzing: {img_path.name}")
565
+ result = detector.detect(img_path)
566
+ print(f" Anemia detected: {result['is_anemic']}")
567
+ print(f" Confidence: {result['confidence']:.2%}")
568
+ print(f" Risk level: {result['risk_level']}")
569
+ print(f" Recommendation: {result['recommendation']}")
570
+
571
+ # Color analysis
572
+ color_info = detector.analyze_color_features(img_path)
573
+ print(f" Estimated Hb: {color_info['estimated_hemoglobin']} g/dL")
574
+ else:
575
+ print(f"Dataset not found at {data_dir}")
576
+ print("Please run download_datasets.py first")
577
+
578
+
579
+ if __name__ == "__main__":
580
+ test_detector()
src/nexus/clinical_synthesizer.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Clinical Synthesizer Module
3
+
4
+ Uses MedGemma from Google HAI-DEF for clinical reasoning and synthesis.
5
+ Combines findings from MedSigLIP (images) and HeAR (audio) into actionable recommendations.
6
+
7
+ HAI-DEF Model: MedGemma 4B (google/medgemma-4b-it or google/medgemma-1.5-4b-it)
8
+ Supports 4-bit quantization via BitsAndBytes for low-VRAM deployment.
9
+ """
10
+
11
+ import torch
12
+ from typing import Dict, Optional, List
13
+ from datetime import datetime
14
+
15
+ try:
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+ HAS_TRANSFORMERS = True
18
+ except ImportError:
19
+ HAS_TRANSFORMERS = False
20
+
21
+ try:
22
+ from transformers import BitsAndBytesConfig
23
+ HAS_BITSANDBYTES = True
24
+ except ImportError:
25
+ HAS_BITSANDBYTES = False
26
+
27
+
28
+ class ClinicalSynthesizer:
29
+ """
30
+ Synthesizes clinical findings using MedGemma.
31
+
32
+ HAI-DEF Model: MedGemma 4B (google/medgemma-4b-it or google/medgemma-1.5-4b-it)
33
+ Method: Prompt engineering (no fine-tuning required)
34
+ Quantization: 4-bit NF4 via BitsAndBytes for low-VRAM deployment
35
+
36
+ Output:
37
+ - Integrated diagnosis suggestions
38
+ - Severity assessment (GREEN/YELLOW/RED)
39
+ - Treatment recommendations (WHO IMNCI)
40
+ - Referral decision with urgency
41
+ - CHW-friendly explanations
42
+ """
43
+
44
+ # WHO IMNCI severity colors
45
+ SEVERITY_LEVELS = {
46
+ "GREEN": "Routine care - no immediate concern",
47
+ "YELLOW": "Close monitoring - may need referral",
48
+ "RED": "Urgent referral - immediate action required",
49
+ }
50
+
51
+ # MedGemma model candidates in preference order
52
+ MEDGEMMA_MODEL_IDS = [
53
+ "google/medgemma-1.5-4b-it", # Newer, better performance
54
+ "google/medgemma-4b-it", # Original HAI-DEF model
55
+ ]
56
+
57
+ def __init__(
58
+ self,
59
+ model_name: Optional[str] = None,
60
+ device: Optional[str] = None,
61
+ use_medgemma: bool = True,
62
+ use_4bit: bool = True,
63
+ ):
64
+ """
65
+ Initialize the Clinical Synthesizer with MedGemma.
66
+
67
+ Args:
68
+ model_name: HuggingFace model name for MedGemma (auto-selects if None)
69
+ device: Device to run model on
70
+ use_medgemma: Whether to use MedGemma (True) or rule-based (False)
71
+ use_4bit: Whether to use 4-bit quantization (reduces VRAM from ~8GB to ~2GB)
72
+ """
73
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
74
+ self._user_model_name = model_name # None if user didn't specify
75
+ self.model_name = model_name or self.MEDGEMMA_MODEL_IDS[-1]
76
+ self.model = None
77
+ self.tokenizer = None
78
+ self.use_medgemma = use_medgemma
79
+ self.use_4bit = use_4bit
80
+ self._medgemma_available = False
81
+
82
+ if use_medgemma and HAS_TRANSFORMERS:
83
+ self._load_medgemma()
84
+ else:
85
+ print("MedGemma not available. Using rule-based clinical synthesis.")
86
+ self.use_medgemma = False
87
+
88
+ print(f"Clinical Synthesizer (HAI-DEF MedGemma) initialized")
89
+
90
+ def _load_medgemma(self) -> None:
91
+ """Load MedGemma model from HuggingFace with 4-bit quantization.
92
+
93
+ Tries model candidates in preference order:
94
+ 1. google/medgemma-1.5-4b-it (newer, better performance)
95
+ 2. google/medgemma-4b-it (original HAI-DEF model)
96
+
97
+ Uses BitsAndBytes NF4 quantization to reduce VRAM from ~8GB to ~2GB,
98
+ which fixes CUDA OOM errors on consumer GPUs.
99
+ """
100
+ import os
101
+ hf_token = os.environ.get("HF_TOKEN")
102
+ if not hf_token:
103
+ print("Warning: HF_TOKEN not set. MedGemma is a gated model and requires authentication.")
104
+ print("Set HF_TOKEN environment variable with your HuggingFace token.")
105
+
106
+ # Determine models to try — if user explicitly passed a model_name,
107
+ # only try that one; otherwise try all candidates in preference order.
108
+ models_to_try = [self._user_model_name] if self._user_model_name else self.MEDGEMMA_MODEL_IDS
109
+
110
+ # Build quantization config for 4-bit loading
111
+ bnb_config = None
112
+ if self.use_4bit and self.device == "cuda" and HAS_BITSANDBYTES:
113
+ try:
114
+ bnb_config = BitsAndBytesConfig(
115
+ load_in_4bit=True,
116
+ bnb_4bit_quant_type="nf4",
117
+ bnb_4bit_use_double_quant=True,
118
+ bnb_4bit_compute_dtype=torch.float16,
119
+ )
120
+ print("4-bit quantization enabled (NF4 + double quant)")
121
+ except Exception as e:
122
+ print(f"Warning: Could not create BitsAndBytes config: {e}")
123
+ bnb_config = None
124
+
125
+ for candidate_model in models_to_try:
126
+ try:
127
+ print(f"Loading MedGemma model: {candidate_model}")
128
+ self.tokenizer = AutoTokenizer.from_pretrained(
129
+ candidate_model, token=hf_token
130
+ )
131
+
132
+ load_kwargs = {
133
+ "token": hf_token,
134
+ "device_map": "auto" if self.device == "cuda" else None,
135
+ }
136
+
137
+ if bnb_config is not None:
138
+ # 4-bit quantized loading (~2GB VRAM)
139
+ load_kwargs["quantization_config"] = bnb_config
140
+ else:
141
+ # Standard loading with fp16/fp32
142
+ load_kwargs["torch_dtype"] = (
143
+ torch.float16 if self.device == "cuda" else torch.float32
144
+ )
145
+
146
+ self.model = AutoModelForCausalLM.from_pretrained(
147
+ candidate_model, **load_kwargs
148
+ )
149
+
150
+ if self.device == "cpu" and bnb_config is None:
151
+ self.model = self.model.to(self.device)
152
+
153
+ self.model_name = candidate_model
154
+ self._medgemma_available = True
155
+ quant_status = "4-bit NF4" if bnb_config is not None else "fp16/fp32"
156
+ print(f"MedGemma loaded successfully: {candidate_model} ({quant_status})")
157
+ return
158
+
159
+ except Exception as e:
160
+ print(f"Warning: Could not load {candidate_model}: {e}")
161
+ continue
162
+
163
+ print("Could not load any MedGemma model. Falling back to rule-based synthesis.")
164
+ self.model = None
165
+ self.tokenizer = None
166
+ self.use_medgemma = False
167
+ self._medgemma_available = False
168
+
169
+ @staticmethod
170
+ def _sanitize(value: object) -> str:
171
+ """Sanitize a value for safe inclusion in a prompt.
172
+
173
+ Strips control characters and truncates excessively long strings to
174
+ prevent prompt injection via adversarial findings.
175
+ """
176
+ text = str(value) if value is not None else "N/A"
177
+ # Remove characters that could break prompt structure
178
+ text = text.replace("\x00", "").replace("\r", "")
179
+ # Truncate overly long values
180
+ if len(text) > 500:
181
+ text = text[:500] + "..."
182
+ return text
183
+
184
+ def _build_prompt(self, findings: Dict) -> str:
185
+ """
186
+ Build clinical synthesis prompt for MedGemma.
187
+
188
+ Args:
189
+ findings: Dictionary with anemia, jaundice, cry analysis results.
190
+ May include 'agent_context' and 'agent_reasoning_summary'
191
+ when called from the agentic workflow engine.
192
+
193
+ Returns:
194
+ Formatted prompt for MedGemma
195
+ """
196
+ # Extract findings with safe defaults
197
+ anemia = findings.get("anemia", {})
198
+ jaundice = findings.get("jaundice", {})
199
+ cry = findings.get("cry", {})
200
+ symptoms = self._sanitize(findings.get("symptoms", "None reported"))
201
+ patient_info = findings.get("patient_info", {})
202
+ agent_context = findings.get("agent_context", {})
203
+ agent_reasoning = self._sanitize(findings.get("agent_reasoning_summary", ""))
204
+
205
+ prompt = f"""You are a pediatric health assistant helping community health workers in low-resource settings.
206
+
207
+ PATIENT INFORMATION:
208
+ - Age: {patient_info.get("age", "Not specified")}
209
+ - Weight: {patient_info.get("weight", "Not specified")}
210
+ - Location: {patient_info.get("location", "Rural health post")}
211
+ - Patient Type: {patient_info.get("type", "Not specified")}
212
+
213
+ ASSESSMENT FINDINGS:
214
+
215
+ 1. ANEMIA SCREENING (Conjunctiva Analysis):
216
+ - Result: {"Anemia detected" if anemia.get("is_anemic") else "No anemia detected"}
217
+ - Confidence: {anemia.get("confidence", "N/A")}
218
+ - Severity: {anemia.get("severity", anemia.get("risk_level", "N/A"))}
219
+ - Estimated Hemoglobin: {anemia.get("estimated_hemoglobin", "N/A")} g/dL
220
+
221
+ 2. JAUNDICE SCREENING (Skin Analysis):
222
+ - Result: {"Jaundice detected" if jaundice.get("has_jaundice") else "No jaundice detected"}
223
+ - Confidence: {jaundice.get("confidence", "N/A")}
224
+ - Severity: {jaundice.get("severity", "N/A")}
225
+ - Estimated Bilirubin: {jaundice.get("estimated_bilirubin", "N/A")} mg/dL
226
+ - Needs Phototherapy: {jaundice.get("needs_phototherapy", "N/A")}
227
+
228
+ 3. CRY ANALYSIS (Audio):
229
+ - Result: {"Abnormal cry pattern" if cry.get("is_abnormal") else "Normal cry pattern"}
230
+ - Asphyxia Risk: {cry.get("asphyxia_risk", "N/A")}
231
+ - Cry Type: {cry.get("cry_type", "N/A")}
232
+
233
+ 4. REPORTED SYMPTOMS:
234
+ {symptoms}
235
+ """
236
+
237
+ # Add agentic workflow context if available
238
+ if agent_context:
239
+ prompt += f"""
240
+ 5. MULTI-AGENT ASSESSMENT CONTEXT:
241
+ - Triage Score: {agent_context.get("triage_score", "N/A")} (Risk: {agent_context.get("triage_risk", "N/A")})
242
+ - Critical Danger Signs: {", ".join(agent_context.get("critical_signs", [])) or "None"}
243
+ - WHO IMNCI Classification: {agent_context.get("protocol_classification", "N/A")}
244
+ - Applicable Protocols: {", ".join(agent_context.get("applicable_protocols", [])) or "N/A"}
245
+ - Referral Decision: {"YES" if agent_context.get("referral_needed") else "NO"} (Urgency: {agent_context.get("referral_urgency", "N/A")})
246
+ """
247
+
248
+ if agent_reasoning:
249
+ prompt += f"""
250
+ 6. AGENT REASONING TRAIL:
251
+ {agent_reasoning}
252
+ """
253
+
254
+ prompt += """
255
+ Based on these findings, provide a clinical assessment following WHO IMNCI protocols:
256
+
257
+ 1. ASSESSMENT SUMMARY (2-3 sentences in simple language)
258
+ 2. SEVERITY LEVEL (GREEN = routine care, YELLOW = close monitoring, RED = urgent referral)
259
+ 3. IMMEDIATE ACTIONS for the CHW (bullet points, simple steps)
260
+ 4. REFERRAL RECOMMENDATION (Yes/No, and if yes, urgency level)
261
+ 5. FOLLOW-UP PLAN (when to reassess)
262
+
263
+ Use simple language appropriate for a community health worker with basic training.
264
+ Focus on actionable steps they can take immediately.
265
+ """
266
+ return prompt
267
+
268
+ def synthesize(self, findings: Dict) -> Dict:
269
+ """
270
+ Synthesize all findings into clinical recommendations.
271
+
272
+ Args:
273
+ findings: Dictionary with anemia, jaundice, cry analysis results
274
+
275
+ Returns:
276
+ Clinical summary and recommendations
277
+ """
278
+ if self.use_medgemma and self.model is not None:
279
+ return self._synthesize_with_medgemma(findings)
280
+ else:
281
+ return self._synthesize_rule_based(findings)
282
+
283
+ def _synthesize_with_medgemma(self, findings: Dict) -> Dict:
284
+ """Synthesize using MedGemma model.
285
+
286
+ Falls back to rule-based synthesis if generation fails (e.g. CUDA OOM,
287
+ device-side assertion, or any other runtime error).
288
+ """
289
+ try:
290
+ prompt = self._build_prompt(findings)
291
+
292
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
293
+ # For models loaded with device_map="auto", route inputs to the
294
+ # embedding layer's device to avoid CPU/CUDA mismatch.
295
+ try:
296
+ input_device = self.model.get_input_embeddings().weight.device
297
+ except Exception:
298
+ input_device = self.device
299
+ inputs = {k: v.to(input_device) for k, v in inputs.items()}
300
+
301
+ prompt_len = inputs["input_ids"].shape[-1]
302
+
303
+ with torch.no_grad():
304
+ outputs = self.model.generate(
305
+ **inputs,
306
+ max_new_tokens=500,
307
+ temperature=0.7,
308
+ do_sample=True,
309
+ top_p=0.9,
310
+ )
311
+
312
+ # Extract only the generated tokens (after the prompt)
313
+ generated_ids = outputs[0][prompt_len:]
314
+ response = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
315
+
316
+ # Guard against empty or very short responses
317
+ if len(response) < 20:
318
+ return self._synthesize_rule_based(findings)
319
+
320
+ # Determine display name for the model
321
+ if "1.5" in self.model_name:
322
+ display_name = "MedGemma 1.5 4B"
323
+ else:
324
+ display_name = "MedGemma 4B"
325
+
326
+ return {
327
+ "summary": response,
328
+ "model": display_name,
329
+ "model_id": self.model_name,
330
+ "generated_at": datetime.now().isoformat(),
331
+ "findings_used": list(findings.keys()),
332
+ }
333
+ except Exception as e:
334
+ print(f"MedGemma generation failed: {e}. Falling back to rule-based synthesis.")
335
+ # Disable MedGemma to avoid repeated CUDA errors that corrupt the
336
+ # device context and break subsequent GPU operations.
337
+ self.use_medgemma = False
338
+ self._medgemma_available = False
339
+ self.model = None
340
+ try:
341
+ torch.cuda.empty_cache()
342
+ except Exception:
343
+ pass
344
+ return self._synthesize_rule_based(findings)
345
+
346
+ def _synthesize_rule_based(self, findings: Dict) -> Dict:
347
+ """
348
+ Rule-based clinical synthesis (fallback when MedGemma unavailable).
349
+
350
+ Follows WHO IMNCI protocols for maternal and neonatal care.
351
+ """
352
+ # Extract findings
353
+ anemia = findings.get("anemia", {})
354
+ jaundice = findings.get("jaundice", {})
355
+ cry = findings.get("cry", {})
356
+
357
+ # Determine overall severity
358
+ severity_score = 0
359
+ urgent_conditions = []
360
+ actions = []
361
+ referral_needed = False
362
+ referral_urgency = "none"
363
+
364
+ # Assess anemia
365
+ if anemia.get("is_anemic"):
366
+ if anemia.get("risk_level") == "high":
367
+ severity_score += 3
368
+ urgent_conditions.append("Severe anemia")
369
+ actions.append("Refer for blood transfusion if Hb < 7 g/dL")
370
+ referral_needed = True
371
+ referral_urgency = "urgent"
372
+ elif anemia.get("risk_level") == "medium":
373
+ severity_score += 2
374
+ urgent_conditions.append("Moderate anemia")
375
+ actions.append("Start iron supplementation")
376
+ actions.append("Schedule blood test within 48 hours")
377
+ else:
378
+ severity_score += 1
379
+ actions.append("Monitor hemoglobin levels")
380
+ actions.append("Encourage iron-rich foods")
381
+
382
+ # Assess jaundice
383
+ if jaundice.get("has_jaundice"):
384
+ if jaundice.get("needs_phototherapy"):
385
+ severity_score += 3
386
+ urgent_conditions.append("Severe jaundice requiring phototherapy")
387
+ actions.append("URGENT: Start phototherapy immediately")
388
+ actions.append("Refer to hospital if phototherapy unavailable")
389
+ referral_needed = True
390
+ referral_urgency = "immediate"
391
+ elif jaundice.get("severity") in ["moderate", "severe"]:
392
+ severity_score += 2
393
+ urgent_conditions.append("Moderate jaundice")
394
+ actions.append("Expose baby to indirect sunlight")
395
+ actions.append("Ensure frequent breastfeeding")
396
+ actions.append("Recheck in 12-24 hours")
397
+ else:
398
+ severity_score += 1
399
+ actions.append("Continue breastfeeding")
400
+ actions.append("Monitor skin color")
401
+
402
+ # Assess cry analysis
403
+ if cry.get("is_abnormal"):
404
+ if cry.get("asphyxia_risk", 0) > 0.6:
405
+ severity_score += 3
406
+ urgent_conditions.append("Signs of birth asphyxia")
407
+ actions.append("URGENT: Check airway, breathing, circulation")
408
+ actions.append("Provide warmth and stimulation")
409
+ actions.append("Immediate referral for evaluation")
410
+ referral_needed = True
411
+ referral_urgency = "immediate"
412
+ else:
413
+ severity_score += 1
414
+ actions.append("Monitor cry patterns")
415
+ actions.append("Assess feeding and alertness")
416
+
417
+ # Determine overall severity level
418
+ if severity_score >= 5 or referral_urgency == "immediate":
419
+ severity_level = "RED"
420
+ summary = f"URGENT ATTENTION NEEDED. {', '.join(urgent_conditions)}. Immediate medical intervention required."
421
+ elif severity_score >= 2:
422
+ severity_level = "YELLOW"
423
+ summary = f"Close monitoring required. {', '.join(urgent_conditions) if urgent_conditions else 'Some abnormal findings detected'}. Follow recommended actions."
424
+ else:
425
+ severity_level = "GREEN"
426
+ summary = "Routine care. No immediate concerns detected. Continue standard monitoring."
427
+
428
+ # Default actions if none specified
429
+ if not actions:
430
+ actions = [
431
+ "Continue routine care",
432
+ "Ensure adequate nutrition",
433
+ "Schedule follow-up in 1 week",
434
+ ]
435
+
436
+ # Follow-up plan
437
+ if severity_level == "RED":
438
+ follow_up = "Immediate referral. Follow up after hospital evaluation."
439
+ elif severity_level == "YELLOW":
440
+ follow_up = "Reassess in 24-48 hours. Refer if condition worsens."
441
+ else:
442
+ follow_up = "Routine follow-up in 1-2 weeks."
443
+
444
+ return {
445
+ "summary": summary,
446
+ "severity_level": severity_level,
447
+ "severity_description": self.SEVERITY_LEVELS[severity_level],
448
+ "immediate_actions": actions,
449
+ "referral_needed": referral_needed,
450
+ "referral_urgency": referral_urgency,
451
+ "follow_up": follow_up,
452
+ "urgent_conditions": urgent_conditions,
453
+ "model": "Rule-based (WHO IMNCI)",
454
+ "generated_at": datetime.now().isoformat(),
455
+ }
456
+
457
+ def get_who_protocol(self, condition: str) -> Dict:
458
+ """
459
+ Get WHO IMNCI protocol for a specific condition.
460
+
461
+ Args:
462
+ condition: Condition name (anemia, jaundice, asphyxia)
463
+
464
+ Returns:
465
+ Protocol details
466
+ """
467
+ protocols = {
468
+ "anemia": {
469
+ "name": "Maternal Anemia Management",
470
+ "source": "WHO IMNCI Guidelines",
471
+ "steps": [
472
+ "Assess pallor of conjunctiva, palms, and nail beds",
473
+ "If severe pallor: Urgent referral",
474
+ "If some pallor: Iron supplementation + folic acid",
475
+ "Counsel on iron-rich foods",
476
+ "Follow up in 4 weeks",
477
+ ],
478
+ "referral_criteria": "Hb < 7 g/dL or severe pallor with symptoms",
479
+ },
480
+ "jaundice": {
481
+ "name": "Neonatal Jaundice Management",
482
+ "source": "WHO IMNCI Guidelines",
483
+ "steps": [
484
+ "Check for yellow skin/eyes within first 24 hours",
485
+ "If jaundice in first 24 hours: URGENT referral",
486
+ "If moderate jaundice: Frequent breastfeeding, sun exposure",
487
+ "If bilirubin > 15 mg/dL: Phototherapy",
488
+ "If bilirubin > 25 mg/dL: Exchange transfusion",
489
+ ],
490
+ "referral_criteria": "Jaundice < 24 hours old, bilirubin > 20 mg/dL",
491
+ },
492
+ "asphyxia": {
493
+ "name": "Birth Asphyxia Management",
494
+ "source": "WHO Neonatal Resuscitation Guidelines",
495
+ "steps": [
496
+ "Assess APGAR score at 1 and 5 minutes",
497
+ "Clear airway if needed",
498
+ "Provide warmth and stimulation",
499
+ "If not breathing: Begin resuscitation",
500
+ "Refer for evaluation if abnormal cry or poor feeding",
501
+ ],
502
+ "referral_criteria": "APGAR < 7, abnormal cry, seizures, poor feeding",
503
+ },
504
+ }
505
+ return protocols.get(condition.lower(), {"error": "Protocol not found"})
506
+
507
+
508
+ def test_synthesizer():
509
+ """Test the clinical synthesizer."""
510
+ print("Testing Clinical Synthesizer...")
511
+
512
+ synthesizer = ClinicalSynthesizer(use_medgemma=False) # Use rule-based for testing
513
+
514
+ # Test case: Multiple findings
515
+ findings = {
516
+ "anemia": {
517
+ "is_anemic": True,
518
+ "confidence": 0.85,
519
+ "risk_level": "medium",
520
+ "estimated_hemoglobin": 9.5,
521
+ },
522
+ "jaundice": {
523
+ "has_jaundice": True,
524
+ "confidence": 0.75,
525
+ "severity": "mild",
526
+ "estimated_bilirubin": 8.5,
527
+ "needs_phototherapy": False,
528
+ },
529
+ "cry": {
530
+ "is_abnormal": False,
531
+ "asphyxia_risk": 0.2,
532
+ "cry_type": "hunger",
533
+ },
534
+ "symptoms": "Mother reports baby seems tired after feeding",
535
+ }
536
+
537
+ result = synthesizer.synthesize(findings)
538
+
539
+ print("\n=== Clinical Synthesis Result ===")
540
+ print(f"Summary: {result['summary']}")
541
+ print(f"Severity: {result.get('severity_level', 'N/A')}")
542
+ print(f"Referral Needed: {result.get('referral_needed', 'N/A')}")
543
+ print(f"Actions: {result.get('immediate_actions', [])}")
544
+ print(f"Follow-up: {result.get('follow_up', 'N/A')}")
545
+
546
+
547
+ if __name__ == "__main__":
548
+ test_synthesizer()
src/nexus/cry_analyzer.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cry Analyzer Module
3
+
4
+ Uses HeAR from Google HAI-DEF for infant cry analysis and birth asphyxia detection.
5
+ Implements embedding extraction + linear classifier per NEXUS_MASTER_PLAN.md.
6
+
7
+ HAI-DEF Model: HeAR (Health Acoustic Representations)
8
+ Source: https://github.com/Google-Health/google-health/tree/master/health_acoustic_representations
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple, Union
16
+ import warnings
17
+ import os
18
+
19
+ try:
20
+ import librosa
21
+ import soundfile as sf
22
+ HAS_AUDIO = True
23
+ except ImportError:
24
+ HAS_AUDIO = False
25
+
26
+ try:
27
+ from sklearn.linear_model import LogisticRegression
28
+ import joblib
29
+ HAS_SKLEARN = True
30
+ except ImportError:
31
+ HAS_SKLEARN = False
32
+
33
+ # HeAR PyTorch via HuggingFace
34
+ try:
35
+ from transformers import AutoModel as HearAutoModel
36
+ HAS_HEAR_PYTORCH = True
37
+ except ImportError:
38
+ HAS_HEAR_PYTORCH = False
39
+
40
+
41
+ class CryAnalyzer:
42
+ """
43
+ Analyzes infant cry audio for birth asphyxia detection using HeAR.
44
+
45
+ HAI-DEF Model: HeAR (google/hear-pytorch)
46
+ Method: Embedding extraction + acoustic feature analysis
47
+
48
+ Process:
49
+ 1. Split audio into 2-second chunks (HeAR requirement)
50
+ 2. Extract HeAR embeddings (512-dim per chunk)
51
+ 3. Aggregate embeddings (mean pooling)
52
+ 4. Classify with trained linear model or rule-based fallback
53
+ """
54
+
55
+ # HeAR model configuration
56
+ SAMPLE_RATE = 16000 # Hz - HeAR requires 16kHz
57
+ CHUNK_DURATION = 2.0 # seconds - HeAR chunk size
58
+ CHUNK_SIZE = 32000 # samples (2 seconds at 16kHz)
59
+ EMBEDDING_DIM = 512 # HeAR embedding dimension
60
+
61
+ # Acoustic feature thresholds (fallback if HeAR unavailable)
62
+ NORMAL_F0_RANGE = (250, 450) # Hz
63
+ ASPHYXIA_F0_THRESHOLD = 500 # Hz - higher F0 indicates distress
64
+ MIN_CRY_DURATION = 0.5 # seconds
65
+
66
+ # HeAR model ID on HuggingFace (PyTorch)
67
+ HEAR_MODEL_ID = "google/hear-pytorch"
68
+
69
+ # Default classifier path (relative to project root)
70
+ DEFAULT_CLASSIFIER_PATHS = [
71
+ Path(__file__).parent.parent.parent / "models" / "linear_probes" / "cry_classifier.joblib",
72
+ Path("models/linear_probes/cry_classifier.joblib"),
73
+ ]
74
+
75
+ # Cry type labels from trained classifier
76
+ CRY_TYPE_LABELS = {
77
+ 0: "belly_pain",
78
+ 1: "burping",
79
+ 2: "discomfort",
80
+ 3: "hungry",
81
+ 4: "tired",
82
+ }
83
+
84
+ def __init__(
85
+ self,
86
+ device: Optional[str] = None,
87
+ classifier_path: Optional[str] = None,
88
+ use_hear: bool = True,
89
+ ):
90
+ """
91
+ Initialize the Cry Analyzer with HeAR.
92
+
93
+ Args:
94
+ device: Device to run model on
95
+ classifier_path: Path to trained linear classifier (optional, auto-detected)
96
+ use_hear: Whether to use HeAR embeddings (True) or acoustic features (False)
97
+ """
98
+ if not HAS_AUDIO:
99
+ raise ImportError("librosa and soundfile required. Install with: pip install librosa soundfile")
100
+
101
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
102
+ self.classifier_path = classifier_path
103
+ self.classifier = None
104
+ self.hear_model = None
105
+ self.use_hear = use_hear
106
+ self._hear_available = False
107
+
108
+ # Try to load HeAR model
109
+ if use_hear:
110
+ self._load_hear_model()
111
+
112
+ # Load trained classifier: explicit path first, then auto-detect
113
+ self._load_classifier(classifier_path)
114
+
115
+ mode = "HeAR" if self._hear_available else "Acoustic Features (HeAR unavailable)"
116
+ classifier_status = "with trained classifier" if self.classifier else "heuristic scoring"
117
+ print(f"Cry Analyzer (HAI-DEF {mode}, {classifier_status}) initialized on {self.device}")
118
+
119
+ def _load_classifier(self, classifier_path: Optional[str] = None) -> None:
120
+ """Load trained cry classifier from file.
121
+
122
+ Searches explicit path first, then default locations.
123
+ """
124
+ if not HAS_SKLEARN:
125
+ return
126
+
127
+ paths_to_try = []
128
+ if classifier_path:
129
+ paths_to_try.append(Path(classifier_path))
130
+ paths_to_try.extend(self.DEFAULT_CLASSIFIER_PATHS)
131
+
132
+ for path in paths_to_try:
133
+ if path.exists():
134
+ try:
135
+ self.classifier = joblib.load(path)
136
+ self.classifier_path = str(path)
137
+ print(f"Loaded cry classifier from {path}")
138
+ return
139
+ except Exception as e:
140
+ print(f"Warning: Could not load classifier from {path}: {e}")
141
+
142
+ def _load_hear_model(self) -> None:
143
+ """Load HeAR model from HuggingFace (PyTorch).
144
+
145
+ HeAR (Health Acoustic Representations) is a Google HAI-DEF model
146
+ for health-related audio analysis. It produces 512-dimensional
147
+ embeddings from 2-second audio chunks at 16kHz.
148
+ """
149
+ if not HAS_HEAR_PYTORCH:
150
+ print("Warning: transformers not available. Install with: pip install transformers")
151
+ print("Falling back to acoustic feature extraction (deterministic)")
152
+ self._hear_available = False
153
+ return
154
+
155
+ hf_token = os.environ.get("HF_TOKEN")
156
+
157
+ try:
158
+ print(f"Loading HeAR model from HuggingFace: {self.HEAR_MODEL_ID}")
159
+ self.hear_model = HearAutoModel.from_pretrained(
160
+ self.HEAR_MODEL_ID,
161
+ token=hf_token,
162
+ trust_remote_code=True,
163
+ )
164
+ self.hear_model = self.hear_model.to(self.device)
165
+ self.hear_model.eval()
166
+ self._hear_available = True
167
+ print("HeAR model loaded successfully (PyTorch)")
168
+
169
+ except Exception as e:
170
+ print(f"Warning: Could not load HeAR model: {e}")
171
+ print("Falling back to acoustic feature extraction (deterministic)")
172
+ self.hear_model = None
173
+ self._hear_available = False
174
+
175
+ def _split_audio_chunks(self, audio: np.ndarray) -> List[np.ndarray]:
176
+ """
177
+ Split audio into 2-second chunks for HeAR processing.
178
+
179
+ Args:
180
+ audio: Audio signal array (16kHz)
181
+
182
+ Returns:
183
+ List of audio chunks (each 2 seconds / 32000 samples)
184
+ """
185
+ chunks = []
186
+ for i in range(0, len(audio), self.CHUNK_SIZE):
187
+ chunk = audio[i:i + self.CHUNK_SIZE]
188
+ if len(chunk) < self.CHUNK_SIZE:
189
+ # Pad with zeros if needed
190
+ chunk = np.pad(chunk, (0, self.CHUNK_SIZE - len(chunk)))
191
+ chunks.append(chunk)
192
+ return chunks
193
+
194
+ def extract_hear_embeddings(self, audio: np.ndarray) -> np.ndarray:
195
+ """
196
+ Extract HeAR embeddings from audio using PyTorch.
197
+
198
+ HeAR is a ViT model that expects mel-PCEN spectrograms, not raw audio.
199
+ Pipeline: raw audio (32000 samples) → preprocess_audio() → (1, 1, 192, 128)
200
+ → ViT forward pass → pool last_hidden_state → embedding
201
+
202
+ Args:
203
+ audio: Audio signal (16kHz)
204
+
205
+ Returns:
206
+ Aggregated embedding (HeAR hidden_size dim, or 8-dim fallback)
207
+ """
208
+ if not self._hear_available or self.hear_model is None:
209
+ # Fallback: use acoustic features as pseudo-embeddings
210
+ # This is deterministic - same audio always produces same features
211
+ features = self.extract_features(audio, self.SAMPLE_RATE)
212
+ # Create a feature vector from acoustic features
213
+ feature_vector = np.array([
214
+ features.get("f0_mean", 0),
215
+ features.get("f0_std", 0),
216
+ features.get("f0_range", 0),
217
+ features.get("voiced_ratio", 0),
218
+ features.get("spectral_centroid_mean", 0),
219
+ features.get("spectral_bandwidth_mean", 0),
220
+ features.get("zcr_mean", 0),
221
+ features.get("rms_mean", 0),
222
+ ])
223
+ return feature_vector
224
+
225
+ from .hear_preprocessing import preprocess_audio
226
+
227
+ # Split into 2-second chunks for HeAR
228
+ chunks = self._split_audio_chunks(audio)
229
+
230
+ # Extract embeddings for each chunk using HeAR (PyTorch)
231
+ embeddings = []
232
+ with torch.no_grad():
233
+ for chunk in chunks:
234
+ # Convert raw audio to tensor: (1, 32000)
235
+ chunk_tensor = torch.tensor(
236
+ chunk.astype(np.float32)
237
+ ).unsqueeze(0).to(self.device)
238
+
239
+ # Preprocess: raw audio → mel-PCEN spectrogram (1, 1, 192, 128)
240
+ spectrogram = preprocess_audio(chunk_tensor)
241
+
242
+ # Forward pass: HeAR ViT expects pixel_values
243
+ output = self.hear_model(
244
+ pixel_values=spectrogram,
245
+ return_dict=True,
246
+ )
247
+
248
+ # Extract embedding from ViT output
249
+ if hasattr(output, 'pooler_output') and output.pooler_output is not None:
250
+ embedding = output.pooler_output
251
+ elif hasattr(output, 'last_hidden_state'):
252
+ # Mean pool over sequence dimension (skip CLS token)
253
+ embedding = output.last_hidden_state[:, 1:, :].mean(dim=1)
254
+ elif isinstance(output, torch.Tensor):
255
+ embedding = output
256
+ else:
257
+ embedding = list(output.values())[0] if hasattr(output, 'values') else output[0]
258
+
259
+ embeddings.append(embedding.cpu().numpy().squeeze())
260
+
261
+ # Aggregate embeddings (mean pooling across chunks)
262
+ aggregated = np.mean(embeddings, axis=0)
263
+ return aggregated
264
+
265
+ def load_audio(
266
+ self,
267
+ audio_path: Union[str, Path],
268
+ sr: int = None,
269
+ ) -> Tuple[np.ndarray, int]:
270
+ """
271
+ Load audio file.
272
+
273
+ Args:
274
+ audio_path: Path to audio file
275
+ sr: Target sample rate (uses file's native if None)
276
+
277
+ Returns:
278
+ Tuple of (audio_array, sample_rate)
279
+ """
280
+ sr = sr or self.SAMPLE_RATE
281
+ audio, file_sr = librosa.load(audio_path, sr=sr)
282
+ return audio, sr
283
+
284
+ def extract_features(self, audio: np.ndarray, sr: int) -> Dict:
285
+ """
286
+ Extract acoustic features from cry audio.
287
+
288
+ Features based on cry analysis literature:
289
+ - Fundamental frequency (F0)
290
+ - MFCCs (mel-frequency cepstral coefficients)
291
+ - Spectral features
292
+ - Temporal features
293
+
294
+ Args:
295
+ audio: Audio signal array
296
+ sr: Sample rate
297
+
298
+ Returns:
299
+ Dictionary of extracted features
300
+ """
301
+ features = {}
302
+
303
+ # Ensure minimum length
304
+ if len(audio) < sr * self.MIN_CRY_DURATION:
305
+ # Pad if too short
306
+ audio = np.pad(audio, (0, int(sr * self.MIN_CRY_DURATION) - len(audio)))
307
+
308
+ # Duration
309
+ features["duration"] = len(audio) / sr
310
+
311
+ # Fundamental frequency (F0) using pyin
312
+ with warnings.catch_warnings():
313
+ warnings.simplefilter("ignore")
314
+ f0, voiced_flag, voiced_probs = librosa.pyin(
315
+ audio,
316
+ fmin=librosa.note_to_hz('C2'),
317
+ fmax=librosa.note_to_hz('C7'),
318
+ sr=sr,
319
+ )
320
+
321
+ # F0 statistics (ignoring unvoiced frames)
322
+ f0_valid = f0[~np.isnan(f0)]
323
+ if len(f0_valid) > 0:
324
+ features["f0_mean"] = float(np.mean(f0_valid))
325
+ features["f0_std"] = float(np.std(f0_valid))
326
+ features["f0_min"] = float(np.min(f0_valid))
327
+ features["f0_max"] = float(np.max(f0_valid))
328
+ features["f0_range"] = features["f0_max"] - features["f0_min"]
329
+ else:
330
+ features["f0_mean"] = 0
331
+ features["f0_std"] = 0
332
+ features["f0_min"] = 0
333
+ features["f0_max"] = 0
334
+ features["f0_range"] = 0
335
+
336
+ # Voiced ratio (cry vs silence)
337
+ features["voiced_ratio"] = float(np.mean(voiced_flag))
338
+
339
+ # MFCCs
340
+ mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
341
+ for i in range(13):
342
+ features[f"mfcc_{i}_mean"] = float(np.mean(mfccs[i]))
343
+ features[f"mfcc_{i}_std"] = float(np.std(mfccs[i]))
344
+
345
+ # Spectral features
346
+ spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)
347
+ spectral_bandwidth = librosa.feature.spectral_bandwidth(y=audio, sr=sr)
348
+ spectral_rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sr)
349
+
350
+ features["spectral_centroid_mean"] = float(np.mean(spectral_centroid))
351
+ features["spectral_bandwidth_mean"] = float(np.mean(spectral_bandwidth))
352
+ features["spectral_rolloff_mean"] = float(np.mean(spectral_rolloff))
353
+
354
+ # Zero crossing rate (higher in noisy/irregular cries)
355
+ zcr = librosa.feature.zero_crossing_rate(audio)
356
+ features["zcr_mean"] = float(np.mean(zcr))
357
+ features["zcr_std"] = float(np.std(zcr))
358
+
359
+ # RMS energy
360
+ rms = librosa.feature.rms(y=audio)
361
+ features["rms_mean"] = float(np.mean(rms))
362
+ features["rms_std"] = float(np.std(rms))
363
+
364
+ # Tempo estimation (cry rhythm)
365
+ onset_env = librosa.onset.onset_strength(y=audio, sr=sr)
366
+ tempo = librosa.feature.tempo(onset_envelope=onset_env, sr=sr)
367
+ features["tempo"] = float(tempo[0]) if len(tempo) > 0 else 0
368
+
369
+ return features
370
+
371
+ def analyze(self, audio_path: Union[str, Path]) -> Dict:
372
+ """
373
+ Analyze cry audio for health indicators.
374
+
375
+ Uses HeAR embeddings + classifier when available, falls back to
376
+ rule-based acoustic analysis when HeAR is unavailable.
377
+
378
+ Args:
379
+ audio_path: Path to cry audio file
380
+
381
+ Returns:
382
+ Dictionary containing:
383
+ - is_abnormal: Boolean indicating abnormal cry
384
+ - asphyxia_risk: Risk score for birth asphyxia (0-1)
385
+ - cry_type: Detected cry type
386
+ - features: Extracted acoustic features
387
+ - risk_level: "low", "medium", "high"
388
+ - recommendation: Clinical recommendation
389
+ """
390
+ # Load audio
391
+ audio, sr = self.load_audio(audio_path)
392
+
393
+ # Extract acoustic features (always needed for cry_type and feature reporting)
394
+ features = self.extract_features(audio, sr)
395
+
396
+ # Determine cry type based on acoustic features
397
+ cry_type = self._classify_cry_type(features)
398
+
399
+ # Try HeAR-based classification first
400
+ classified_cry_type = None
401
+ if self._hear_available or (self.classifier is not None and HAS_SKLEARN):
402
+ asphyxia_risk, model_used, classified_cry_type = self._analyze_with_hear(audio)
403
+ else:
404
+ asphyxia_risk, model_used = self._analyze_with_rules(features)
405
+
406
+ # Use classifier's cry type if available, otherwise rule-based
407
+ if classified_cry_type is not None:
408
+ cry_type = classified_cry_type
409
+
410
+ # Determine risk level and recommendation based on risk score
411
+ if asphyxia_risk > 0.6:
412
+ risk_level = "high"
413
+ is_abnormal = True
414
+ recommendation = "URGENT: High-pitched abnormal cry detected. Assess for birth asphyxia immediately. Check APGAR score and vital signs."
415
+ elif asphyxia_risk > 0.3:
416
+ risk_level = "medium"
417
+ is_abnormal = True
418
+ recommendation = "CAUTION: Some abnormal cry characteristics. Monitor closely and reassess in 30 minutes."
419
+ else:
420
+ risk_level = "low"
421
+ is_abnormal = False
422
+ recommendation = "Normal cry pattern. Continue routine care."
423
+
424
+ return {
425
+ "is_abnormal": is_abnormal,
426
+ "asphyxia_risk": round(asphyxia_risk, 3),
427
+ "cry_type": cry_type,
428
+ "risk_level": risk_level,
429
+ "recommendation": recommendation,
430
+ "features": {
431
+ "f0_mean": round(features["f0_mean"], 1),
432
+ "f0_std": round(features["f0_std"], 1),
433
+ "duration": round(features["duration"], 2),
434
+ "voiced_ratio": round(features["voiced_ratio"], 2),
435
+ },
436
+ "model": model_used,
437
+ "model_note": self._get_model_note(model_used),
438
+ }
439
+
440
+ def _analyze_with_hear(self, audio: np.ndarray) -> Tuple[float, str, Optional[str]]:
441
+ """
442
+ Analyze cry using HeAR embeddings.
443
+
444
+ Args:
445
+ audio: Audio signal array (16kHz)
446
+
447
+ Returns:
448
+ Tuple of (asphyxia_risk, model_name, predicted_cry_type)
449
+ """
450
+ # Extract HeAR embeddings
451
+ embeddings = self.extract_hear_embeddings(audio)
452
+
453
+ # Use trained classifier if available
454
+ if self.classifier is not None and HAS_SKLEARN:
455
+ embeddings_2d = embeddings.reshape(1, -1)
456
+
457
+ # Multi-class cry type classification
458
+ prediction = int(self.classifier.predict(embeddings_2d)[0])
459
+ predicted_type = self.CRY_TYPE_LABELS.get(prediction, "unknown")
460
+
461
+ # Get class probabilities for confidence
462
+ if hasattr(self.classifier, 'predict_proba'):
463
+ proba = self.classifier.predict_proba(embeddings_2d)[0]
464
+ confidence = float(max(proba))
465
+
466
+ # Derive asphyxia risk from cry type probabilities
467
+ # Pain and belly_pain cries are most associated with distress
468
+ pain_classes = {"belly_pain": 0, "discomfort": 2}
469
+ distress_prob = sum(
470
+ proba[idx] for name, idx in pain_classes.items()
471
+ if idx < len(proba)
472
+ )
473
+ # Scale distress probability to asphyxia risk
474
+ asphyxia_risk = min(1.0, distress_prob * 0.8)
475
+ else:
476
+ confidence = 0.7
477
+ asphyxia_risk = 0.5 if predicted_type in ("belly_pain", "discomfort") else 0.2
478
+
479
+ return asphyxia_risk, "HeAR + Classifier", predicted_type
480
+
481
+ # No classifier: use embedding-based heuristic
482
+ embedding_mean = float(np.mean(embeddings))
483
+ embedding_std = float(np.std(embeddings))
484
+ embedding_max = float(np.max(np.abs(embeddings)))
485
+
486
+ risk_score = 0.0
487
+ if embedding_std > 0.5:
488
+ risk_score += 0.3
489
+ if embedding_max > 2.0:
490
+ risk_score += 0.2
491
+ if abs(embedding_mean) > 0.3:
492
+ risk_score += 0.2
493
+
494
+ return min(risk_score, 1.0), "HeAR (uncalibrated)", None
495
+
496
+ def _analyze_with_rules(self, features: Dict) -> Tuple[float, str]:
497
+ """
498
+ Analyze cry using rule-based acoustic features.
499
+
500
+ Fallback when HeAR is unavailable.
501
+
502
+ Args:
503
+ features: Extracted acoustic features
504
+
505
+ Returns:
506
+ Tuple of (asphyxia_risk, model_name)
507
+ """
508
+ # Rule-based asphyxia risk assessment
509
+ # Based on medical literature on cry acoustics
510
+ asphyxia_indicators = 0
511
+ max_indicators = 5
512
+
513
+ # High F0 (> 500 Hz) is associated with asphyxia
514
+ if features["f0_mean"] > self.ASPHYXIA_F0_THRESHOLD:
515
+ asphyxia_indicators += 1
516
+
517
+ # High F0 variability
518
+ if features["f0_std"] > 100:
519
+ asphyxia_indicators += 1
520
+
521
+ # Wide F0 range
522
+ if features["f0_range"] > 300:
523
+ asphyxia_indicators += 1
524
+
525
+ # Low voiced ratio (fragmented cry)
526
+ if features["voiced_ratio"] < 0.3:
527
+ asphyxia_indicators += 1
528
+
529
+ # High zero crossing rate (irregular)
530
+ if features["zcr_mean"] > 0.15:
531
+ asphyxia_indicators += 1
532
+
533
+ asphyxia_risk = asphyxia_indicators / max_indicators
534
+ return asphyxia_risk, "Acoustic Features"
535
+
536
+ def _get_model_note(self, model_used: str) -> str:
537
+ """Get descriptive note for the model used."""
538
+ notes = {
539
+ "HeAR + Classifier": "HAI-DEF HeAR embeddings with trained linear classifier",
540
+ "HeAR (uncalibrated)": "HAI-DEF HeAR embeddings with heuristic scoring (no trained classifier)",
541
+ "Acoustic Features": "Deterministic acoustic feature extraction (HeAR unavailable)",
542
+ }
543
+ return notes.get(model_used, model_used)
544
+
545
+ def _classify_cry_type(self, features: Dict) -> str:
546
+ """
547
+ Classify cry type based on acoustic features.
548
+
549
+ Categories based on donate-a-cry corpus:
550
+ - hunger: Regular rhythm, moderate pitch
551
+ - pain: High pitch, irregular
552
+ - discomfort: Variable pitch, whimpering
553
+ - tired: Low energy, fragmented
554
+ - belly_pain: High pitch, straining patterns
555
+ """
556
+ f0_mean = features["f0_mean"]
557
+ f0_std = features["f0_std"]
558
+ rms_mean = features["rms_mean"]
559
+ voiced_ratio = features["voiced_ratio"]
560
+
561
+ # Simple rule-based classification
562
+ if f0_mean > 500 and f0_std > 80:
563
+ return "pain"
564
+ elif f0_mean > 450 and rms_mean > 0.1:
565
+ return "belly_pain"
566
+ elif voiced_ratio < 0.4 and rms_mean < 0.05:
567
+ return "tired"
568
+ elif f0_std < 50 and voiced_ratio > 0.5:
569
+ return "hunger"
570
+ else:
571
+ return "discomfort"
572
+
573
+ def analyze_batch(
574
+ self,
575
+ audio_paths: List[Union[str, Path]],
576
+ ) -> List[Dict]:
577
+ """
578
+ Analyze multiple cry audio files.
579
+
580
+ Args:
581
+ audio_paths: List of paths to audio files
582
+
583
+ Returns:
584
+ List of analysis results
585
+ """
586
+ results = []
587
+ for path in audio_paths:
588
+ try:
589
+ result = self.analyze(path)
590
+ result["file"] = str(path)
591
+ results.append(result)
592
+ except Exception as e:
593
+ results.append({
594
+ "file": str(path),
595
+ "error": str(e),
596
+ "is_abnormal": None,
597
+ })
598
+ return results
599
+
600
+ def get_spectrogram(
601
+ self,
602
+ audio_path: Union[str, Path],
603
+ n_mels: int = 128,
604
+ ) -> np.ndarray:
605
+ """
606
+ Generate mel spectrogram for visualization.
607
+
608
+ Args:
609
+ audio_path: Path to audio file
610
+ n_mels: Number of mel bands
611
+
612
+ Returns:
613
+ Mel spectrogram array (dB scale)
614
+ """
615
+ audio, sr = self.load_audio(audio_path)
616
+
617
+ mel_spec = librosa.feature.melspectrogram(
618
+ y=audio,
619
+ sr=sr,
620
+ n_mels=n_mels,
621
+ )
622
+ mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
623
+
624
+ return mel_spec_db
625
+
626
+
627
+ def test_analyzer():
628
+ """Test the cry analyzer with sample audio files."""
629
+ print("Testing Cry Analyzer...")
630
+
631
+ analyzer = CryAnalyzer()
632
+
633
+ # Check for available audio files
634
+ data_dirs = [
635
+ Path(__file__).parent.parent.parent / "data" / "raw" / "cryceleb" / "audio",
636
+ Path(__file__).parent.parent.parent / "data" / "raw" / "donate-a-cry",
637
+ Path(__file__).parent.parent.parent / "data" / "raw" / "infant-cry-dataset" / "cry",
638
+ ]
639
+
640
+ audio_files = []
641
+ for data_dir in data_dirs:
642
+ if data_dir.exists():
643
+ audio_files.extend(list(data_dir.rglob("*.wav"))[:2])
644
+
645
+ if audio_files:
646
+ for audio_path in audio_files[:5]:
647
+ print(f"\nAnalyzing: {audio_path.name}")
648
+ try:
649
+ result = analyzer.analyze(audio_path)
650
+ print(f" Abnormal cry: {result['is_abnormal']}")
651
+ print(f" Asphyxia risk: {result['asphyxia_risk']:.1%}")
652
+ print(f" Cry type: {result['cry_type']}")
653
+ print(f" Risk level: {result['risk_level']}")
654
+ print(f" F0 mean: {result['features']['f0_mean']} Hz")
655
+ except Exception as e:
656
+ print(f" Error: {e}")
657
+ else:
658
+ print("No audio files found. Please download datasets first.")
659
+
660
+
661
+ if __name__ == "__main__":
662
+ test_analyzer()
src/nexus/hear_preprocessing.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HeAR Audio Preprocessing Module
3
+
4
+ Converts raw audio waveforms into mel-PCEN spectrograms required by the
5
+ HeAR (Health Acoustic Representations) ViT model.
6
+
7
+ Pipeline: raw audio (batch, 32000) → normalize → STFT → power spectrogram
8
+ → mel filterbank (128 bins) → PCEN → resize → (batch, 1, 192, 128)
9
+
10
+ Adapted from Google's official HeAR preprocessing:
11
+ https://github.com/Google-Health/google-health/tree/master/health_acoustic_representations
12
+
13
+ Copyright 2025 Google LLC (original implementation)
14
+ Licensed under the Apache License, Version 2.0
15
+ """
16
+
17
+ import math
18
+ from typing import Callable, Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+
24
+ def _enclosing_power_of_two(value: int) -> int:
25
+ """Smallest power of 2 >= value."""
26
+ return int(2 ** math.ceil(math.log2(value))) if value > 0 else 1
27
+
28
+
29
+ def _compute_stft(
30
+ signals: torch.Tensor,
31
+ frame_length: int,
32
+ frame_step: int,
33
+ fft_length: Optional[int] = None,
34
+ window_fn: Optional[Callable[[int], torch.Tensor]] = torch.hann_window,
35
+ pad_end: bool = True,
36
+ ) -> torch.Tensor:
37
+ """Short-time Fourier Transform.
38
+
39
+ Args:
40
+ signals: [..., samples] real-valued tensor.
41
+ frame_length: Window length in samples.
42
+ frame_step: Step size in samples.
43
+ fft_length: FFT size (defaults to smallest power of 2 >= frame_length).
44
+ window_fn: Window function (default: Hann).
45
+ pad_end: Pad signal end with zeros.
46
+
47
+ Returns:
48
+ [..., frames, fft_length//2 + 1] complex64 tensor.
49
+ """
50
+ if signals.ndim < 1:
51
+ raise ValueError(f"Input signals must have rank >= 1, got {signals.ndim}")
52
+
53
+ if fft_length is None:
54
+ fft_length = _enclosing_power_of_two(frame_length)
55
+
56
+ if pad_end:
57
+ n_frames = (
58
+ math.ceil(signals.shape[-1] / frame_step)
59
+ if signals.shape[-1] > 0
60
+ else 0
61
+ )
62
+ padded_length = (
63
+ max(0, (n_frames - 1) * frame_step + frame_length)
64
+ if n_frames > 0
65
+ else frame_length
66
+ )
67
+ padding_needed = max(0, padded_length - signals.shape[-1])
68
+ if padding_needed > 0:
69
+ signals = F.pad(signals, (0, padding_needed))
70
+
71
+ framed_signals = signals.unfold(-1, frame_length, frame_step)
72
+
73
+ if framed_signals.shape[-2] == 0:
74
+ return torch.empty(
75
+ *signals.shape[:-1],
76
+ 0,
77
+ fft_length // 2 + 1,
78
+ dtype=torch.complex64,
79
+ device=signals.device,
80
+ )
81
+
82
+ if window_fn is not None:
83
+ window = (
84
+ window_fn(frame_length)
85
+ .to(framed_signals.device)
86
+ .to(framed_signals.dtype)
87
+ )
88
+ framed_signals = framed_signals * window
89
+
90
+ return torch.fft.rfft(framed_signals, n=fft_length, dim=-1)
91
+
92
+
93
+ def _ema(
94
+ inputs: torch.Tensor,
95
+ num_channels: int,
96
+ smooth_coef: float,
97
+ initial_state: Optional[torch.Tensor] = None,
98
+ ) -> torch.Tensor:
99
+ """Exponential Moving Average for PCEN smoothing.
100
+
101
+ Args:
102
+ inputs: (batch, timesteps, channels) tensor.
103
+ num_channels: Number of channels.
104
+ smooth_coef: EMA smoothing coefficient.
105
+ initial_state: Optional (batch, channels) initial state.
106
+
107
+ Returns:
108
+ (batch, timesteps, channels) EMA output.
109
+ """
110
+ batch_size, timesteps, _ = inputs.shape
111
+
112
+ if initial_state is None:
113
+ ema_state = torch.zeros(
114
+ (batch_size, num_channels), dtype=torch.float32, device=inputs.device
115
+ )
116
+ else:
117
+ ema_state = initial_state
118
+
119
+ identity_kernel = (
120
+ torch.eye(num_channels, dtype=torch.float32, device=inputs.device)
121
+ * smooth_coef
122
+ )
123
+ identity_recurrent_kernel = (
124
+ torch.eye(num_channels, dtype=torch.float32, device=inputs.device)
125
+ * (1.0 - smooth_coef)
126
+ )
127
+
128
+ output_sequence = []
129
+ start = initial_state is not None
130
+ if start:
131
+ output_sequence.append(ema_state)
132
+
133
+ for t in range(start, timesteps):
134
+ current_input = inputs[:, t, :]
135
+ output = torch.matmul(current_input, identity_kernel) + torch.matmul(
136
+ ema_state, identity_recurrent_kernel
137
+ )
138
+ ema_state = output
139
+ output_sequence.append(output)
140
+
141
+ return torch.stack(output_sequence, dim=1)
142
+
143
+
144
+ def _pcen_function(
145
+ inputs: torch.Tensor,
146
+ num_channels: int = 128,
147
+ alpha: float = 0.8,
148
+ smooth_coef: float = 0.04,
149
+ delta: float = 2.0,
150
+ root: float = 2.0,
151
+ floor: float = 1e-8,
152
+ ) -> torch.Tensor:
153
+ """Per-Channel Energy Normalization.
154
+
155
+ See https://arxiv.org/abs/1607.05666
156
+ """
157
+ alpha_param = torch.ones(num_channels).to(inputs.device).to(inputs.dtype)
158
+ alpha_param = alpha_param * alpha
159
+ delta_param = torch.ones(num_channels).to(inputs.device).to(inputs.dtype)
160
+ delta_param = delta_param * delta
161
+ root_param = torch.ones(num_channels).to(inputs.device).to(inputs.dtype)
162
+ root_param = root_param * root
163
+
164
+ alpha_param = torch.minimum(alpha_param, torch.ones_like(alpha_param))
165
+ root_param = torch.maximum(root_param, torch.ones_like(root_param))
166
+
167
+ ema_smoother = _ema(
168
+ inputs,
169
+ num_channels=num_channels,
170
+ smooth_coef=smooth_coef,
171
+ initial_state=inputs[:, 0] if inputs.ndim > 1 else None,
172
+ ).to(inputs.device)
173
+
174
+ one_over_root = 1.0 / root_param
175
+ output = (
176
+ inputs / (floor + ema_smoother) ** alpha_param + delta_param
177
+ ) ** one_over_root - delta_param**one_over_root
178
+ return output
179
+
180
+
181
+ def _hertz_to_mel(frequencies_hertz: torch.Tensor) -> torch.Tensor:
182
+ """Convert Hz to mel scale."""
183
+ return 2595.0 * torch.log10(1.0 + frequencies_hertz / 700.0)
184
+
185
+
186
+ def _linear_to_mel_weight_matrix(
187
+ device: torch.device,
188
+ num_mel_bins: int = 128,
189
+ num_spectrogram_bins: int = 201,
190
+ sample_rate: float = 16000,
191
+ lower_edge_hertz: float = 0.0,
192
+ upper_edge_hertz: float = 8000.0,
193
+ dtype: torch.dtype = torch.float32,
194
+ ) -> torch.Tensor:
195
+ """Mel filterbank matrix: [num_spectrogram_bins, num_mel_bins]."""
196
+ zero = torch.tensor(0.0, dtype=dtype, device=device)
197
+ nyquist_hertz = torch.tensor(sample_rate, dtype=dtype) / 2.0
198
+ lower_edge = torch.tensor(lower_edge_hertz, dtype=dtype, device=device)
199
+ upper_edge = torch.tensor(upper_edge_hertz, dtype=dtype, device=device)
200
+
201
+ bands_to_zero = 1
202
+ linear_frequencies = torch.linspace(
203
+ zero, nyquist_hertz, num_spectrogram_bins, dtype=dtype, device=device
204
+ )[bands_to_zero:]
205
+ spectrogram_bins_mel = _hertz_to_mel(linear_frequencies).unsqueeze(1)
206
+
207
+ band_edges_mel = torch.linspace(
208
+ _hertz_to_mel(lower_edge),
209
+ _hertz_to_mel(upper_edge),
210
+ num_mel_bins + 2,
211
+ dtype=dtype,
212
+ device=device,
213
+ )
214
+ band_edges_mel = band_edges_mel.unfold(0, 3, 1)
215
+
216
+ lower_edge_mel = band_edges_mel[:, 0].unsqueeze(0)
217
+ center_mel = band_edges_mel[:, 1].unsqueeze(0)
218
+ upper_edge_mel = band_edges_mel[:, 2].unsqueeze(0)
219
+
220
+ lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
221
+ center_mel - lower_edge_mel
222
+ )
223
+ upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
224
+ upper_edge_mel - center_mel
225
+ )
226
+
227
+ mel_weights_matrix = torch.maximum(
228
+ zero, torch.minimum(lower_slopes, upper_slopes)
229
+ )
230
+
231
+ return F.pad(
232
+ mel_weights_matrix, (0, 0, bands_to_zero, 0), mode="constant", value=0.0
233
+ )
234
+
235
+
236
+ def _torch_resize_bilinear_tf_compat(
237
+ images: torch.Tensor,
238
+ size: tuple,
239
+ ) -> torch.Tensor:
240
+ """Bilinear resize matching TF's tf.image.resize behavior.
241
+
242
+ Args:
243
+ images: [C, H, W] or [B, C, H, W] float tensor.
244
+ size: (new_height, new_width).
245
+
246
+ Returns:
247
+ Resized tensor with same rank as input.
248
+ """
249
+ new_height, new_width = size
250
+ images = images.to(torch.float32)
251
+
252
+ was_3d = False
253
+ if images.dim() == 3:
254
+ images = images.unsqueeze(0)
255
+ was_3d = True
256
+
257
+ resized = F.interpolate(
258
+ images,
259
+ size=(new_height, new_width),
260
+ mode="bilinear",
261
+ align_corners=False,
262
+ antialias=False,
263
+ )
264
+
265
+ if was_3d:
266
+ resized = resized.squeeze(0)
267
+
268
+ return resized
269
+
270
+
271
+ def _mel_pcen(x: torch.Tensor) -> torch.Tensor:
272
+ """Mel spectrogram + PCEN normalization."""
273
+ x = x.float()
274
+ # Scale to [-1, 1]
275
+ x -= torch.min(x)
276
+ x = x / (torch.max(x) + 1e-8)
277
+ x = (x * 2) - 1
278
+
279
+ frame_length = 16 * 25 # 400
280
+ frame_step = 160
281
+
282
+ stft = _compute_stft(
283
+ x,
284
+ frame_length=frame_length,
285
+ fft_length=frame_length,
286
+ frame_step=frame_step,
287
+ window_fn=torch.hann_window,
288
+ pad_end=True,
289
+ )
290
+ spectrograms = torch.square(torch.abs(stft))
291
+
292
+ mel_transform = _linear_to_mel_weight_matrix(x.device)
293
+ mel_spectrograms = torch.matmul(spectrograms, mel_transform)
294
+ return _pcen_function(mel_spectrograms)
295
+
296
+
297
+ def preprocess_audio(audio: torch.Tensor) -> torch.Tensor:
298
+ """Convert raw audio waveform to mel-PCEN spectrogram for HeAR.
299
+
300
+ Args:
301
+ audio: [batch, samples] tensor. 2-second clips at 16kHz (32000 samples).
302
+
303
+ Returns:
304
+ [batch, 1, 192, 128] mel-PCEN spectrogram tensor.
305
+ """
306
+ if audio.ndim != 2:
307
+ raise ValueError(f"Input audio must have rank 2, got rank {audio.ndim}")
308
+
309
+ if audio.shape[1] < 32000:
310
+ n = 32000 - audio.shape[1]
311
+ audio = F.pad(audio, pad=(0, n), mode="constant", value=0)
312
+ elif audio.shape[1] > 32000:
313
+ raise ValueError(
314
+ f"Input audio must have <= 32000 samples, got {audio.shape[1]}"
315
+ )
316
+
317
+ spectrogram = _mel_pcen(audio)
318
+ # Add channel dimension: [B, H, W] → [B, 1, H, W]
319
+ spectrogram = torch.unsqueeze(spectrogram, dim=1)
320
+ return _torch_resize_bilinear_tf_compat(spectrogram, size=(192, 128))
src/nexus/jaundice_detector.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Jaundice Detector Module
3
+
4
+ Uses MedSigLIP from Google HAI-DEF for jaundice detection from neonatal skin images.
5
+ Implements zero-shot classification with medical text prompts per NEXUS_MASTER_PLAN.md.
6
+
7
+ HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
8
+ Documentation: https://developers.google.com/health-ai-developer-foundations/medsiglip
9
+ """
10
+
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ from PIL import Image
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple, Union
17
+ import numpy as np
18
+
19
+ try:
20
+ from transformers import AutoProcessor, AutoModel
21
+ HAS_TRANSFORMERS = True
22
+ except ImportError:
23
+ HAS_TRANSFORMERS = False
24
+
25
+ # HAI-DEF MedSigLIP model IDs to try in order of preference
26
+ MEDSIGLIP_MODEL_IDS = [
27
+ "google/medsiglip-448", # MedSigLIP - official HAI-DEF model
28
+ "google/siglip-base-patch16-224", # SigLIP 224 - fallback
29
+ ]
30
+
31
+
32
+ class _BilirubinRegressor(nn.Module):
33
+ """3-layer MLP regression head with BatchNorm for bilirubin prediction (mg/dL).
34
+
35
+ Must match the architecture in scripts/training/finetune_bilirubin_regression.py
36
+ so that saved state_dict keys align.
37
+ """
38
+
39
+ def __init__(self, input_dim: int = 1152, hidden_dim: int = 256):
40
+ super().__init__()
41
+ mid_dim = hidden_dim * 2 # 512
42
+ self.net = nn.Sequential(
43
+ nn.Linear(input_dim, mid_dim),
44
+ nn.BatchNorm1d(mid_dim),
45
+ nn.ReLU(),
46
+ nn.Dropout(0.3),
47
+ nn.Linear(mid_dim, hidden_dim),
48
+ nn.BatchNorm1d(hidden_dim),
49
+ nn.ReLU(),
50
+ nn.Dropout(0.15),
51
+ nn.Linear(hidden_dim, 1),
52
+ )
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ return self.net(x).squeeze(-1)
56
+
57
+
58
+ class _BilirubinRegressorV1(nn.Module):
59
+ """Original 2-layer MLP for backwards compatibility with older checkpoints."""
60
+
61
+ def __init__(self, input_dim: int = 1152, hidden_dim: int = 256):
62
+ super().__init__()
63
+ self.net = nn.Sequential(
64
+ nn.Linear(input_dim, hidden_dim),
65
+ nn.ReLU(),
66
+ nn.Dropout(0.3),
67
+ nn.Linear(hidden_dim, 1),
68
+ )
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ return self.net(x).squeeze(-1)
72
+
73
+
74
+ class JaundiceDetector:
75
+ """
76
+ Detects neonatal jaundice from skin/sclera images using MedSigLIP.
77
+
78
+ Uses zero-shot classification with medical prompts and
79
+ color analysis for bilirubin estimation.
80
+
81
+ HAI-DEF Model: google/medsiglip-448 (MedSigLIP)
82
+ Fallback: siglip-base-patch16-224
83
+ """
84
+
85
+ # Medical text prompts for zero-shot classification (optimized for MedSigLIP)
86
+ # Expanded with Kramer zone references, skin-tone context, severity gradation
87
+ JAUNDICE_PROMPTS = [
88
+ "newborn with visible yellow discoloration of skin indicating jaundice",
89
+ "neonatal skin showing yellow-orange pigmentation from hyperbilirubinemia",
90
+ "jaundiced infant with icteric sclera and yellow skin tone",
91
+ "baby with yellow skin extending to trunk and limbs Kramer zone 3",
92
+ "neonatal jaundice with deep yellow skin requiring phototherapy",
93
+ "newborn showing yellow staining of skin and conjunctiva from bilirubin",
94
+ "infant with moderate to severe jaundice visible on face and chest",
95
+ "yellow discoloration of neonatal skin consistent with elevated bilirubin",
96
+ ]
97
+
98
+ NORMAL_PROMPTS = [
99
+ "healthy newborn with normal pink skin color without jaundice",
100
+ "infant with normal skin pigmentation and no yellow discoloration",
101
+ "newborn baby with clear healthy skin and no icterus",
102
+ "normal neonatal skin showing pink to brown coloration without yellowing",
103
+ "healthy baby skin with no signs of hyperbilirubinemia",
104
+ "newborn with well-perfused normal colored skin and clear sclera",
105
+ "infant with healthy natural skin tone and no bilirubin staining",
106
+ "normal newborn skin without yellow or orange discoloration",
107
+ ]
108
+
109
+ # Bilirubin risk thresholds (mg/dL)
110
+ BILIRUBIN_THRESHOLDS = {
111
+ "low": 5.0, # Normal range
112
+ "moderate": 12.0, # Monitor closely
113
+ "high": 15.0, # Consider phototherapy
114
+ "critical": 20.0, # Urgent phototherapy
115
+ "exchange": 25.0, # Exchange transfusion territory
116
+ }
117
+
118
+ def __init__(
119
+ self,
120
+ model_name: Optional[str] = None, # Auto-select MedSigLIP
121
+ device: Optional[str] = None,
122
+ threshold: float = 0.5,
123
+ ):
124
+ """
125
+ Initialize the Jaundice Detector with MedSigLIP.
126
+
127
+ Args:
128
+ model_name: HuggingFace model name (auto-selects HAI-DEF MedSigLIP if None)
129
+ device: Device to run model on (auto-detected if None)
130
+ threshold: Classification threshold for jaundice detection
131
+ """
132
+ if not HAS_TRANSFORMERS:
133
+ raise ImportError("transformers library required. Install with: pip install transformers")
134
+
135
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
136
+ self.threshold = threshold
137
+ self._model_loaded = False
138
+ self.classifier = None # Can be set by pipeline for trained classification
139
+ self.regressor = None # Bilirubin regression head (MedSigLIP embeddings -> mg/dL)
140
+
141
+ # Determine which models to try
142
+ models_to_try = [model_name] if model_name else MEDSIGLIP_MODEL_IDS
143
+
144
+ # HuggingFace token for gated models
145
+ hf_token = os.environ.get("HF_TOKEN")
146
+
147
+ # Try loading models in order of preference
148
+ for candidate_model in models_to_try:
149
+ print(f"Loading HAI-DEF model: {candidate_model}")
150
+ try:
151
+ self.processor = AutoProcessor.from_pretrained(
152
+ candidate_model, token=hf_token
153
+ )
154
+ self.model = AutoModel.from_pretrained(
155
+ candidate_model, token=hf_token
156
+ ).to(self.device)
157
+ self.model_name = candidate_model
158
+ self._model_loaded = True
159
+ print(f"Successfully loaded: {candidate_model}")
160
+ break
161
+ except Exception as e:
162
+ print(f"Warning: Could not load {candidate_model}: {e}")
163
+ continue
164
+
165
+ if not self._model_loaded:
166
+ raise RuntimeError(
167
+ f"Could not load any MedSigLIP model. Tried: {models_to_try}. "
168
+ "Install transformers and ensure internet access."
169
+ )
170
+
171
+ self.model.eval()
172
+
173
+ # Pre-compute text embeddings
174
+ self._precompute_text_embeddings()
175
+
176
+ # Try to auto-load trained classifier
177
+ self._auto_load_classifier()
178
+
179
+ # Try to load bilirubin regression model
180
+ self._load_regressor()
181
+
182
+ # Indicate which model variant is being used
183
+ is_medsiglip = "medsiglip" in self.model_name
184
+ model_type = "MedSigLIP" if is_medsiglip else "SigLIP (fallback)"
185
+ classifier_status = "trained classifier" if self.classifier else "zero-shot"
186
+ regressor_status = "with regressor" if self.regressor else "color-based only"
187
+ print(f"Jaundice Detector (HAI-DEF {model_type}, {classifier_status}, {regressor_status}) initialized on {self.device}")
188
+
189
+ def _auto_load_classifier(self) -> None:
190
+ """Auto-load trained jaundice classifier if available."""
191
+ if self.classifier is not None:
192
+ return
193
+
194
+ try:
195
+ import joblib
196
+ except ImportError:
197
+ return
198
+
199
+ default_paths = [
200
+ Path(__file__).parent.parent.parent / "models" / "linear_probes" / "jaundice_classifier.joblib",
201
+ Path("models/linear_probes/jaundice_classifier.joblib"),
202
+ ]
203
+
204
+ for path in default_paths:
205
+ if path.exists():
206
+ try:
207
+ self.classifier = joblib.load(path)
208
+ print(f"Auto-loaded jaundice classifier from {path}")
209
+ return
210
+ except Exception as e:
211
+ print(f"Warning: Could not load classifier from {path}: {e}")
212
+
213
+ # Logit temperature for softmax conversion
214
+ LOGIT_SCALE = 30.0
215
+
216
+ def _precompute_text_embeddings(self) -> None:
217
+ """Pre-compute text embeddings for zero-shot classification using SigLIP.
218
+
219
+ Stores individual prompt embeddings for max-similarity scoring.
220
+ """
221
+ all_prompts = self.JAUNDICE_PROMPTS + self.NORMAL_PROMPTS
222
+
223
+ with torch.no_grad():
224
+ inputs = self.processor(
225
+ text=all_prompts,
226
+ return_tensors="pt",
227
+ padding="max_length",
228
+ truncation=True,
229
+ ).to(self.device)
230
+
231
+ # Get text embeddings - support multiple output APIs
232
+ if hasattr(self.model, 'get_text_features'):
233
+ text_embeddings = self.model.get_text_features(**inputs)
234
+ else:
235
+ outputs = self.model(**inputs)
236
+ if hasattr(outputs, 'text_embeds'):
237
+ text_embeddings = outputs.text_embeds
238
+ elif hasattr(outputs, 'text_model_output'):
239
+ text_embeddings = outputs.text_model_output.pooler_output
240
+ else:
241
+ text_outputs = self.model.text_model(**inputs)
242
+ text_embeddings = text_outputs.pooler_output
243
+
244
+ text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
245
+
246
+ # Store individual embeddings for max-similarity scoring
247
+ n_jaundice = len(self.JAUNDICE_PROMPTS)
248
+ self.jaundice_embeddings_all = text_embeddings[:n_jaundice] # (N, D)
249
+ self.normal_embeddings_all = text_embeddings[n_jaundice:] # (M, D)
250
+
251
+ # Also keep mean embeddings as fallback
252
+ self.jaundice_embeddings = self.jaundice_embeddings_all.mean(dim=0, keepdim=True)
253
+ self.normal_embeddings = self.normal_embeddings_all.mean(dim=0, keepdim=True)
254
+ self.jaundice_embeddings = self.jaundice_embeddings / self.jaundice_embeddings.norm(dim=-1, keepdim=True)
255
+ self.normal_embeddings = self.normal_embeddings / self.normal_embeddings.norm(dim=-1, keepdim=True)
256
+
257
+ def _load_regressor(self) -> None:
258
+ """Load trained bilirubin regression head if available.
259
+
260
+ Tries the new 3-layer architecture first, falls back to V1 (2-layer).
261
+ """
262
+ model_paths = [
263
+ Path(__file__).parent.parent.parent / "models" / "linear_probes" / "bilirubin_regressor.pt",
264
+ Path("models/linear_probes/bilirubin_regressor.pt"),
265
+ ]
266
+
267
+ for model_path in model_paths:
268
+ if model_path.exists():
269
+ try:
270
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
271
+ input_dim = checkpoint.get("input_dim", 1152)
272
+ hidden_dim = checkpoint.get("hidden_dim", 256)
273
+
274
+ # Try new 3-layer architecture first, then fall back to V1
275
+ for RegClass in [_BilirubinRegressor, _BilirubinRegressorV1]:
276
+ try:
277
+ regressor = RegClass(input_dim, hidden_dim)
278
+ regressor.load_state_dict(checkpoint["model_state_dict"])
279
+ regressor.to(self.device)
280
+ regressor.eval()
281
+ self.regressor = regressor
282
+ arch = "v2 (3-layer)" if RegClass is _BilirubinRegressor else "v1 (2-layer)"
283
+ print(f"Bilirubin regressor ({arch}) loaded from {model_path}")
284
+ return
285
+ except (RuntimeError, KeyError):
286
+ continue
287
+
288
+ print(f"Warning: Regressor checkpoint incompatible at {model_path}")
289
+ except Exception as e:
290
+ print(f"Warning: Could not load regressor from {model_path}: {e}")
291
+ self.regressor = None
292
+
293
+ def preprocess_image(self, image: Union[str, Path, Image.Image]) -> Image.Image:
294
+ """Preprocess image for analysis.
295
+
296
+ Args:
297
+ image: Path to image file or PIL Image object.
298
+
299
+ Returns:
300
+ PIL Image in RGB mode.
301
+
302
+ Raises:
303
+ ValueError: If the input type is unsupported.
304
+ FileNotFoundError: If the image file does not exist.
305
+ """
306
+ if isinstance(image, (str, Path)):
307
+ path = Path(image)
308
+ if not path.exists():
309
+ raise FileNotFoundError(f"Image file not found: {path}")
310
+ image = Image.open(path).convert("RGB")
311
+ elif isinstance(image, Image.Image):
312
+ image = image.convert("RGB")
313
+ else:
314
+ raise ValueError(f"Expected str, Path, or PIL Image, got {type(image)}")
315
+ return image
316
+
317
+ def estimate_bilirubin(self, image: Union[str, Path, Image.Image]) -> float:
318
+ """
319
+ Estimate bilirubin level from image color analysis.
320
+
321
+ This uses the yellow-blue ratio which correlates with
322
+ transcutaneous bilirubin measurements.
323
+
324
+ Args:
325
+ image: Neonatal skin/sclera image
326
+
327
+ Returns:
328
+ Estimated bilirubin in mg/dL
329
+ """
330
+ pil_image = self.preprocess_image(image)
331
+ img_array = np.array(pil_image).astype(float)
332
+
333
+ # Ensure 3-channel RGB
334
+ if img_array.ndim == 2:
335
+ img_array = np.stack([img_array, img_array, img_array], axis=-1)
336
+ elif img_array.shape[-1] == 1:
337
+ img_array = np.concatenate([img_array] * 3, axis=-1)
338
+
339
+ # Extract color channels
340
+ r = img_array[:, :, 0]
341
+ g = img_array[:, :, 1]
342
+ b = img_array[:, :, 2]
343
+
344
+ # Calculate yellow index (R+G-B correlation with bilirubin)
345
+ # Higher values indicate more yellow (jaundiced)
346
+ yellow_index = (r + g - b) / (r + g + b + 1e-6)
347
+ mean_yellow = np.mean(yellow_index)
348
+
349
+ # Convert to bilirubin estimate
350
+ # Calibrated based on medical literature
351
+ # Normal yellow_index ~ 0.2-0.3, jaundiced ~ 0.4-0.6
352
+ bilirubin_estimate = max(0, (mean_yellow - 0.2) * 50)
353
+
354
+ return round(bilirubin_estimate, 1)
355
+
356
+ def detect(self, image: Union[str, Path, Image.Image]) -> Dict:
357
+ """
358
+ Detect jaundice from neonatal image.
359
+
360
+ Uses trained classifier if available, otherwise falls back to
361
+ zero-shot classification with MedSigLIP.
362
+
363
+ Args:
364
+ image: Neonatal skin/sclera image
365
+
366
+ Returns:
367
+ Dictionary containing:
368
+ - has_jaundice: Boolean indicating jaundice detection
369
+ - confidence: Confidence score
370
+ - jaundice_score: Raw jaundice probability
371
+ - estimated_bilirubin: Estimated bilirubin (mg/dL)
372
+ - severity: "none", "mild", "moderate", "severe", "critical"
373
+ - needs_phototherapy: Boolean
374
+ - recommendation: Clinical recommendation
375
+ """
376
+ pil_image = self.preprocess_image(image)
377
+
378
+ # Get image embedding using SigLIP
379
+ with torch.no_grad():
380
+ inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
381
+
382
+ # Get image embeddings - support multiple output APIs
383
+ if hasattr(self.model, 'get_image_features'):
384
+ image_embedding = self.model.get_image_features(**inputs)
385
+ else:
386
+ outputs = self.model(**inputs)
387
+ if hasattr(outputs, 'image_embeds'):
388
+ image_embedding = outputs.image_embeds
389
+ elif hasattr(outputs, 'vision_model_output'):
390
+ image_embedding = outputs.vision_model_output.pooler_output
391
+ else:
392
+ vision_outputs = self.model.vision_model(**inputs)
393
+ image_embedding = vision_outputs.pooler_output
394
+
395
+ image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
396
+
397
+ # Use trained classifier if available, otherwise zero-shot
398
+ if self.classifier is not None:
399
+ jaundice_prob, model_method = self._classify_with_trained_model(image_embedding)
400
+ else:
401
+ jaundice_prob, model_method = self._classify_zero_shot(image_embedding)
402
+
403
+ # Color-based bilirubin estimate (always available)
404
+ estimated_bilirubin = self.estimate_bilirubin(pil_image)
405
+
406
+ # ML-based bilirubin estimate from trained regressor on MedSigLIP embeddings
407
+ estimated_bilirubin_ml = None
408
+ if self.regressor is not None:
409
+ with torch.no_grad():
410
+ bilirubin_pred = self.regressor(image_embedding)
411
+ raw_value = float(bilirubin_pred.item())
412
+ # Clamp to physiologically valid range (0-35 mg/dL)
413
+ clamped_value = max(0.0, min(35.0, raw_value))
414
+ estimated_bilirubin_ml = round(clamped_value, 1)
415
+
416
+ # Use ML estimate for severity when available, otherwise color-based
417
+ bilirubin_for_severity = estimated_bilirubin_ml if estimated_bilirubin_ml is not None else estimated_bilirubin
418
+
419
+ # Determine severity based on bilirubin level
420
+ if bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["low"]:
421
+ severity = "none"
422
+ needs_phototherapy = False
423
+ recommendation = "No jaundice detected. Continue routine care."
424
+ elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["moderate"]:
425
+ severity = "mild"
426
+ needs_phototherapy = False
427
+ recommendation = "Mild jaundice. Monitor closely and ensure adequate feeding."
428
+ elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["high"]:
429
+ severity = "moderate"
430
+ needs_phototherapy = False
431
+ recommendation = "Moderate jaundice. Recheck in 12-24 hours. Consider phototherapy if rising."
432
+ elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["critical"]:
433
+ severity = "severe"
434
+ needs_phototherapy = True
435
+ recommendation = "URGENT: Start phototherapy. Refer for serum bilirubin confirmation."
436
+ else:
437
+ severity = "critical"
438
+ needs_phototherapy = True
439
+ recommendation = "CRITICAL: Immediate phototherapy required. Consider exchange transfusion."
440
+
441
+ is_medsiglip = "medsiglip" in self.model_name
442
+ base_model = "MedSigLIP (HAI-DEF)" if is_medsiglip else "SigLIP (fallback)"
443
+
444
+ result = {
445
+ "has_jaundice": jaundice_prob > self.threshold,
446
+ "confidence": max(jaundice_prob, 1 - jaundice_prob),
447
+ "jaundice_score": jaundice_prob,
448
+ "estimated_bilirubin": estimated_bilirubin,
449
+ "severity": severity,
450
+ "needs_phototherapy": needs_phototherapy,
451
+ "recommendation": recommendation,
452
+ "model": self.model_name,
453
+ "model_type": f"{base_model} + {model_method}",
454
+ }
455
+
456
+ if estimated_bilirubin_ml is not None:
457
+ result["estimated_bilirubin_ml"] = estimated_bilirubin_ml
458
+ result["bilirubin_method"] = "MedSigLIP Regressor"
459
+ else:
460
+ result["bilirubin_method"] = "Color Analysis"
461
+
462
+ return result
463
+
464
+ def _classify_with_trained_model(self, image_embedding: torch.Tensor) -> Tuple[float, str]:
465
+ """
466
+ Classify using trained classifier on embeddings.
467
+
468
+ Args:
469
+ image_embedding: Normalized image embedding from MedSigLIP
470
+
471
+ Returns:
472
+ Tuple of (jaundice_prob, method_name)
473
+ """
474
+ # Convert embedding to numpy for sklearn classifiers
475
+ embedding_np = image_embedding.cpu().numpy().reshape(1, -1)
476
+
477
+ # Handle different classifier types
478
+ if hasattr(self.classifier, 'predict_proba'):
479
+ # Sklearn classifier with probability support
480
+ proba = self.classifier.predict_proba(embedding_np)
481
+ # Assume binary: [normal, jaundice] or [jaundice, normal]
482
+ if proba.shape[1] >= 2:
483
+ # Check classifier classes to determine order
484
+ if hasattr(self.classifier, 'classes_'):
485
+ classes = list(self.classifier.classes_)
486
+ if 1 in classes:
487
+ jaundice_idx = classes.index(1)
488
+ else:
489
+ jaundice_idx = 1 # Default assumption
490
+ else:
491
+ jaundice_idx = 1
492
+ jaundice_prob = float(proba[0, jaundice_idx])
493
+ else:
494
+ jaundice_prob = float(proba[0, 0])
495
+ return jaundice_prob, "Trained Classifier"
496
+
497
+ elif hasattr(self.classifier, 'predict'):
498
+ # Classifier without probability - use binary prediction
499
+ prediction = self.classifier.predict(embedding_np)
500
+ jaundice_prob = float(prediction[0])
501
+ return jaundice_prob, "Trained Classifier (binary)"
502
+
503
+ elif isinstance(self.classifier, nn.Module):
504
+ # PyTorch classifier
505
+ self.classifier.eval()
506
+ with torch.no_grad():
507
+ logits = self.classifier(image_embedding)
508
+ probs = torch.softmax(logits, dim=-1)
509
+ if probs.shape[-1] >= 2:
510
+ jaundice_prob = probs[0, 1].item()
511
+ else:
512
+ jaundice_prob = probs[0, 0].item()
513
+ return jaundice_prob, "Trained Classifier (PyTorch)"
514
+
515
+ else:
516
+ # Unknown classifier type - fall back to zero-shot
517
+ print(f"Warning: Unknown classifier type {type(self.classifier)}, using zero-shot")
518
+ return self._classify_zero_shot(image_embedding)
519
+
520
+ def _classify_zero_shot(self, image_embedding: torch.Tensor) -> Tuple[float, str]:
521
+ """
522
+ Classify using zero-shot with max-similarity scoring.
523
+
524
+ Uses the maximum cosine similarity across all prompts per class
525
+ for better discrimination.
526
+
527
+ Args:
528
+ image_embedding: Normalized image embedding from MedSigLIP
529
+
530
+ Returns:
531
+ Tuple of (jaundice_prob, method_name)
532
+ """
533
+ # Max-similarity: best-matching prompt per class
534
+ jaundice_sims = (image_embedding @ self.jaundice_embeddings_all.T).squeeze(0)
535
+ normal_sims = (image_embedding @ self.normal_embeddings_all.T).squeeze(0)
536
+
537
+ # Ensure at least 1-D for .max() to work on single-image inputs
538
+ if jaundice_sims.dim() == 0:
539
+ jaundice_sims = jaundice_sims.unsqueeze(0)
540
+ if normal_sims.dim() == 0:
541
+ normal_sims = normal_sims.unsqueeze(0)
542
+
543
+ jaundice_sim = jaundice_sims.max().item()
544
+ normal_sim = normal_sims.max().item()
545
+
546
+ # Convert to probabilities with tuned temperature
547
+ logits = torch.tensor([jaundice_sim, normal_sim]) * self.LOGIT_SCALE
548
+ probs = torch.softmax(logits, dim=0)
549
+ jaundice_prob = probs[0].item()
550
+
551
+ return jaundice_prob, "Zero-Shot"
552
+
553
+ def detect_batch(
554
+ self,
555
+ images: List[Union[str, Path, Image.Image]],
556
+ batch_size: int = 8,
557
+ ) -> List[Dict]:
558
+ """Detect jaundice from multiple images."""
559
+ results = []
560
+
561
+ for i in range(0, len(images), batch_size):
562
+ batch = images[i:i + batch_size]
563
+ pil_images = [self.preprocess_image(img) for img in batch]
564
+
565
+ with torch.no_grad():
566
+ inputs = self.processor(images=pil_images, return_tensors="pt", padding=True).to(self.device)
567
+
568
+ # Get image embeddings from SigLIP vision encoder
569
+ if hasattr(self.model, 'get_image_features'):
570
+ image_embeddings = self.model.get_image_features(**inputs)
571
+ else:
572
+ vision_outputs = self.model.vision_model(**inputs)
573
+ image_embeddings = vision_outputs.pooler_output
574
+
575
+ image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
576
+
577
+ for j, (img_emb, pil_img) in enumerate(zip(image_embeddings, pil_images)):
578
+ img_emb = img_emb.unsqueeze(0)
579
+
580
+ # Use trained classifier if available, otherwise zero-shot
581
+ if self.classifier is not None:
582
+ jaundice_prob, model_method = self._classify_with_trained_model(img_emb)
583
+ else:
584
+ jaundice_prob, model_method = self._classify_zero_shot(img_emb)
585
+
586
+ # Color-based bilirubin
587
+ estimated_bilirubin = self.estimate_bilirubin(pil_img)
588
+
589
+ # ML bilirubin from regressor (consistent with detect())
590
+ estimated_bilirubin_ml = None
591
+ if self.regressor is not None:
592
+ with torch.no_grad():
593
+ bilirubin_pred = self.regressor(img_emb)
594
+ raw_value = float(bilirubin_pred.item())
595
+ estimated_bilirubin_ml = round(max(0.0, min(35.0, raw_value)), 1)
596
+
597
+ bilirubin_for_severity = estimated_bilirubin_ml if estimated_bilirubin_ml is not None else estimated_bilirubin
598
+
599
+ if bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["low"]:
600
+ severity, needs_phototherapy = "none", False
601
+ elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["moderate"]:
602
+ severity, needs_phototherapy = "mild", False
603
+ elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["high"]:
604
+ severity, needs_phototherapy = "moderate", False
605
+ elif bilirubin_for_severity < self.BILIRUBIN_THRESHOLDS["critical"]:
606
+ severity, needs_phototherapy = "severe", True
607
+ else:
608
+ severity, needs_phototherapy = "critical", True
609
+
610
+ result_item = {
611
+ "has_jaundice": jaundice_prob > self.threshold,
612
+ "confidence": max(jaundice_prob, 1 - jaundice_prob),
613
+ "jaundice_score": jaundice_prob,
614
+ "estimated_bilirubin": estimated_bilirubin,
615
+ "severity": severity,
616
+ "needs_phototherapy": needs_phototherapy,
617
+ }
618
+ if estimated_bilirubin_ml is not None:
619
+ result_item["estimated_bilirubin_ml"] = estimated_bilirubin_ml
620
+ results.append(result_item)
621
+
622
+ return results
623
+
624
+ def analyze_kramer_zones(self, image: Union[str, Path, Image.Image]) -> Dict:
625
+ """
626
+ Analyze jaundice using Kramer's zones concept.
627
+
628
+ Kramer's zones estimate bilirubin based on cephalocaudal progression:
629
+ - Zone 1 (face): ~5-6 mg/dL
630
+ - Zone 2 (chest): ~9 mg/dL
631
+ - Zone 3 (abdomen): ~12 mg/dL
632
+ - Zone 4 (arms/legs): ~15 mg/dL
633
+ - Zone 5 (hands/feet): ~20+ mg/dL
634
+
635
+ Args:
636
+ image: Full body or partial neonatal image
637
+
638
+ Returns:
639
+ Dictionary with zone analysis
640
+ """
641
+ pil_image = self.preprocess_image(image)
642
+ img_array = np.array(pil_image).astype(float)
643
+
644
+ # Simple color-based zone estimation
645
+ r = img_array[:, :, 0]
646
+ g = img_array[:, :, 1]
647
+ b = img_array[:, :, 2]
648
+
649
+ yellow_index = np.mean((r + g - b) / (r + g + b + 1e-6))
650
+
651
+ # Map yellow index to Kramer zone
652
+ if yellow_index < 0.25:
653
+ zone = 0
654
+ zone_bilirubin = 3
655
+ elif yellow_index < 0.30:
656
+ zone = 1
657
+ zone_bilirubin = 6
658
+ elif yellow_index < 0.35:
659
+ zone = 2
660
+ zone_bilirubin = 9
661
+ elif yellow_index < 0.40:
662
+ zone = 3
663
+ zone_bilirubin = 12
664
+ elif yellow_index < 0.45:
665
+ zone = 4
666
+ zone_bilirubin = 15
667
+ else:
668
+ zone = 5
669
+ zone_bilirubin = 20
670
+
671
+ return {
672
+ "kramer_zone": zone,
673
+ "zone_description": self._get_zone_description(zone),
674
+ "estimated_bilirubin_by_zone": zone_bilirubin,
675
+ "yellow_index": round(yellow_index, 3),
676
+ }
677
+
678
+ def _get_zone_description(self, zone: int) -> str:
679
+ """Get description for Kramer zone."""
680
+ descriptions = {
681
+ 0: "No visible jaundice",
682
+ 1: "Face and neck (Zone 1)",
683
+ 2: "Upper trunk (Zone 2)",
684
+ 3: "Lower trunk and thighs (Zone 3)",
685
+ 4: "Arms and lower legs (Zone 4)",
686
+ 5: "Hands and feet (Zone 5) - Severe",
687
+ }
688
+ return descriptions.get(zone, "Unknown")
689
+
690
+
691
+ def test_detector():
692
+ """Test the jaundice detector with sample images."""
693
+ print("Testing Jaundice Detector...")
694
+
695
+ detector = JaundiceDetector()
696
+
697
+ data_dir = Path(__file__).parent.parent.parent / "data" / "raw" / "neojaundice" / "images"
698
+
699
+ if data_dir.exists():
700
+ sample_images = list(data_dir.glob("*.jpg"))[:3]
701
+
702
+ for img_path in sample_images:
703
+ print(f"\nAnalyzing: {img_path.name}")
704
+ result = detector.detect(img_path)
705
+ print(f" Jaundice detected: {result['has_jaundice']}")
706
+ print(f" Confidence: {result['confidence']:.2%}")
707
+ print(f" Estimated bilirubin: {result['estimated_bilirubin']} mg/dL")
708
+ print(f" Severity: {result['severity']}")
709
+ print(f" Needs phototherapy: {result['needs_phototherapy']}")
710
+ print(f" Recommendation: {result['recommendation']}")
711
+ else:
712
+ print(f"Dataset not found at {data_dir}")
713
+
714
+
715
+ if __name__ == "__main__":
716
+ test_detector()
src/nexus/pipeline.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NEXUS Pipeline Module
3
+
4
+ Integrates all detection modules into a unified diagnostic pipeline
5
+ for maternal-neonatal care.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Union
10
+ from dataclasses import dataclass
11
+ from datetime import datetime
12
+ import json
13
+
14
+
15
+ @dataclass
16
+ class PatientInfo:
17
+ """Patient information for context."""
18
+ patient_id: str
19
+ age_days: Optional[int] = None # For neonates
20
+ gestational_age: Optional[int] = None # Weeks
21
+ birth_weight: Optional[int] = None # Grams
22
+ gender: Optional[str] = None
23
+ is_maternal: bool = False # True for mother, False for neonate
24
+
25
+
26
+ @dataclass
27
+ class AssessmentResult:
28
+ """Complete assessment result."""
29
+ patient: PatientInfo
30
+ timestamp: str
31
+ anemia_result: Optional[Dict] = None
32
+ jaundice_result: Optional[Dict] = None
33
+ cry_result: Optional[Dict] = None
34
+ overall_risk: str = "unknown"
35
+ priority_actions: List[str] = None
36
+ referral_needed: bool = False
37
+
38
+
39
+ class NEXUSPipeline:
40
+ """
41
+ NEXUS Integrated Diagnostic Pipeline
42
+
43
+ Combines anemia, jaundice, and cry analysis into a unified
44
+ assessment workflow for maternal-neonatal care.
45
+ """
46
+
47
+ # Default paths for trained model checkpoints
48
+ DEFAULT_CHECKPOINT_DIR = Path(__file__).parent.parent.parent / "models" / "checkpoints"
49
+ DEFAULT_LINEAR_PROBE_DIR = Path(__file__).parent.parent.parent / "models" / "linear_probes"
50
+
51
+ def __init__(
52
+ self,
53
+ device: Optional[str] = None,
54
+ lazy_load: bool = True,
55
+ anemia_checkpoint: Optional[Union[str, Path]] = None,
56
+ jaundice_checkpoint: Optional[Union[str, Path]] = None,
57
+ cry_checkpoint: Optional[Union[str, Path]] = None,
58
+ use_linear_probes: bool = True,
59
+ ):
60
+ """
61
+ Initialize NEXUS Pipeline.
62
+
63
+ Args:
64
+ device: Device for model inference
65
+ lazy_load: If True, load models only when needed
66
+ anemia_checkpoint: Path to trained anemia classifier checkpoint
67
+ jaundice_checkpoint: Path to trained jaundice classifier checkpoint
68
+ cry_checkpoint: Path to trained cry classifier checkpoint
69
+ use_linear_probes: If True, auto-load linear probes from default dir
70
+ """
71
+ self.device = device
72
+ self.lazy_load = lazy_load
73
+
74
+ # Store checkpoint paths
75
+ self.anemia_checkpoint = anemia_checkpoint
76
+ self.jaundice_checkpoint = jaundice_checkpoint
77
+ self.cry_checkpoint = cry_checkpoint
78
+
79
+ # Auto-detect checkpoints from default locations
80
+ if use_linear_probes:
81
+ self._auto_detect_checkpoints()
82
+
83
+ self._anemia_detector = None
84
+ self._jaundice_detector = None
85
+ self._cry_analyzer = None
86
+
87
+ if not lazy_load:
88
+ self._load_all_models()
89
+
90
+ print("NEXUS Pipeline initialized")
91
+
92
+ def verify_hai_def_compliance(self) -> Dict:
93
+ """
94
+ Verify which HAI-DEF models are loaded and report compliance.
95
+
96
+ Returns:
97
+ Dictionary with model status and compliance flag.
98
+ """
99
+ from .anemia_detector import MEDSIGLIP_MODEL_IDS
100
+ from .cry_analyzer import CryAnalyzer
101
+
102
+ status = {
103
+ "medsiglip": {
104
+ "expected": "google/medsiglip-448",
105
+ "configured_models": MEDSIGLIP_MODEL_IDS,
106
+ "anemia_loaded": self._anemia_detector is not None,
107
+ "jaundice_loaded": self._jaundice_detector is not None,
108
+ },
109
+ "hear": {
110
+ "expected": CryAnalyzer.HEAR_MODEL_ID,
111
+ "cry_loaded": self._cry_analyzer is not None,
112
+ "hear_active": getattr(self._cry_analyzer, '_hear_available', False) if self._cry_analyzer else False,
113
+ },
114
+ "medgemma": {
115
+ "expected": "google/medgemma-4b-it",
116
+ },
117
+ }
118
+
119
+ # Check loaded model names
120
+ if self._anemia_detector:
121
+ status["medsiglip"]["anemia_model"] = getattr(self._anemia_detector, 'model_name', 'unknown')
122
+ if self._jaundice_detector:
123
+ status["medsiglip"]["jaundice_model"] = getattr(self._jaundice_detector, 'model_name', 'unknown')
124
+
125
+ # Overall compliance
126
+ anemia_ok = "medsiglip" in status["medsiglip"].get("anemia_model", "")
127
+ jaundice_ok = "medsiglip" in status["medsiglip"].get("jaundice_model", "")
128
+ hear_ok = status["hear"]["hear_active"]
129
+
130
+ status["compliant"] = anemia_ok or jaundice_ok or hear_ok
131
+ status["all_hai_def"] = anemia_ok and jaundice_ok and hear_ok
132
+
133
+ return status
134
+
135
+ def _auto_detect_checkpoints(self) -> None:
136
+ """Auto-detect trained checkpoints from default directories."""
137
+ # Check for linear probes (.joblib sklearn models)
138
+ if self.anemia_checkpoint is None:
139
+ anemia_probe = self.DEFAULT_LINEAR_PROBE_DIR / "anemia_linear_probe.joblib"
140
+ if anemia_probe.exists():
141
+ self.anemia_checkpoint = anemia_probe
142
+ print(f"Auto-detected anemia probe: {anemia_probe}")
143
+
144
+ if self.jaundice_checkpoint is None:
145
+ jaundice_probe = self.DEFAULT_LINEAR_PROBE_DIR / "jaundice_linear_probe.joblib"
146
+ if jaundice_probe.exists():
147
+ self.jaundice_checkpoint = jaundice_probe
148
+ print(f"Auto-detected jaundice probe: {jaundice_probe}")
149
+
150
+ if self.cry_checkpoint is None:
151
+ cry_probe = self.DEFAULT_LINEAR_PROBE_DIR / "cry_linear_probe.joblib"
152
+ if cry_probe.exists():
153
+ self.cry_checkpoint = cry_probe
154
+ print(f"Auto-detected cry probe: {cry_probe}")
155
+
156
+ # Also check checkpoint dir for full fine-tuned models
157
+ if self.anemia_checkpoint is None:
158
+ anemia_best = self.DEFAULT_CHECKPOINT_DIR / "anemia_best.pt"
159
+ if anemia_best.exists():
160
+ self.anemia_checkpoint = anemia_best
161
+ print(f"Auto-detected anemia checkpoint: {anemia_best}")
162
+
163
+ def _load_all_models(self) -> None:
164
+ """Load all detection models."""
165
+ self._get_anemia_detector()
166
+ self._get_jaundice_detector()
167
+ self._get_cry_analyzer()
168
+
169
+ def _get_anemia_detector(self):
170
+ """Get or create anemia detector with optional trained classifier."""
171
+ if self._anemia_detector is None:
172
+ from .anemia_detector import AnemiaDetector
173
+
174
+ # Initialize detector
175
+ self._anemia_detector = AnemiaDetector(device=self.device)
176
+
177
+ # Load trained classifier if available
178
+ if self.anemia_checkpoint:
179
+ self._load_classifier_checkpoint(
180
+ self._anemia_detector,
181
+ self.anemia_checkpoint,
182
+ "anemia"
183
+ )
184
+
185
+ return self._anemia_detector
186
+
187
+ def _get_jaundice_detector(self):
188
+ """Get or create jaundice detector with optional trained classifier."""
189
+ if self._jaundice_detector is None:
190
+ from .jaundice_detector import JaundiceDetector
191
+
192
+ self._jaundice_detector = JaundiceDetector(device=self.device)
193
+
194
+ # Load trained classifier if available
195
+ if self.jaundice_checkpoint:
196
+ self._load_classifier_checkpoint(
197
+ self._jaundice_detector,
198
+ self.jaundice_checkpoint,
199
+ "jaundice"
200
+ )
201
+
202
+ return self._jaundice_detector
203
+
204
+ def _get_cry_analyzer(self):
205
+ """Get or create cry analyzer with optional trained classifier."""
206
+ if self._cry_analyzer is None:
207
+ from .cry_analyzer import CryAnalyzer
208
+
209
+ # Cry analyzer supports classifier_path directly
210
+ classifier_path = str(self.cry_checkpoint) if self.cry_checkpoint else None
211
+ self._cry_analyzer = CryAnalyzer(
212
+ device=self.device,
213
+ classifier_path=classifier_path
214
+ )
215
+
216
+ return self._cry_analyzer
217
+
218
+ def _load_classifier_checkpoint(
219
+ self,
220
+ detector,
221
+ checkpoint_path: Union[str, Path],
222
+ model_type: str
223
+ ) -> None:
224
+ """
225
+ Load a trained classifier checkpoint into a detector.
226
+
227
+ Supports both linear probes (sklearn) and PyTorch checkpoints.
228
+ """
229
+ import torch
230
+
231
+ checkpoint_path = Path(checkpoint_path)
232
+ if not checkpoint_path.exists():
233
+ print(f"Warning: {model_type} checkpoint not found: {checkpoint_path}")
234
+ return
235
+
236
+ try:
237
+ # Check if it's a sklearn model (joblib)
238
+ if checkpoint_path.suffix in ['.pkl', '.joblib']:
239
+ import joblib
240
+ classifier = joblib.load(checkpoint_path)
241
+ detector.classifier = classifier
242
+ print(f"Loaded sklearn classifier for {model_type}")
243
+
244
+ # Check if it's a PyTorch model
245
+ elif checkpoint_path.suffix == '.pt':
246
+ checkpoint = torch.load(checkpoint_path, map_location=self.device or 'cpu')
247
+
248
+ # Handle different checkpoint formats
249
+ if 'classifier' in checkpoint:
250
+ # Linear probe format
251
+ detector.classifier = checkpoint['classifier']
252
+ print(f"Loaded linear probe for {model_type}")
253
+ elif 'model_state_dict' in checkpoint:
254
+ # Full model checkpoint - would need separate handling
255
+ print(f"Note: Full model checkpoint for {model_type} - using zero-shot")
256
+ else:
257
+ print(f"Unknown checkpoint format for {model_type}")
258
+
259
+ except Exception as e:
260
+ print(f"Warning: Could not load {model_type} checkpoint: {e}")
261
+
262
+ def assess_maternal(
263
+ self,
264
+ patient: PatientInfo,
265
+ conjunctiva_image: Optional[Union[str, Path]] = None,
266
+ ) -> AssessmentResult:
267
+ """
268
+ Perform maternal health assessment.
269
+
270
+ Currently focuses on anemia detection via conjunctiva imaging.
271
+
272
+ Args:
273
+ patient: Patient information
274
+ conjunctiva_image: Path to conjunctiva image
275
+
276
+ Returns:
277
+ AssessmentResult with findings
278
+ """
279
+ result = AssessmentResult(
280
+ patient=patient,
281
+ timestamp=datetime.now().isoformat(),
282
+ priority_actions=[],
283
+ )
284
+
285
+ # Anemia detection
286
+ if conjunctiva_image:
287
+ detector = self._get_anemia_detector()
288
+ result.anemia_result = detector.detect(conjunctiva_image)
289
+
290
+ # Add color analysis
291
+ color_info = detector.analyze_color_features(conjunctiva_image)
292
+ result.anemia_result["color_analysis"] = color_info
293
+
294
+ # Determine actions
295
+ if result.anemia_result["risk_level"] == "high":
296
+ result.priority_actions.append("URGENT: Refer for blood test - suspected severe anemia")
297
+ result.referral_needed = True
298
+ result.overall_risk = "high"
299
+ elif result.anemia_result["risk_level"] == "medium":
300
+ result.priority_actions.append("Schedule blood test within 48 hours")
301
+ result.overall_risk = "medium"
302
+ else:
303
+ result.overall_risk = "low"
304
+
305
+ return result
306
+
307
+ def assess_neonate(
308
+ self,
309
+ patient: PatientInfo,
310
+ skin_image: Optional[Union[str, Path]] = None,
311
+ cry_audio: Optional[Union[str, Path]] = None,
312
+ ) -> AssessmentResult:
313
+ """
314
+ Perform neonatal health assessment.
315
+
316
+ Includes jaundice detection and cry analysis.
317
+
318
+ Args:
319
+ patient: Patient information
320
+ skin_image: Path to skin/sclera image for jaundice
321
+ cry_audio: Path to cry audio file
322
+
323
+ Returns:
324
+ AssessmentResult with findings
325
+ """
326
+ result = AssessmentResult(
327
+ patient=patient,
328
+ timestamp=datetime.now().isoformat(),
329
+ priority_actions=[],
330
+ )
331
+
332
+ risk_scores = []
333
+
334
+ # Jaundice detection
335
+ if skin_image:
336
+ detector = self._get_jaundice_detector()
337
+ result.jaundice_result = detector.detect(skin_image)
338
+
339
+ # Add zone analysis
340
+ zone_info = detector.analyze_kramer_zones(skin_image)
341
+ result.jaundice_result["zone_analysis"] = zone_info
342
+
343
+ if result.jaundice_result["severity"] == "critical":
344
+ result.priority_actions.insert(0, "CRITICAL: Immediate phototherapy required")
345
+ result.referral_needed = True
346
+ risk_scores.append(1.0)
347
+ elif result.jaundice_result["severity"] == "severe":
348
+ result.priority_actions.append("URGENT: Start phototherapy")
349
+ result.referral_needed = True
350
+ risk_scores.append(0.8)
351
+ elif result.jaundice_result["severity"] == "moderate":
352
+ result.priority_actions.append("Monitor closely, recheck in 12-24 hours")
353
+ risk_scores.append(0.5)
354
+ else:
355
+ risk_scores.append(0.2)
356
+
357
+ # Cry analysis
358
+ if cry_audio:
359
+ analyzer = self._get_cry_analyzer()
360
+ result.cry_result = analyzer.analyze(cry_audio)
361
+
362
+ if result.cry_result["risk_level"] == "high":
363
+ result.priority_actions.insert(0, "URGENT: Abnormal cry - assess for birth asphyxia")
364
+ result.referral_needed = True
365
+ risk_scores.append(1.0)
366
+ elif result.cry_result["risk_level"] == "medium":
367
+ result.priority_actions.append("Monitor cry patterns, reassess in 30 minutes")
368
+ risk_scores.append(0.5)
369
+ else:
370
+ risk_scores.append(0.2)
371
+
372
+ # Determine overall risk
373
+ if risk_scores:
374
+ max_risk = max(risk_scores)
375
+ if max_risk >= 0.8:
376
+ result.overall_risk = "high"
377
+ elif max_risk >= 0.5:
378
+ result.overall_risk = "medium"
379
+ else:
380
+ result.overall_risk = "low"
381
+
382
+ return result
383
+
384
+ def agentic_assessment(
385
+ self,
386
+ patient_type: str = "newborn",
387
+ conjunctiva_image: Optional[Union[str, Path]] = None,
388
+ skin_image: Optional[Union[str, Path]] = None,
389
+ cry_audio: Optional[Union[str, Path]] = None,
390
+ danger_signs: Optional[List[Dict]] = None,
391
+ patient_info: Optional[Dict] = None,
392
+ ) -> Dict:
393
+ """
394
+ Run the full agentic clinical workflow with 6 specialized agents.
395
+
396
+ This provides richer output than full_assessment() — each agent emits
397
+ step-by-step reasoning traces forming a complete audit trail.
398
+
399
+ Args:
400
+ patient_type: "pregnant" or "newborn"
401
+ conjunctiva_image: Path to conjunctiva image for anemia screening
402
+ skin_image: Path to skin image for jaundice detection
403
+ cry_audio: Path to cry audio for asphyxia detection
404
+ danger_signs: List of danger sign dicts with keys: id, label, severity, present
405
+ patient_info: Patient information dict
406
+
407
+ Returns:
408
+ Dict with workflow result including agent_traces list
409
+ """
410
+ from .agentic_workflow import (
411
+ AgenticWorkflowEngine,
412
+ AgentPatientInfo,
413
+ DangerSign,
414
+ WorkflowInput,
415
+ )
416
+
417
+ # Build patient info
418
+ info = AgentPatientInfo(patient_type=patient_type)
419
+ if patient_info:
420
+ info.patient_id = patient_info.get("patient_id", "")
421
+ info.gestational_weeks = patient_info.get("gestational_weeks")
422
+ info.birth_weight = patient_info.get("birth_weight")
423
+ info.apgar_score = patient_info.get("apgar_score")
424
+ info.age_hours = patient_info.get("age_hours")
425
+
426
+ # Build danger signs
427
+ signs = []
428
+ if danger_signs:
429
+ for s in danger_signs:
430
+ signs.append(DangerSign(
431
+ id=s.get("id", ""),
432
+ label=s.get("label", ""),
433
+ severity=s.get("severity", "medium"),
434
+ present=s.get("present", True),
435
+ ))
436
+
437
+ workflow_input = WorkflowInput(
438
+ patient_type=patient_type,
439
+ patient_info=info,
440
+ danger_signs=signs,
441
+ conjunctiva_image=conjunctiva_image,
442
+ skin_image=skin_image,
443
+ cry_audio=cry_audio,
444
+ )
445
+
446
+ # Create engine with existing model instances to avoid reloading
447
+ engine = AgenticWorkflowEngine(
448
+ anemia_detector=self._anemia_detector,
449
+ jaundice_detector=self._jaundice_detector,
450
+ cry_analyzer=self._cry_analyzer,
451
+ )
452
+
453
+ result = engine.execute(workflow_input)
454
+
455
+ # Serialize to dict
456
+ return {
457
+ "success": result.success,
458
+ "patient_type": result.patient_type,
459
+ "who_classification": result.who_classification,
460
+ "clinical_synthesis": result.clinical_synthesis,
461
+ "recommendation": result.recommendation,
462
+ "immediate_actions": result.immediate_actions,
463
+ "processing_time_ms": result.processing_time_ms,
464
+ "timestamp": result.timestamp,
465
+ "triage": {
466
+ "risk_level": result.triage_result.risk_level,
467
+ "score": result.triage_result.score,
468
+ "critical_signs": result.triage_result.critical_signs,
469
+ "immediate_referral": result.triage_result.immediate_referral_needed,
470
+ } if result.triage_result else None,
471
+ "referral": {
472
+ "referral_needed": result.referral_result.referral_needed,
473
+ "urgency": result.referral_result.urgency,
474
+ "facility_level": result.referral_result.facility_level,
475
+ "reason": result.referral_result.reason,
476
+ "timeframe": result.referral_result.timeframe,
477
+ } if result.referral_result else None,
478
+ "protocol": {
479
+ "classification": result.protocol_result.classification,
480
+ "applicable_protocols": result.protocol_result.applicable_protocols,
481
+ "treatment_recommendations": result.protocol_result.treatment_recommendations,
482
+ "follow_up_schedule": result.protocol_result.follow_up_schedule,
483
+ } if result.protocol_result else None,
484
+ "agent_traces": [
485
+ {
486
+ "agent_name": t.agent_name,
487
+ "status": t.status,
488
+ "reasoning": t.reasoning,
489
+ "findings": t.findings,
490
+ "confidence": t.confidence,
491
+ "processing_time_ms": t.processing_time_ms,
492
+ }
493
+ for t in result.agent_traces
494
+ ],
495
+ }
496
+
497
+ def full_assessment(
498
+ self,
499
+ patient: PatientInfo,
500
+ conjunctiva_image: Optional[Union[str, Path]] = None,
501
+ skin_image: Optional[Union[str, Path]] = None,
502
+ cry_audio: Optional[Union[str, Path]] = None,
503
+ ) -> AssessmentResult:
504
+ """
505
+ Perform full assessment (maternal or neonatal based on patient info).
506
+
507
+ Args:
508
+ patient: Patient information
509
+ conjunctiva_image: For maternal anemia screening
510
+ skin_image: For neonatal jaundice detection
511
+ cry_audio: For neonatal cry analysis
512
+
513
+ Returns:
514
+ Complete AssessmentResult
515
+ """
516
+ if patient.is_maternal:
517
+ return self.assess_maternal(patient, conjunctiva_image)
518
+ else:
519
+ return self.assess_neonate(patient, skin_image, cry_audio)
520
+
521
+ def generate_report(self, result: AssessmentResult) -> str:
522
+ """
523
+ Generate a text report from assessment result.
524
+
525
+ Args:
526
+ result: AssessmentResult from assessment
527
+
528
+ Returns:
529
+ Formatted report string
530
+ """
531
+ lines = [
532
+ "=" * 60,
533
+ "NEXUS HEALTH ASSESSMENT REPORT",
534
+ "=" * 60,
535
+ "",
536
+ f"Patient ID: {result.patient.patient_id}",
537
+ f"Assessment Time: {result.timestamp}",
538
+ f"Patient Type: {'Maternal' if result.patient.is_maternal else 'Neonatal'}",
539
+ "",
540
+ ]
541
+
542
+ if result.patient.age_days is not None:
543
+ lines.append(f"Age: {result.patient.age_days} days")
544
+ if result.patient.gestational_age is not None:
545
+ lines.append(f"Gestational Age: {result.patient.gestational_age} weeks")
546
+ if result.patient.birth_weight is not None:
547
+ lines.append(f"Birth Weight: {result.patient.birth_weight} grams")
548
+
549
+ lines.extend(["", "-" * 60, "FINDINGS", "-" * 60, ""])
550
+
551
+ # Anemia findings
552
+ if result.anemia_result:
553
+ lines.extend([
554
+ "ANEMIA SCREENING:",
555
+ f" Status: {'ANEMIC' if result.anemia_result['is_anemic'] else 'Normal'}",
556
+ f" Confidence: {result.anemia_result['confidence']:.1%}",
557
+ f" Risk Level: {result.anemia_result['risk_level'].upper()}",
558
+ "",
559
+ ])
560
+
561
+ # Jaundice findings
562
+ if result.jaundice_result:
563
+ lines.extend([
564
+ "JAUNDICE ASSESSMENT:",
565
+ f" Status: {'JAUNDICE DETECTED' if result.jaundice_result['has_jaundice'] else 'Normal'}",
566
+ f" Estimated Bilirubin: {result.jaundice_result['estimated_bilirubin']} mg/dL",
567
+ f" Severity: {result.jaundice_result['severity'].upper()}",
568
+ f" Phototherapy Needed: {'YES' if result.jaundice_result['needs_phototherapy'] else 'No'}",
569
+ "",
570
+ ])
571
+
572
+ # Cry analysis findings
573
+ if result.cry_result:
574
+ lines.extend([
575
+ "CRY ANALYSIS:",
576
+ f" Status: {'ABNORMAL' if result.cry_result['is_abnormal'] else 'Normal'}",
577
+ f" Asphyxia Risk: {result.cry_result['asphyxia_risk']:.1%}",
578
+ f" Cry Type: {result.cry_result['cry_type']}",
579
+ f" Risk Level: {result.cry_result['risk_level'].upper()}",
580
+ "",
581
+ ])
582
+
583
+ lines.extend(["-" * 60, "OVERALL ASSESSMENT", "-" * 60, ""])
584
+ lines.append(f"Overall Risk Level: {result.overall_risk.upper()}")
585
+ lines.append(f"Referral Needed: {'YES' if result.referral_needed else 'No'}")
586
+
587
+ if result.priority_actions:
588
+ lines.extend(["", "PRIORITY ACTIONS:"])
589
+ for i, action in enumerate(result.priority_actions, 1):
590
+ lines.append(f" {i}. {action}")
591
+
592
+ lines.extend(["", "=" * 60])
593
+
594
+ return "\n".join(lines)
595
+
596
+ def to_json(self, result: AssessmentResult) -> str:
597
+ """Convert assessment result to JSON string."""
598
+ data = {
599
+ "patient": {
600
+ "patient_id": result.patient.patient_id,
601
+ "age_days": result.patient.age_days,
602
+ "gestational_age": result.patient.gestational_age,
603
+ "birth_weight": result.patient.birth_weight,
604
+ "gender": result.patient.gender,
605
+ "is_maternal": result.patient.is_maternal,
606
+ },
607
+ "timestamp": result.timestamp,
608
+ "anemia_result": result.anemia_result,
609
+ "jaundice_result": result.jaundice_result,
610
+ "cry_result": result.cry_result,
611
+ "overall_risk": result.overall_risk,
612
+ "priority_actions": result.priority_actions,
613
+ "referral_needed": result.referral_needed,
614
+ }
615
+ return json.dumps(data, indent=2)
616
+
617
+
618
+ def demo():
619
+ """Demo the NEXUS pipeline."""
620
+ print("NEXUS Pipeline Demo")
621
+ print("=" * 60)
622
+
623
+ # Initialize pipeline
624
+ pipeline = NEXUSPipeline(lazy_load=True)
625
+
626
+ # Demo maternal assessment
627
+ print("\n--- Maternal Assessment Demo ---")
628
+ maternal_patient = PatientInfo(
629
+ patient_id="M001",
630
+ is_maternal=True,
631
+ )
632
+
633
+ data_dir = Path(__file__).parent.parent.parent / "data" / "raw"
634
+ anemia_images = list((data_dir / "eyes-defy-anemia").rglob("*.jpg"))[:1]
635
+
636
+ if anemia_images:
637
+ result = pipeline.assess_maternal(maternal_patient, anemia_images[0])
638
+ print(pipeline.generate_report(result))
639
+
640
+ # Demo neonatal assessment
641
+ print("\n--- Neonatal Assessment Demo ---")
642
+ neonatal_patient = PatientInfo(
643
+ patient_id="N001",
644
+ age_days=3,
645
+ gestational_age=38,
646
+ birth_weight=3200,
647
+ gender="M",
648
+ is_maternal=False,
649
+ )
650
+
651
+ jaundice_images = list((data_dir / "neojaundice" / "images").glob("*.jpg"))[:1]
652
+ cry_files = list((data_dir / "donate-a-cry").rglob("*.wav"))[:1]
653
+
654
+ skin_image = jaundice_images[0] if jaundice_images else None
655
+ cry_audio = cry_files[0] if cry_files else None
656
+
657
+ if skin_image or cry_audio:
658
+ result = pipeline.assess_neonate(neonatal_patient, skin_image, cry_audio)
659
+ print(pipeline.generate_report(result))
660
+
661
+
662
+ if __name__ == "__main__":
663
+ demo()