Ashkan Taghipour (The University of Western Australia) Claude Opus 4.5 commited on
Commit
cd846d7
·
1 Parent(s): 01070c9

Initial HeartWatch AI demo release

Browse files

- Added app.py: Gradio interface with upload and sample gallery tabs
- Added inference.py: DeepECG inference engine with 4 models
- Added visualization.py: ECG waveform, diagnosis bars, risk gauges
- Added class_names.json: 77 ECG diagnosis class names
- Added sample ECG data from MIT-BIH database
- Added requirements.txt and .gitignore

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ ENV/
26
+ env/
27
+ .venv/
28
+
29
+ # IDE
30
+ .idea/
31
+ .vscode/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Jupyter
37
+ .ipynb_checkpoints/
38
+
39
+ # Model weights (downloaded at runtime)
40
+ weights/
41
+ *.pt
42
+ *.pth
43
+ *.ckpt
44
+
45
+ # Local development
46
+ .env
47
+ .env.local
48
+
49
+ # OS
50
+ .DS_Store
51
+ Thumbs.db
52
+
53
+ # Logs
54
+ *.log
55
+ logs/
56
+
57
+ # Temporary files
58
+ *.tmp
59
+ *.temp
README.md CHANGED
@@ -1,13 +1,43 @@
1
  ---
2
  title: HeartWatchAI
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: AI-based ECG heart Analysis
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: HeartWatchAI
3
+ emoji: ❤️
4
+ colorFrom: red
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.12.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: AI-powered 12-Lead ECG Analysis
11
+ hf_oauth: false
12
  ---
13
 
14
+ # HeartWatch AI
15
+
16
+ AI-powered 12-Lead ECG analysis using deep learning models.
17
+
18
+ ## Features
19
+
20
+ - **77-Class ECG Diagnosis**: Detect 77 different cardiac conditions
21
+ - **LVEF Prediction**: Predict left ventricular ejection fraction < 40% and < 50%
22
+ - **AFib Risk**: 5-year atrial fibrillation risk prediction
23
+ - **Interactive Visualization**: Clinical 4x3 lead layout with ECG paper grid
24
+
25
+ ## Models
26
+
27
+ This demo uses EfficientNetV2 models from the DeepECG project:
28
+
29
+ - `heartwise/EfficientNetV2_77_Classes`
30
+ - `heartwise/EfficientNetV2_LVEF_40`
31
+ - `heartwise/EfficientNetV2_LVEF_50`
32
+ - `heartwise/EfficientNetV2_AFIB_5y`
33
+
34
+ ## Input Format
35
+
36
+ - NumPy array (.npy file)
37
+ - Shape: (2500, 12) or (12, 2500)
38
+ - 12 standard leads: I, II, III, aVR, aVL, aVF, V1-V6
39
+ - 10 seconds at 250 Hz sampling rate
40
+
41
+ ## Disclaimer
42
+
43
+ This is a research demonstration tool. Predictions should NOT be used for clinical decision-making. Always consult qualified healthcare professionals for medical advice.
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HeartWatch AI - ECG Analysis Demo
3
+ ==================================
4
+
5
+ A Gradio-based web application for AI-powered ECG analysis using DeepECG models.
6
+
7
+ Features:
8
+ - 77-class ECG diagnosis
9
+ - LVEF < 40% prediction
10
+ - LVEF < 50% prediction
11
+ - 5-year AFib risk assessment
12
+ - Interactive 12-lead ECG visualization
13
+ """
14
+
15
+ import os
16
+ import logging
17
+ import numpy as np
18
+ import gradio as gr
19
+ from pathlib import Path
20
+
21
+ # Local imports
22
+ from inference import DeepECGInference
23
+ from visualization import (
24
+ plot_ecg_waveform,
25
+ plot_diagnosis_bars,
26
+ plot_risk_gauges,
27
+ generate_thumbnail
28
+ )
29
+
30
+ # Configure logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Global inference engine
35
+ inference_engine = None
36
+
37
+
38
+ def load_inference_engine():
39
+ """Load the inference engine on startup."""
40
+ global inference_engine
41
+ if inference_engine is None:
42
+ logger.info("Loading DeepECG inference engine...")
43
+ inference_engine = DeepECGInference()
44
+ inference_engine.load_models()
45
+ logger.info("Inference engine loaded successfully")
46
+ return inference_engine
47
+
48
+
49
+ def get_sample_ecgs():
50
+ """Get list of sample ECG files from demo_data directory."""
51
+ sample_dir = Path(__file__).parent / "demo_data" / "samples"
52
+ if not sample_dir.exists():
53
+ return []
54
+
55
+ samples = []
56
+ for npy_file in sorted(sample_dir.glob("*.npy")):
57
+ samples.append({
58
+ "path": str(npy_file),
59
+ "name": npy_file.stem.replace("_", " ").title()
60
+ })
61
+ return samples
62
+
63
+
64
+ def analyze_ecg(ecg_signal: np.ndarray, filename: str = "Uploaded ECG"):
65
+ """
66
+ Analyze an ECG signal and return all visualizations.
67
+
68
+ Args:
69
+ ecg_signal: ECG signal array
70
+ filename: Name to display
71
+
72
+ Returns:
73
+ Tuple of (ecg_plot, diagnosis_plot, risk_plot, summary_text)
74
+ """
75
+ engine = load_inference_engine()
76
+
77
+ # Run inference
78
+ results = engine.predict(ecg_signal)
79
+
80
+ # Generate ECG waveform plot
81
+ ecg_fig = plot_ecg_waveform(ecg_signal, sample_rate=250, title=filename)
82
+
83
+ # Generate diagnosis bar chart
84
+ if "diagnosis_77" in results:
85
+ probs = results["diagnosis_77"]["probabilities"]
86
+ class_names = results["diagnosis_77"]["class_names"]
87
+ diagnosis_dict = dict(zip(class_names, probs))
88
+ diagnosis_fig = plot_diagnosis_bars(diagnosis_dict, top_n=10)
89
+ else:
90
+ diagnosis_fig = None
91
+
92
+ # Generate risk gauges
93
+ lvef_40 = results.get("lvef_40", 0.0)
94
+ lvef_50 = results.get("lvef_50", 0.0)
95
+ afib_5y = results.get("afib_5y", 0.0)
96
+ risk_fig = plot_risk_gauges(lvef_40, lvef_50, afib_5y)
97
+
98
+ # Generate summary text
99
+ inference_time = results.get("inference_time_ms", 0)
100
+ summary = f"""## Analysis Summary
101
+
102
+ **Inference Time:** {inference_time:.1f} ms
103
+
104
+ ### Risk Predictions
105
+ - **LVEF < 40%:** {lvef_40*100:.1f}% probability
106
+ - **LVEF < 50%:** {lvef_50*100:.1f}% probability
107
+ - **5-year AFib Risk:** {afib_5y*100:.1f}% probability
108
+
109
+ ### Top Diagnoses
110
+ """
111
+ if "diagnosis_77" in results:
112
+ probs = results["diagnosis_77"]["probabilities"]
113
+ class_names = results["diagnosis_77"]["class_names"]
114
+ top_indices = np.argsort(probs)[::-1][:5]
115
+ for i, idx in enumerate(top_indices, 1):
116
+ summary += f"{i}. {class_names[idx]}: {probs[idx]*100:.1f}%\n"
117
+
118
+ return ecg_fig, diagnosis_fig, risk_fig, summary
119
+
120
+
121
+ def analyze_uploaded_file(file):
122
+ """Handle uploaded .npy file."""
123
+ if file is None:
124
+ return None, None, None, "Please upload a .npy file containing ECG data."
125
+
126
+ try:
127
+ ecg_signal = np.load(file.name)
128
+ filename = Path(file.name).stem
129
+ return analyze_ecg(ecg_signal, filename)
130
+ except Exception as e:
131
+ logger.error(f"Error loading file: {e}")
132
+ return None, None, None, f"Error loading file: {str(e)}"
133
+
134
+
135
+ def analyze_sample(sample_name: str):
136
+ """Analyze a sample ECG from the gallery."""
137
+ samples = get_sample_ecgs()
138
+
139
+ for sample in samples:
140
+ if sample["name"] == sample_name:
141
+ ecg_signal = np.load(sample["path"])
142
+ return analyze_ecg(ecg_signal, sample["name"])
143
+
144
+ return None, None, None, "Sample not found."
145
+
146
+
147
+ def create_demo_interface():
148
+ """Create the Gradio interface."""
149
+
150
+ # Custom CSS for styling
151
+ custom_css = """
152
+ .gradio-container {
153
+ font-family: 'Inter', sans-serif;
154
+ }
155
+ .main-header {
156
+ text-align: center;
157
+ padding: 20px;
158
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
159
+ color: white;
160
+ border-radius: 10px;
161
+ margin-bottom: 20px;
162
+ }
163
+ .main-header h1 {
164
+ margin: 0;
165
+ font-size: 2.5em;
166
+ }
167
+ .main-header p {
168
+ margin: 10px 0 0 0;
169
+ opacity: 0.9;
170
+ }
171
+ """
172
+
173
+ with gr.Blocks(css=custom_css, title="HeartWatch AI") as demo:
174
+ # Header
175
+ gr.HTML("""
176
+ <div class="main-header">
177
+ <h1>HeartWatch AI</h1>
178
+ <p>AI-Powered 12-Lead ECG Analysis</p>
179
+ </div>
180
+ """)
181
+
182
+ gr.Markdown("""
183
+ This demo analyzes 12-lead ECG signals using deep learning models trained on large clinical datasets.
184
+
185
+ **Models:**
186
+ - 77-class ECG diagnosis classifier
187
+ - LVEF < 40% prediction
188
+ - LVEF < 50% prediction
189
+ - 5-year Atrial Fibrillation risk
190
+
191
+ **Note:** This is a research demo. Results should not be used for clinical decision-making.
192
+ """)
193
+
194
+ with gr.Tabs():
195
+ # Tab 1: Upload ECG
196
+ with gr.TabItem("Upload ECG"):
197
+ with gr.Row():
198
+ with gr.Column(scale=1):
199
+ file_input = gr.File(
200
+ label="Upload ECG (.npy file)",
201
+ file_types=[".npy"],
202
+ type="filepath"
203
+ )
204
+ analyze_btn = gr.Button("Analyze ECG", variant="primary")
205
+
206
+ gr.Markdown("""
207
+ **Expected Format:**
208
+ - NumPy array shape: (2500, 12) or (12, 2500)
209
+ - 12 leads: I, II, III, aVR, aVL, aVF, V1-V6
210
+ - 2500 samples (10 seconds at 250 Hz)
211
+ """)
212
+
213
+ with gr.Column(scale=2):
214
+ summary_output = gr.Markdown(label="Summary")
215
+
216
+ with gr.Row():
217
+ ecg_plot = gr.Plot(label="12-Lead ECG")
218
+
219
+ with gr.Row():
220
+ with gr.Column():
221
+ diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
222
+ with gr.Column():
223
+ risk_plot = gr.Plot(label="Risk Assessment")
224
+
225
+ analyze_btn.click(
226
+ fn=analyze_uploaded_file,
227
+ inputs=[file_input],
228
+ outputs=[ecg_plot, diagnosis_plot, risk_plot, summary_output]
229
+ )
230
+
231
+ # Tab 2: Sample Gallery
232
+ with gr.TabItem("Sample Gallery"):
233
+ gr.Markdown("### Select a sample ECG to analyze")
234
+
235
+ samples = get_sample_ecgs()
236
+ if samples:
237
+ sample_names = [s["name"] for s in samples]
238
+ sample_dropdown = gr.Dropdown(
239
+ choices=sample_names,
240
+ label="Select Sample",
241
+ value=sample_names[0] if sample_names else None
242
+ )
243
+ analyze_sample_btn = gr.Button("Analyze Sample", variant="primary")
244
+
245
+ with gr.Row():
246
+ sample_summary = gr.Markdown(label="Summary")
247
+
248
+ with gr.Row():
249
+ sample_ecg_plot = gr.Plot(label="12-Lead ECG")
250
+
251
+ with gr.Row():
252
+ with gr.Column():
253
+ sample_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
254
+ with gr.Column():
255
+ sample_risk_plot = gr.Plot(label="Risk Assessment")
256
+
257
+ analyze_sample_btn.click(
258
+ fn=analyze_sample,
259
+ inputs=[sample_dropdown],
260
+ outputs=[sample_ecg_plot, sample_diagnosis_plot, sample_risk_plot, sample_summary]
261
+ )
262
+ else:
263
+ gr.Markdown("*No sample ECGs available. Upload your own in the Upload tab.*")
264
+
265
+ # Tab 3: About
266
+ with gr.TabItem("About"):
267
+ gr.Markdown("""
268
+ ## About HeartWatch AI
269
+
270
+ HeartWatch AI is a deep learning-based ECG analysis system that can:
271
+
272
+ ### Models
273
+
274
+ 1. **77-Class Diagnosis Model**
275
+ - Trained to detect 77 different ECG patterns and conditions
276
+ - Based on EfficientNetV2 architecture
277
+ - Outputs probability for each condition
278
+
279
+ 2. **LVEF Prediction Models**
280
+ - LVEF < 40%: Identifies patients with reduced ejection fraction
281
+ - LVEF < 50%: Identifies patients with moderately reduced ejection fraction
282
+
283
+ 3. **AFib Risk Model**
284
+ - Predicts 5-year risk of developing Atrial Fibrillation
285
+
286
+ ### Technical Details
287
+
288
+ - **Input:** 12-lead ECG, 10 seconds, 250 Hz sampling rate
289
+ - **Architecture:** EfficientNetV2 (TorchScript optimized)
290
+ - **Inference:** CPU-optimized for accessibility
291
+
292
+ ### Disclaimer
293
+
294
+ This is a research demonstration tool. The predictions provided should NOT be used
295
+ for clinical decision-making. Always consult qualified healthcare professionals
296
+ for medical advice and diagnosis.
297
+
298
+ ### Contact
299
+
300
+ For questions or issues, please visit our GitHub repository.
301
+ """)
302
+
303
+ # Footer
304
+ gr.Markdown("""
305
+ ---
306
+ *Built with Gradio and PyTorch. Models from DeepECG project.*
307
+ """)
308
+
309
+ return demo
310
+
311
+
312
+ # Create and launch the demo
313
+ if __name__ == "__main__":
314
+ # Pre-load the inference engine
315
+ try:
316
+ load_inference_engine()
317
+ except Exception as e:
318
+ logger.warning(f"Could not pre-load models: {e}")
319
+
320
+ # Create and launch demo
321
+ demo = create_demo_interface()
322
+ demo.launch(
323
+ server_name="0.0.0.0",
324
+ server_port=7860,
325
+ share=False
326
+ )
class_names.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "Sinusal",
3
+ "Regular",
4
+ "Monomorph",
5
+ "QS complex in V1-V2-V3",
6
+ "R complex in V5-V6",
7
+ "T wave inversion (inferior - II, III, aVF)",
8
+ "Left bundle branch block",
9
+ "RaVL > 11 mm",
10
+ "SV1 + RV5 or RV6 > 35 mm",
11
+ "T wave inversion (lateral -I, aVL, V5-V6)",
12
+ "T wave inversion (anterior - V3-V4)",
13
+ "Left axis deviation",
14
+ "Left ventricular hypertrophy",
15
+ "Bradycardia",
16
+ "Q wave (inferior - II, III, aVF)",
17
+ "Afib",
18
+ "Irregularly irregular",
19
+ "Atrial tachycardia (>= 100 BPM)",
20
+ "Nonspecific intraventricular conduction delay",
21
+ "Premature ventricular complex",
22
+ "Polymorph",
23
+ "T wave inversion (septal- V1-V2)",
24
+ "Right bundle branch block",
25
+ "Ventricular paced",
26
+ "ST elevation (anterior - V3-V4)",
27
+ "ST elevation (septal - V1-V2)",
28
+ "1st degree AV block",
29
+ "Premature atrial complex",
30
+ "Atrial flutter",
31
+ "rSR' in V1-V2",
32
+ "qRS in V5-V6-I, aVL",
33
+ "Left anterior fascicular block",
34
+ "Right axis deviation",
35
+ "2nd degree AV block - mobitz 1",
36
+ "ST depression (inferior - II, III, aVF)",
37
+ "Acute pericarditis",
38
+ "ST elevation (inferior - II, III, aVF)",
39
+ "Low voltage",
40
+ "Regularly irregular",
41
+ "Junctional rhythm",
42
+ "Left atrial enlargement",
43
+ "ST elevation (lateral - I, aVL, V5-V6)",
44
+ "Atrial paced",
45
+ "Right ventricular hypertrophy",
46
+ "Delta wave",
47
+ "Wolff-Parkinson-White (Pre-excitation syndrome)",
48
+ "Prolonged QT",
49
+ "ST depression (anterior - V3-V4)",
50
+ "QRS complex negative in III",
51
+ "Q wave (lateral- I, aVL, V5-V6)",
52
+ "Supraventricular tachycardia",
53
+ "ST downslopping",
54
+ "ST depression (lateral - I, avL, V5-V6)",
55
+ "2nd degree AV block - mobitz 2",
56
+ "U wave",
57
+ "R/S ratio in V1-V2 >1",
58
+ "RV1 + SV6 > 11 mm",
59
+ "Left posterior fascicular block",
60
+ "Right atrial enlargement",
61
+ "ST depression (septal- V1-V2)",
62
+ "Q wave (septal- V1-V2)",
63
+ "Q wave (anterior - V3-V4)",
64
+ "ST upslopping",
65
+ "Right superior axis",
66
+ "Ventricular tachycardia",
67
+ "ST elevation (posterior - V7-V8-V9)",
68
+ "Ectopic atrial rhythm (< 100 BPM)",
69
+ "Lead misplacement",
70
+ "Third Degree AV Block",
71
+ "Acute MI",
72
+ "Early repolarization",
73
+ "Q wave (posterior - V7-V9)",
74
+ "Bi-atrial enlargement",
75
+ "LV pacing",
76
+ "Brugada",
77
+ "Ventricular Rhythm",
78
+ "no_qrs"
79
+ ]
demo_data/samples/Atrial_Flutter.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8704d5fb100d5a9080640587aef6cb69c725b82b4dbed4ecf72510f67acd92c9
3
+ size 120128
demo_data/samples/Normal_Sinus_Rhythm.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a1191aa4d9a7dc79b67fcd2eac378b86456d7b1ed035a443819ed88b0c7065f
3
+ size 120128
demo_data/samples/Ventricular_Tachycardia.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a43e7555616d2601daf958140b9770a69522cce0126ed1402f546829646263dc
3
+ size 120128
inference.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepECG Inference Module for HeartWatch AI
3
+ ===========================================
4
+
5
+ This module provides CPU-optimized inference for 4 EfficientNetV2 models:
6
+ - 77-class ECG diagnosis
7
+ - LVEF <= 40% prediction
8
+ - LVEF < 50% prediction
9
+ - 5-year AFib risk prediction
10
+
11
+ The preprocessing exactly replicates DeepECG's pipeline:
12
+ 1. Load signal as (samples, leads) = (2500, 12)
13
+ 2. Transpose to (leads, samples) = (12, 2500)
14
+ 3. Apply MHI factor scaling: signal *= (1/0.0048)
15
+ 4. Apply sigmoid to model logits
16
+
17
+ Models are downloaded from HuggingFace Hub using HF_TOKEN from environment.
18
+ """
19
+
20
+ import os
21
+ import json
22
+ import time
23
+ import logging
24
+ from typing import Dict, Optional, Any, Union
25
+ from pathlib import Path
26
+
27
+ import numpy as np
28
+ import torch
29
+ from huggingface_hub import snapshot_download
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # CPU optimizations for HuggingFace Spaces (no GPU)
36
+ torch.set_num_threads(2)
37
+ torch.set_flush_denormal(True)
38
+
39
+
40
+ class DeepECGInference:
41
+ """
42
+ CPU-optimized inference engine for DeepECG EfficientNetV2 models.
43
+
44
+ Loads 4 models from HuggingFace Hub:
45
+ - heartwise/EfficientNetV2_77_Classes: 77-class ECG diagnosis
46
+ - heartwise/EfficientNetV2_LVEF_40: LVEF <= 40% prediction
47
+ - heartwise/EfficientNetV2_LVEF_50: LVEF < 50% prediction
48
+ - heartwise/EfficientNetV2_AFIB_5y: 5-year AFib risk prediction
49
+
50
+ Attributes:
51
+ device: Always CPU for HF Spaces
52
+ models: Dict containing loaded TorchScript models
53
+ class_names: List of 77 ECG diagnosis class names
54
+ mhi_factor: Scaling factor for signal preprocessing (1/0.0048)
55
+ """
56
+
57
+ # Model repository mappings
58
+ MODEL_REPOS = {
59
+ "diagnosis_77": "heartwise/EfficientNetV2_77_Classes",
60
+ "lvef_40": "heartwise/EfficientNetV2_LVEF_40",
61
+ "lvef_50": "heartwise/EfficientNetV2_LVEF_50",
62
+ "afib_5y": "heartwise/EfficientNetV2_AFIB_5y",
63
+ }
64
+
65
+ # Expected input specifications
66
+ EXPECTED_LEADS = 12
67
+ EXPECTED_SAMPLES = 2500 # 10 seconds at 250 Hz
68
+ SAMPLING_RATE = 250 # Hz
69
+
70
+ # Preprocessing constants from DeepECG
71
+ MHI_FACTOR = 1 / 0.0048 # ~208.33
72
+
73
+ def __init__(self, cache_dir: Optional[str] = None):
74
+ """
75
+ Initialize the inference engine.
76
+
77
+ Args:
78
+ cache_dir: Directory to cache downloaded models.
79
+ Defaults to ./weights
80
+ """
81
+ self.device = torch.device("cpu")
82
+ self.cache_dir = cache_dir or os.path.join(os.getcwd(), "weights")
83
+ self.models: Dict[str, torch.jit.ScriptModule] = {}
84
+ self.class_names: list = []
85
+ self._load_class_names()
86
+
87
+ def _load_class_names(self) -> None:
88
+ """Load the 77 ECG class names from class_names.json."""
89
+ class_names_path = os.path.join(
90
+ os.path.dirname(os.path.abspath(__file__)),
91
+ "class_names.json"
92
+ )
93
+ try:
94
+ with open(class_names_path, "r") as f:
95
+ self.class_names = json.load(f)
96
+ logger.info(f"Loaded {len(self.class_names)} class names")
97
+ except FileNotFoundError:
98
+ logger.warning(f"class_names.json not found at {class_names_path}")
99
+ self.class_names = []
100
+
101
+ def _get_hf_token(self) -> Optional[str]:
102
+ """Get HuggingFace token from environment variable."""
103
+ token = os.environ.get("HF_TOKEN")
104
+ if not token:
105
+ logger.warning("HF_TOKEN environment variable not set")
106
+ return token
107
+
108
+ def _download_model(self, repo_id: str, model_name: str) -> str:
109
+ """
110
+ Download model from HuggingFace Hub.
111
+
112
+ Args:
113
+ repo_id: HuggingFace repository ID
114
+ model_name: Local name for the model
115
+
116
+ Returns:
117
+ Path to the downloaded model directory
118
+ """
119
+ local_dir = os.path.join(self.cache_dir, model_name)
120
+
121
+ if os.path.exists(local_dir):
122
+ logger.info(f"Model {model_name} already cached at {local_dir}")
123
+ return local_dir
124
+
125
+ logger.info(f"Downloading {repo_id} to {local_dir}")
126
+ os.makedirs(local_dir, exist_ok=True)
127
+
128
+ hf_token = self._get_hf_token()
129
+ local_dir = snapshot_download(
130
+ repo_id=repo_id,
131
+ local_dir=local_dir,
132
+ repo_type="model",
133
+ token=hf_token
134
+ )
135
+
136
+ logger.info(f"Downloaded {repo_id} to {local_dir}")
137
+ return local_dir
138
+
139
+ def _load_model_from_dir(self, model_dir: str) -> torch.jit.ScriptModule:
140
+ """
141
+ Load a TorchScript model from a directory.
142
+
143
+ Args:
144
+ model_dir: Directory containing the .pt file
145
+
146
+ Returns:
147
+ Loaded TorchScript model
148
+
149
+ Raises:
150
+ ValueError: If no .pt file is found in the directory
151
+ """
152
+ pt_file = next(
153
+ (f for f in os.listdir(model_dir) if f.endswith('.pt')),
154
+ None
155
+ )
156
+ if not pt_file:
157
+ raise ValueError(f"No .pt file found in {model_dir}")
158
+
159
+ model_path = os.path.join(model_dir, pt_file)
160
+ model = torch.jit.load(model_path, map_location=self.device)
161
+ model.eval()
162
+
163
+ return model
164
+
165
+ def load_models(self) -> None:
166
+ """
167
+ Download and load all 4 models from HuggingFace Hub.
168
+
169
+ Uses HF_TOKEN from os.environ for authentication.
170
+ Models are loaded in eval mode on CPU.
171
+ """
172
+ logger.info("Loading DeepECG models...")
173
+
174
+ for model_key, repo_id in self.MODEL_REPOS.items():
175
+ try:
176
+ model_dir = self._download_model(repo_id, model_key)
177
+ self.models[model_key] = self._load_model_from_dir(model_dir)
178
+ logger.info(f"Loaded model: {model_key} from {repo_id}")
179
+ except Exception as e:
180
+ logger.error(f"Failed to load {model_key}: {e}")
181
+ raise
182
+
183
+ logger.info(f"Successfully loaded {len(self.models)} models")
184
+
185
+ def preprocess_ecg(
186
+ self,
187
+ ecg_signal: Union[np.ndarray, torch.Tensor]
188
+ ) -> torch.Tensor:
189
+ """
190
+ Preprocess ECG signal to match DeepECG's exact preprocessing.
191
+
192
+ The preprocessing pipeline:
193
+ 1. Ensure signal is numpy array with correct shape
194
+ 2. Handle shape: expect (samples, leads) = (2500, 12) or (12, 2500)
195
+ 3. Transpose to (leads, samples) = (12, 2500) if needed
196
+ 4. Convert to float32 tensor
197
+ 5. Add batch dimension: (1, 12, 2500)
198
+ 6. Apply MHI factor scaling: signal *= (1/0.0048)
199
+
200
+ Args:
201
+ ecg_signal: Raw ECG signal, shape (samples, leads) or (leads, samples)
202
+ Expected: 12 leads, 2500 samples (10s at 250Hz)
203
+
204
+ Returns:
205
+ Preprocessed tensor ready for model inference, shape (1, 12, 2500)
206
+
207
+ Raises:
208
+ ValueError: If signal shape is invalid
209
+ """
210
+ # Convert to numpy if tensor
211
+ if isinstance(ecg_signal, torch.Tensor):
212
+ ecg_signal = ecg_signal.numpy()
213
+
214
+ # Ensure float32
215
+ ecg_signal = ecg_signal.astype(np.float32)
216
+
217
+ # Handle shape - expect (samples, leads) = (2500, 12) or (12, 2500)
218
+ if ecg_signal.ndim != 2:
219
+ raise ValueError(
220
+ f"Expected 2D signal, got shape {ecg_signal.shape}"
221
+ )
222
+
223
+ # Determine orientation and transpose if needed
224
+ # If shape is (samples, leads) = (2500, 12), transpose to (12, 2500)
225
+ # If shape is (12, 2500), it's already correct
226
+ if ecg_signal.shape[0] == self.EXPECTED_SAMPLES and ecg_signal.shape[1] == self.EXPECTED_LEADS:
227
+ # Shape is (2500, 12) -> transpose to (12, 2500)
228
+ ecg_signal = ecg_signal.T
229
+ elif ecg_signal.shape[0] == self.EXPECTED_LEADS and ecg_signal.shape[1] == self.EXPECTED_SAMPLES:
230
+ # Shape is already (12, 2500)
231
+ pass
232
+ else:
233
+ # Try to infer orientation
234
+ if ecg_signal.shape[1] == self.EXPECTED_LEADS:
235
+ ecg_signal = ecg_signal.T
236
+ elif ecg_signal.shape[0] != self.EXPECTED_LEADS:
237
+ raise ValueError(
238
+ f"Invalid signal shape {ecg_signal.shape}. "
239
+ f"Expected (2500, 12) or (12, 2500)"
240
+ )
241
+
242
+ # Verify final shape
243
+ if ecg_signal.shape[0] != self.EXPECTED_LEADS:
244
+ raise ValueError(
245
+ f"Signal must have {self.EXPECTED_LEADS} leads, "
246
+ f"got {ecg_signal.shape[0]}"
247
+ )
248
+
249
+ # Convert to tensor and add batch dimension
250
+ signal_tensor = torch.from_numpy(ecg_signal).float()
251
+ signal_tensor = signal_tensor.unsqueeze(0) # (1, 12, samples)
252
+
253
+ # Move to device (CPU)
254
+ signal_tensor = signal_tensor.to(self.device)
255
+
256
+ # Apply MHI factor scaling (this is done in model __call__ in DeepECG)
257
+ signal_tensor = signal_tensor * self.MHI_FACTOR
258
+
259
+ return signal_tensor
260
+
261
+ def predict(
262
+ self,
263
+ ecg_signal: Union[np.ndarray, torch.Tensor]
264
+ ) -> Dict[str, Any]:
265
+ """
266
+ Run inference on an ECG signal using all 4 models.
267
+
268
+ Args:
269
+ ecg_signal: Raw ECG signal, shape (samples, leads) or (leads, samples)
270
+ Expected: 12 leads, 2500 samples (10s at 250Hz)
271
+
272
+ Returns:
273
+ Dictionary containing:
274
+ - diagnosis_77: Dict with 'probabilities' (77 floats) and 'class_names'
275
+ - lvef_40: Probability of LVEF <= 40%
276
+ - lvef_50: Probability of LVEF < 50%
277
+ - afib_5y: Probability of AFib within 5 years
278
+ - inference_time_ms: Total inference time in milliseconds
279
+ """
280
+ if not self.models:
281
+ raise RuntimeError("Models not loaded. Call load_models() first.")
282
+
283
+ start_time = time.time()
284
+
285
+ # Preprocess the signal
286
+ signal_tensor = self.preprocess_ecg(ecg_signal)
287
+
288
+ results = {}
289
+
290
+ with torch.no_grad():
291
+ # 77-class diagnosis
292
+ if "diagnosis_77" in self.models:
293
+ logits = self.models["diagnosis_77"](signal_tensor)
294
+ probs = torch.sigmoid(logits)
295
+ probs_list = probs.squeeze().cpu().numpy().tolist()
296
+ results["diagnosis_77"] = {
297
+ "probabilities": probs_list,
298
+ "class_names": self.class_names if self.class_names else None,
299
+ }
300
+
301
+ # LVEF <= 40%
302
+ if "lvef_40" in self.models:
303
+ logits = self.models["lvef_40"](signal_tensor)
304
+ prob = torch.sigmoid(logits)
305
+ results["lvef_40"] = float(prob.squeeze().cpu().numpy())
306
+
307
+ # LVEF < 50%
308
+ if "lvef_50" in self.models:
309
+ logits = self.models["lvef_50"](signal_tensor)
310
+ prob = torch.sigmoid(logits)
311
+ results["lvef_50"] = float(prob.squeeze().cpu().numpy())
312
+
313
+ # 5-year AFib risk
314
+ if "afib_5y" in self.models:
315
+ logits = self.models["afib_5y"](signal_tensor)
316
+ prob = torch.sigmoid(logits)
317
+ results["afib_5y"] = float(prob.squeeze().cpu().numpy())
318
+
319
+ end_time = time.time()
320
+ results["inference_time_ms"] = (end_time - start_time) * 1000
321
+
322
+ return results
323
+
324
+ def predict_diagnosis_top_k(
325
+ self,
326
+ ecg_signal: Union[np.ndarray, torch.Tensor],
327
+ k: int = 5
328
+ ) -> Dict[str, Any]:
329
+ """
330
+ Get top-k diagnoses from the 77-class model.
331
+
332
+ Args:
333
+ ecg_signal: Raw ECG signal
334
+ k: Number of top predictions to return
335
+
336
+ Returns:
337
+ Dictionary with top-k predictions sorted by probability
338
+ """
339
+ results = self.predict(ecg_signal)
340
+
341
+ if "diagnosis_77" not in results:
342
+ raise RuntimeError("77-class diagnosis model not loaded")
343
+
344
+ probs = results["diagnosis_77"]["probabilities"]
345
+ class_names = results["diagnosis_77"]["class_names"] or [f"Class_{i}" for i in range(77)]
346
+
347
+ # Get top-k indices
348
+ top_k_indices = np.argsort(probs)[::-1][:k]
349
+
350
+ top_k_predictions = [
351
+ {
352
+ "class_name": class_names[idx],
353
+ "probability": probs[idx],
354
+ "class_index": int(idx)
355
+ }
356
+ for idx in top_k_indices
357
+ ]
358
+
359
+ return {
360
+ "top_k_predictions": top_k_predictions,
361
+ "inference_time_ms": results["inference_time_ms"]
362
+ }
363
+
364
+
365
+ def get_inference_engine(cache_dir: Optional[str] = None) -> DeepECGInference:
366
+ """
367
+ Factory function to create and initialize a DeepECGInference instance.
368
+
369
+ Args:
370
+ cache_dir: Optional directory to cache models
371
+
372
+ Returns:
373
+ Initialized DeepECGInference with models loaded
374
+ """
375
+ engine = DeepECGInference(cache_dir=cache_dir)
376
+ engine.load_models()
377
+ return engine
378
+
379
+
380
+ if __name__ == "__main__":
381
+ # Example usage / testing
382
+ print("DeepECG Inference Module")
383
+ print("=" * 50)
384
+
385
+ # Create inference engine
386
+ engine = DeepECGInference()
387
+
388
+ # Load models (requires HF_TOKEN environment variable)
389
+ try:
390
+ engine.load_models()
391
+ print("Models loaded successfully!")
392
+
393
+ # Create dummy signal for testing
394
+ dummy_signal = np.random.randn(2500, 12).astype(np.float32)
395
+
396
+ # Run inference
397
+ results = engine.predict(dummy_signal)
398
+
399
+ print(f"\nInference time: {results['inference_time_ms']:.2f} ms")
400
+ print(f"LVEF <= 40%: {results['lvef_40']:.4f}")
401
+ print(f"LVEF < 50%: {results['lvef_50']:.4f}")
402
+ print(f"5-year AFib risk: {results['afib_5y']:.4f}")
403
+ print(f"77-class diagnosis: {len(results['diagnosis_77']['probabilities'])} classes")
404
+
405
+ # Get top-5 diagnoses
406
+ top_5 = engine.predict_diagnosis_top_k(dummy_signal, k=5)
407
+ print("\nTop 5 diagnoses:")
408
+ for pred in top_5["top_k_predictions"]:
409
+ print(f" {pred['class_name']}: {pred['probability']:.4f}")
410
+
411
+ except Exception as e:
412
+ print(f"Error: {e}")
413
+ print("\nMake sure HF_TOKEN environment variable is set:")
414
+ print(" export HF_TOKEN='your_huggingface_token'")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ numpy>=1.21.0
4
+ matplotlib>=3.5.0
5
+ Pillow>=9.0.0
6
+ huggingface_hub>=0.16.0
7
+ scipy>=1.7.0
test_inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for DeepECG Inference Module
4
+ =========================================
5
+
6
+ Run this script to verify the inference engine works correctly.
7
+
8
+ Usage:
9
+ # Set HF_TOKEN environment variable first
10
+ export HF_TOKEN='your_huggingface_token'
11
+
12
+ # Run the test
13
+ python test_inference.py
14
+
15
+ Expected output:
16
+ - Models download from HuggingFace Hub
17
+ - Dummy signal inference completes
18
+ - Results for all 4 models are printed
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import numpy as np
24
+
25
+ # Ensure HF_TOKEN is set
26
+ if not os.environ.get("HF_TOKEN"):
27
+ print("ERROR: HF_TOKEN environment variable not set")
28
+ print("Please run: export HF_TOKEN='your_token'")
29
+ sys.exit(1)
30
+
31
+ print("=" * 60)
32
+ print("DeepECG Inference Test")
33
+ print("=" * 60)
34
+
35
+ # Import the inference module
36
+ try:
37
+ from inference import DeepECGInference
38
+ print("[OK] Import successful")
39
+ except ImportError as e:
40
+ print(f"[FAIL] Import failed: {e}")
41
+ sys.exit(1)
42
+
43
+ # Create inference engine
44
+ try:
45
+ engine = DeepECGInference()
46
+ print(f"[OK] Engine created with {len(engine.class_names)} class names")
47
+ except Exception as e:
48
+ print(f"[FAIL] Engine creation failed: {e}")
49
+ sys.exit(1)
50
+
51
+ # Load models
52
+ print("\nLoading models from HuggingFace Hub...")
53
+ try:
54
+ engine.load_models()
55
+ print(f"[OK] Loaded {len(engine.models)} models")
56
+ for name in engine.models:
57
+ print(f" - {name}")
58
+ except Exception as e:
59
+ print(f"[FAIL] Model loading failed: {e}")
60
+ sys.exit(1)
61
+
62
+ # Test with dummy signal
63
+ print("\nTesting inference with dummy signal...")
64
+ try:
65
+ # Create dummy 10-second ECG (2500 samples at 250Hz, 12 leads)
66
+ dummy_signal = np.random.randn(2500, 12).astype(np.float32)
67
+
68
+ # Run inference
69
+ results = engine.predict(dummy_signal)
70
+
71
+ print(f"[OK] Inference completed in {results['inference_time_ms']:.2f} ms")
72
+ print(f"\nResults:")
73
+ print(f" - LVEF <= 40%: {results['lvef_40']:.4f}")
74
+ print(f" - LVEF < 50%: {results['lvef_50']:.4f}")
75
+ print(f" - 5-year AFib: {results['afib_5y']:.4f}")
76
+ print(f" - 77-class diagnosis: {len(results['diagnosis_77']['probabilities'])} probabilities")
77
+
78
+ except Exception as e:
79
+ print(f"[FAIL] Inference failed: {e}")
80
+ import traceback
81
+ traceback.print_exc()
82
+ sys.exit(1)
83
+
84
+ # Test top-k predictions
85
+ print("\nTop 5 diagnoses:")
86
+ try:
87
+ top_5 = engine.predict_diagnosis_top_k(dummy_signal, k=5)
88
+ for pred in top_5["top_k_predictions"]:
89
+ print(f" {pred['class_name']}: {pred['probability']:.4f}")
90
+ except Exception as e:
91
+ print(f"[FAIL] Top-k prediction failed: {e}")
92
+ sys.exit(1)
93
+
94
+ # Test preprocessing edge cases
95
+ print("\nTesting preprocessing with different input shapes...")
96
+ try:
97
+ # Test (2500, 12) shape
98
+ signal_1 = np.random.randn(2500, 12).astype(np.float32)
99
+ tensor_1 = engine.preprocess_ecg(signal_1)
100
+ assert tensor_1.shape == (1, 12, 2500), f"Expected (1, 12, 2500), got {tensor_1.shape}"
101
+ print(f"[OK] Shape (2500, 12) -> {tuple(tensor_1.shape)}")
102
+
103
+ # Test (12, 2500) shape
104
+ signal_2 = np.random.randn(12, 2500).astype(np.float32)
105
+ tensor_2 = engine.preprocess_ecg(signal_2)
106
+ assert tensor_2.shape == (1, 12, 2500), f"Expected (1, 12, 2500), got {tensor_2.shape}"
107
+ print(f"[OK] Shape (12, 2500) -> {tuple(tensor_2.shape)}")
108
+
109
+ except Exception as e:
110
+ print(f"[FAIL] Preprocessing test failed: {e}")
111
+ sys.exit(1)
112
+
113
+ print("\n" + "=" * 60)
114
+ print("ALL TESTS PASSED!")
115
+ print("=" * 60)
visualization.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HeartWatch AI Visualization Module
3
+
4
+ This module provides visualization functions for ECG analysis including:
5
+ - 12-lead ECG waveform plotting with clinical layout
6
+ - Diagnosis probability bar charts
7
+ - Risk assessment gauges
8
+ - ECG thumbnail generation for galleries
9
+ """
10
+
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+ from matplotlib.patches import Wedge
15
+ from PIL import Image
16
+ import io
17
+
18
+
19
+ # Standard 12-lead ECG names in clinical order
20
+ LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
21
+
22
+ # Clinical layout: 4 columns x 3 rows
23
+ # Col 1: I, II, III | Col 2: aVR, aVL, aVF | Col 3: V1, V2, V3 | Col 4: V4, V5, V6
24
+ LEAD_LAYOUT = [
25
+ ['I', 'aVR', 'V1', 'V4'],
26
+ ['II', 'aVL', 'V2', 'V5'],
27
+ ['III', 'aVF', 'V3', 'V6']
28
+ ]
29
+
30
+
31
+ def plot_ecg_waveform(ecg_signal: np.ndarray, sample_rate: int = 250,
32
+ title: str = "12-Lead ECG") -> plt.Figure:
33
+ """
34
+ Plot a 12-lead ECG waveform in clinical layout format.
35
+
36
+ Parameters
37
+ ----------
38
+ ecg_signal : np.ndarray
39
+ ECG signal array of shape (12, n_samples) or (n_samples, 12)
40
+ Each row/column represents one of the 12 standard leads
41
+ sample_rate : int, optional
42
+ Sampling rate in Hz, default 250
43
+ title : str, optional
44
+ Figure title, default "12-Lead ECG"
45
+
46
+ Returns
47
+ -------
48
+ plt.Figure
49
+ Matplotlib figure with 4x3 ECG layout
50
+ """
51
+ # Ensure correct shape (12, n_samples)
52
+ if ecg_signal.shape[0] != 12:
53
+ if ecg_signal.shape[1] == 12:
54
+ ecg_signal = ecg_signal.T
55
+ else:
56
+ raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}")
57
+
58
+ n_samples = ecg_signal.shape[1]
59
+
60
+ # 2.5 seconds per column
61
+ samples_per_col = int(2.5 * sample_rate)
62
+
63
+ # Create figure with clinical dimensions
64
+ fig, axes = plt.subplots(3, 4, figsize=(14, 8))
65
+ fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98)
66
+
67
+ # Create lead index mapping
68
+ lead_to_idx = {name: i for i, name in enumerate(LEAD_NAMES)}
69
+
70
+ for row in range(3):
71
+ for col in range(4):
72
+ ax = axes[row, col]
73
+ lead_name = LEAD_LAYOUT[row][col]
74
+ lead_idx = lead_to_idx[lead_name]
75
+
76
+ # Get signal segment for this column (2.5 sec)
77
+ start_sample = 0
78
+ end_sample = min(samples_per_col, n_samples)
79
+
80
+ signal_segment = ecg_signal[lead_idx, start_sample:end_sample]
81
+ time_segment = np.arange(len(signal_segment)) / sample_rate
82
+
83
+ # Set up ECG paper grid background (pink/red)
84
+ ax.set_facecolor('#fff5f5')
85
+
86
+ # Major grid (0.5 sec, 0.5 mV equivalent)
87
+ ax.set_axisbelow(True)
88
+ ax.grid(True, which='major', color='#ffcccc', linewidth=0.8, linestyle='-')
89
+ ax.grid(True, which='minor', color='#ffe6e6', linewidth=0.4, linestyle='-')
90
+
91
+ # Set tick spacing for major/minor grids
92
+ ax.set_xticks(np.arange(0, 2.6, 0.5))
93
+ ax.set_xticks(np.arange(0, 2.6, 0.1), minor=True)
94
+
95
+ # Calculate y-limits based on signal range
96
+ signal_min, signal_max = signal_segment.min(), signal_segment.max()
97
+ signal_range = signal_max - signal_min
98
+ if signal_range < 0.1:
99
+ signal_range = 2.0 # Default range if signal is flat
100
+ padding = signal_range * 0.1
101
+ y_min = signal_min - padding
102
+ y_max = signal_max + padding
103
+
104
+ # Set y-ticks for grid
105
+ y_tick_spacing = signal_range / 4
106
+ ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing, y_tick_spacing))
107
+ ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing/5, y_tick_spacing/5), minor=True)
108
+
109
+ # Plot ECG waveform
110
+ ax.plot(time_segment, signal_segment, color='black', linewidth=0.8)
111
+
112
+ # Add lead label
113
+ ax.text(0.02, 0.98, lead_name, transform=ax.transAxes,
114
+ fontsize=10, fontweight='bold', verticalalignment='top',
115
+ bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
116
+ edgecolor='none', alpha=0.7))
117
+
118
+ # Set axis limits
119
+ ax.set_xlim(0, 2.5)
120
+ ax.set_ylim(y_min, y_max)
121
+
122
+ # Remove tick labels for cleaner look (except bottom row and left column)
123
+ if row < 2:
124
+ ax.set_xticklabels([])
125
+ else:
126
+ ax.set_xlabel('Time (s)', fontsize=8)
127
+
128
+ if col > 0:
129
+ ax.set_yticklabels([])
130
+ else:
131
+ ax.set_ylabel('Amplitude (mV)', fontsize=8)
132
+
133
+ ax.tick_params(axis='both', which='both', labelsize=6)
134
+
135
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
136
+ return fig
137
+
138
+
139
+ def plot_diagnosis_bars(diagnosis_77: dict, top_n: int = 10,
140
+ ground_truth: list = None) -> plt.Figure:
141
+ """
142
+ Plot horizontal bar chart of diagnosis probabilities.
143
+
144
+ Parameters
145
+ ----------
146
+ diagnosis_77 : dict
147
+ Dictionary mapping diagnosis names to probabilities (0-1)
148
+ top_n : int, optional
149
+ Number of top diagnoses to display, default 10
150
+ ground_truth : list, optional
151
+ List of ground truth diagnosis names to mark with star
152
+
153
+ Returns
154
+ -------
155
+ plt.Figure
156
+ Matplotlib figure with horizontal bar chart
157
+ """
158
+ if ground_truth is None:
159
+ ground_truth = []
160
+
161
+ # Sort diagnoses by probability (descending)
162
+ sorted_diagnoses = sorted(diagnosis_77.items(), key=lambda x: x[1], reverse=True)
163
+ top_diagnoses = sorted_diagnoses[:top_n]
164
+
165
+ # Extract names and probabilities
166
+ names = [d[0] for d in top_diagnoses]
167
+ probs = [d[1] for d in top_diagnoses]
168
+
169
+ # Determine colors based on probability thresholds
170
+ colors = []
171
+ for p in probs:
172
+ if p >= 0.7:
173
+ colors.append('#2ecc71') # Green for high confidence
174
+ elif p >= 0.3:
175
+ colors.append('#f1c40f') # Yellow for moderate
176
+ else:
177
+ colors.append('#95a5a6') # Gray for low confidence
178
+
179
+ # Create figure
180
+ fig, ax = plt.subplots(figsize=(8, 6))
181
+
182
+ # Create horizontal bar chart
183
+ y_pos = np.arange(len(names))
184
+ bars = ax.barh(y_pos, probs, color=colors, edgecolor='black', linewidth=0.5)
185
+
186
+ # Add probability labels on bars
187
+ for i, (bar, prob) in enumerate(zip(bars, probs)):
188
+ width = bar.get_width()
189
+ label_x = width + 0.02 if width < 0.85 else width - 0.08
190
+ label_color = 'black' if width < 0.85 else 'white'
191
+ ax.text(label_x, bar.get_y() + bar.get_height()/2,
192
+ f'{prob:.1%}', va='center', fontsize=9, color=label_color)
193
+
194
+ # Mark ground truth with star
195
+ display_names = []
196
+ for name in names:
197
+ if name in ground_truth:
198
+ display_names.append(f'{name} \u2605') # Unicode star
199
+ else:
200
+ display_names.append(name)
201
+
202
+ # Set y-axis labels
203
+ ax.set_yticks(y_pos)
204
+ ax.set_yticklabels(display_names, fontsize=9)
205
+
206
+ # Set axis limits and labels
207
+ ax.set_xlim(0, 1.0)
208
+ ax.set_xlabel('Probability', fontsize=11)
209
+ ax.set_title('Diagnosis Probabilities (Top {})'.format(top_n),
210
+ fontsize=12, fontweight='bold', pad=10)
211
+
212
+ # Add legend
213
+ legend_elements = [
214
+ mpatches.Patch(facecolor='#2ecc71', edgecolor='black', label='High (\u2265 70%)'),
215
+ mpatches.Patch(facecolor='#f1c40f', edgecolor='black', label='Moderate (30-70%)'),
216
+ mpatches.Patch(facecolor='#95a5a6', edgecolor='black', label='Low (< 30%)')
217
+ ]
218
+ if ground_truth:
219
+ legend_elements.append(mpatches.Patch(facecolor='white', edgecolor='white',
220
+ label='\u2605 = Ground Truth'))
221
+ ax.legend(handles=legend_elements, loc='lower right', fontsize=8)
222
+
223
+ # Add grid for readability
224
+ ax.xaxis.grid(True, linestyle='--', alpha=0.7)
225
+ ax.set_axisbelow(True)
226
+
227
+ # Invert y-axis so highest probability is at top
228
+ ax.invert_yaxis()
229
+
230
+ plt.tight_layout()
231
+ return fig
232
+
233
+
234
+ def _draw_gauge(ax, value: float, title: str):
235
+ """
236
+ Draw a semicircular gauge on the given axes.
237
+
238
+ Parameters
239
+ ----------
240
+ ax : matplotlib.axes.Axes
241
+ Axes to draw on
242
+ value : float
243
+ Value between 0 and 1 to display
244
+ title : str
245
+ Gauge title
246
+ """
247
+ # Clear axes
248
+ ax.clear()
249
+ ax.set_xlim(-1.5, 1.5)
250
+ ax.set_ylim(-0.3, 1.3)
251
+ ax.set_aspect('equal')
252
+ ax.axis('off')
253
+
254
+ # Create gradient background arc (Green -> Yellow -> Red)
255
+ n_segments = 100
256
+ for i in range(n_segments):
257
+ theta1 = 180 - i * (180 / n_segments)
258
+ theta2 = 180 - (i + 1) * (180 / n_segments)
259
+
260
+ # Calculate color based on position
261
+ pos = i / n_segments
262
+ if pos < 0.3:
263
+ # Green zone
264
+ color = '#2ecc71'
265
+ elif pos < 0.6:
266
+ # Yellow zone (transition from green to yellow)
267
+ t = (pos - 0.3) / 0.3
268
+ r = int(46 + t * (241 - 46))
269
+ g = int(204 + t * (196 - 204))
270
+ b = int(113 + t * (15 - 113))
271
+ color = f'#{r:02x}{g:02x}{b:02x}'
272
+ else:
273
+ # Red zone (transition from yellow to red)
274
+ t = (pos - 0.6) / 0.4
275
+ r = int(241 + t * (231 - 241))
276
+ g = int(196 - t * 196)
277
+ b = int(15 - t * 15)
278
+ color = f'#{r:02x}{g:02x}{b:02x}'
279
+
280
+ wedge = Wedge((0, 0), 1.0, theta2, theta1, width=0.3, facecolor=color,
281
+ edgecolor='white', linewidth=0.5)
282
+ ax.add_patch(wedge)
283
+
284
+ # Draw needle
285
+ needle_angle = 180 - value * 180
286
+ needle_rad = np.radians(needle_angle)
287
+ needle_length = 0.85
288
+ needle_x = needle_length * np.cos(needle_rad)
289
+ needle_y = needle_length * np.sin(needle_rad)
290
+
291
+ ax.annotate('', xy=(needle_x, needle_y), xytext=(0, 0),
292
+ arrowprops=dict(arrowstyle='->', color='#2c3e50', lw=2))
293
+
294
+ # Draw center circle
295
+ center_circle = plt.Circle((0, 0), 0.1, color='#2c3e50', zorder=5)
296
+ ax.add_patch(center_circle)
297
+
298
+ # Add value text
299
+ ax.text(0, -0.15, f'{value*100:.0f}%', ha='center', va='top',
300
+ fontsize=14, fontweight='bold', color='#2c3e50')
301
+
302
+ # Add title
303
+ ax.text(0, 1.2, title, ha='center', va='bottom',
304
+ fontsize=11, fontweight='bold', color='#2c3e50')
305
+
306
+ # Add risk labels
307
+ ax.text(-1.1, -0.05, 'Low', ha='center', va='top', fontsize=8, color='#27ae60')
308
+ ax.text(0, 1.05, 'Moderate', ha='center', va='bottom', fontsize=8, color='#f39c12')
309
+ ax.text(1.1, -0.05, 'High', ha='center', va='top', fontsize=8, color='#c0392b')
310
+
311
+ # Add threshold markers
312
+ for pct, label in [(0.3, '30%'), (0.6, '60%')]:
313
+ angle = 180 - pct * 180
314
+ rad = np.radians(angle)
315
+ x_outer = 1.05 * np.cos(rad)
316
+ y_outer = 1.05 * np.sin(rad)
317
+ ax.text(x_outer, y_outer, label, ha='center', va='center', fontsize=7, color='#7f8c8d')
318
+
319
+
320
+ def plot_risk_gauges(lvef_40: float, lvef_50: float, afib_5y: float) -> plt.Figure:
321
+ """
322
+ Plot risk assessment gauges for LVEF and AFib predictions.
323
+
324
+ Parameters
325
+ ----------
326
+ lvef_40 : float
327
+ Probability (0-1) of LVEF < 40%
328
+ lvef_50 : float
329
+ Probability (0-1) of LVEF < 50%
330
+ afib_5y : float
331
+ Probability (0-1) of AFib within 5 years
332
+
333
+ Returns
334
+ -------
335
+ plt.Figure
336
+ Matplotlib figure with 3 semicircular gauges
337
+ """
338
+ # Clamp values to [0, 1]
339
+ lvef_40 = np.clip(lvef_40, 0, 1)
340
+ lvef_50 = np.clip(lvef_50, 0, 1)
341
+ afib_5y = np.clip(afib_5y, 0, 1)
342
+
343
+ # Create figure with 3 subplots
344
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
345
+ fig.suptitle('Risk Assessment', fontsize=14, fontweight='bold', y=0.98)
346
+
347
+ # Draw each gauge
348
+ _draw_gauge(axes[0], lvef_40, 'LVEF < 40%')
349
+ _draw_gauge(axes[1], lvef_50, 'LVEF < 50%')
350
+ _draw_gauge(axes[2], afib_5y, 'AFib (5-year)')
351
+
352
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
353
+ return fig
354
+
355
+
356
+ def generate_thumbnail(ecg_signal: np.ndarray, label: str,
357
+ sample_rate: int = 250) -> Image.Image:
358
+ """
359
+ Generate a thumbnail preview image of Lead II for gallery display.
360
+
361
+ Parameters
362
+ ----------
363
+ ecg_signal : np.ndarray
364
+ ECG signal array of shape (12, n_samples) or (n_samples, 12)
365
+ label : str
366
+ Label text to display on thumbnail
367
+ sample_rate : int, optional
368
+ Sampling rate in Hz, default 250
369
+
370
+ Returns
371
+ -------
372
+ PIL.Image.Image
373
+ Thumbnail image approximately 300x150 pixels
374
+ """
375
+ # Ensure correct shape (12, n_samples)
376
+ if ecg_signal.shape[0] != 12:
377
+ if ecg_signal.shape[1] == 12:
378
+ ecg_signal = ecg_signal.T
379
+ else:
380
+ raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}")
381
+
382
+ # Extract Lead II (index 1)
383
+ lead_ii = ecg_signal[1, :]
384
+ n_samples = len(lead_ii)
385
+ time = np.arange(n_samples) / sample_rate
386
+
387
+ # Create figure with appropriate DPI for ~300x150 pixel output
388
+ fig, ax = plt.subplots(figsize=(3, 1.5), dpi=100)
389
+
390
+ # Clean, minimal design
391
+ ax.plot(time, lead_ii, color='#e74c3c', linewidth=1.0)
392
+
393
+ # Set background
394
+ ax.set_facecolor('#fafafa')
395
+ fig.patch.set_facecolor('#fafafa')
396
+
397
+ # Remove axes for clean look
398
+ ax.set_xticks([])
399
+ ax.set_yticks([])
400
+ for spine in ax.spines.values():
401
+ spine.set_visible(False)
402
+
403
+ # Add label
404
+ ax.text(0.02, 0.98, label, transform=ax.transAxes,
405
+ fontsize=8, fontweight='bold', verticalalignment='top',
406
+ color='#2c3e50')
407
+
408
+ # Add "Lead II" indicator
409
+ ax.text(0.98, 0.02, 'Lead II', transform=ax.transAxes,
410
+ fontsize=6, verticalalignment='bottom', horizontalalignment='right',
411
+ color='#7f8c8d')
412
+
413
+ plt.tight_layout(pad=0.2)
414
+
415
+ # Convert to PIL Image
416
+ buf = io.BytesIO()
417
+ fig.savefig(buf, format='png', facecolor=fig.get_facecolor(),
418
+ edgecolor='none', bbox_inches='tight', pad_inches=0.05)
419
+ plt.close(fig)
420
+
421
+ buf.seek(0)
422
+ img = Image.open(buf)
423
+
424
+ # Resize to ensure ~300x150 pixels
425
+ img = img.resize((300, 150), Image.Resampling.LANCZOS)
426
+
427
+ return img
428
+
429
+
430
+ if __name__ == '__main__':
431
+ # Quick test
432
+ print("Visualization module loaded successfully.")
433
+ print(f"Available functions: plot_ecg_waveform, plot_diagnosis_bars, plot_risk_gauges, generate_thumbnail")