Spaces:
Running
Running
Ashkan Taghipour (The University of Western Australia) Claude Opus 4.5 commited on
Commit ·
cd846d7
1
Parent(s): 01070c9
Initial HeartWatch AI demo release
Browse files- Added app.py: Gradio interface with upload and sample gallery tabs
- Added inference.py: DeepECG inference engine with 4 models
- Added visualization.py: ECG waveform, diagnosis bars, risk gauges
- Added class_names.json: 77 ECG diagnosis class names
- Added sample ECG data from MIT-BIH database
- Added requirements.txt and .gitignore
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- .gitignore +59 -0
- README.md +36 -6
- app.py +326 -0
- class_names.json +79 -0
- demo_data/samples/Atrial_Flutter.npy +3 -0
- demo_data/samples/Normal_Sinus_Rhythm.npy +3 -0
- demo_data/samples/Ventricular_Tachycardia.npy +3 -0
- inference.py +414 -0
- requirements.txt +7 -0
- test_inference.py +115 -0
- visualization.py +433 -0
.gitignore
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
ENV/
|
| 26 |
+
env/
|
| 27 |
+
.venv/
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.idea/
|
| 31 |
+
.vscode/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
|
| 36 |
+
# Jupyter
|
| 37 |
+
.ipynb_checkpoints/
|
| 38 |
+
|
| 39 |
+
# Model weights (downloaded at runtime)
|
| 40 |
+
weights/
|
| 41 |
+
*.pt
|
| 42 |
+
*.pth
|
| 43 |
+
*.ckpt
|
| 44 |
+
|
| 45 |
+
# Local development
|
| 46 |
+
.env
|
| 47 |
+
.env.local
|
| 48 |
+
|
| 49 |
+
# OS
|
| 50 |
+
.DS_Store
|
| 51 |
+
Thumbs.db
|
| 52 |
+
|
| 53 |
+
# Logs
|
| 54 |
+
*.log
|
| 55 |
+
logs/
|
| 56 |
+
|
| 57 |
+
# Temporary files
|
| 58 |
+
*.tmp
|
| 59 |
+
*.temp
|
README.md
CHANGED
|
@@ -1,13 +1,43 @@
|
|
| 1 |
---
|
| 2 |
title: HeartWatchAI
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
short_description: AI-
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: HeartWatchAI
|
| 3 |
+
emoji: ❤️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.12.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: AI-powered 12-Lead ECG Analysis
|
| 11 |
+
hf_oauth: false
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# HeartWatch AI
|
| 15 |
+
|
| 16 |
+
AI-powered 12-Lead ECG analysis using deep learning models.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **77-Class ECG Diagnosis**: Detect 77 different cardiac conditions
|
| 21 |
+
- **LVEF Prediction**: Predict left ventricular ejection fraction < 40% and < 50%
|
| 22 |
+
- **AFib Risk**: 5-year atrial fibrillation risk prediction
|
| 23 |
+
- **Interactive Visualization**: Clinical 4x3 lead layout with ECG paper grid
|
| 24 |
+
|
| 25 |
+
## Models
|
| 26 |
+
|
| 27 |
+
This demo uses EfficientNetV2 models from the DeepECG project:
|
| 28 |
+
|
| 29 |
+
- `heartwise/EfficientNetV2_77_Classes`
|
| 30 |
+
- `heartwise/EfficientNetV2_LVEF_40`
|
| 31 |
+
- `heartwise/EfficientNetV2_LVEF_50`
|
| 32 |
+
- `heartwise/EfficientNetV2_AFIB_5y`
|
| 33 |
+
|
| 34 |
+
## Input Format
|
| 35 |
+
|
| 36 |
+
- NumPy array (.npy file)
|
| 37 |
+
- Shape: (2500, 12) or (12, 2500)
|
| 38 |
+
- 12 standard leads: I, II, III, aVR, aVL, aVF, V1-V6
|
| 39 |
+
- 10 seconds at 250 Hz sampling rate
|
| 40 |
+
|
| 41 |
+
## Disclaimer
|
| 42 |
+
|
| 43 |
+
This is a research demonstration tool. Predictions should NOT be used for clinical decision-making. Always consult qualified healthcare professionals for medical advice.
|
app.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HeartWatch AI - ECG Analysis Demo
|
| 3 |
+
==================================
|
| 4 |
+
|
| 5 |
+
A Gradio-based web application for AI-powered ECG analysis using DeepECG models.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- 77-class ECG diagnosis
|
| 9 |
+
- LVEF < 40% prediction
|
| 10 |
+
- LVEF < 50% prediction
|
| 11 |
+
- 5-year AFib risk assessment
|
| 12 |
+
- Interactive 12-lead ECG visualization
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import logging
|
| 17 |
+
import numpy as np
|
| 18 |
+
import gradio as gr
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Local imports
|
| 22 |
+
from inference import DeepECGInference
|
| 23 |
+
from visualization import (
|
| 24 |
+
plot_ecg_waveform,
|
| 25 |
+
plot_diagnosis_bars,
|
| 26 |
+
plot_risk_gauges,
|
| 27 |
+
generate_thumbnail
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(level=logging.INFO)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Global inference engine
|
| 35 |
+
inference_engine = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_inference_engine():
|
| 39 |
+
"""Load the inference engine on startup."""
|
| 40 |
+
global inference_engine
|
| 41 |
+
if inference_engine is None:
|
| 42 |
+
logger.info("Loading DeepECG inference engine...")
|
| 43 |
+
inference_engine = DeepECGInference()
|
| 44 |
+
inference_engine.load_models()
|
| 45 |
+
logger.info("Inference engine loaded successfully")
|
| 46 |
+
return inference_engine
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_sample_ecgs():
|
| 50 |
+
"""Get list of sample ECG files from demo_data directory."""
|
| 51 |
+
sample_dir = Path(__file__).parent / "demo_data" / "samples"
|
| 52 |
+
if not sample_dir.exists():
|
| 53 |
+
return []
|
| 54 |
+
|
| 55 |
+
samples = []
|
| 56 |
+
for npy_file in sorted(sample_dir.glob("*.npy")):
|
| 57 |
+
samples.append({
|
| 58 |
+
"path": str(npy_file),
|
| 59 |
+
"name": npy_file.stem.replace("_", " ").title()
|
| 60 |
+
})
|
| 61 |
+
return samples
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def analyze_ecg(ecg_signal: np.ndarray, filename: str = "Uploaded ECG"):
|
| 65 |
+
"""
|
| 66 |
+
Analyze an ECG signal and return all visualizations.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
ecg_signal: ECG signal array
|
| 70 |
+
filename: Name to display
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Tuple of (ecg_plot, diagnosis_plot, risk_plot, summary_text)
|
| 74 |
+
"""
|
| 75 |
+
engine = load_inference_engine()
|
| 76 |
+
|
| 77 |
+
# Run inference
|
| 78 |
+
results = engine.predict(ecg_signal)
|
| 79 |
+
|
| 80 |
+
# Generate ECG waveform plot
|
| 81 |
+
ecg_fig = plot_ecg_waveform(ecg_signal, sample_rate=250, title=filename)
|
| 82 |
+
|
| 83 |
+
# Generate diagnosis bar chart
|
| 84 |
+
if "diagnosis_77" in results:
|
| 85 |
+
probs = results["diagnosis_77"]["probabilities"]
|
| 86 |
+
class_names = results["diagnosis_77"]["class_names"]
|
| 87 |
+
diagnosis_dict = dict(zip(class_names, probs))
|
| 88 |
+
diagnosis_fig = plot_diagnosis_bars(diagnosis_dict, top_n=10)
|
| 89 |
+
else:
|
| 90 |
+
diagnosis_fig = None
|
| 91 |
+
|
| 92 |
+
# Generate risk gauges
|
| 93 |
+
lvef_40 = results.get("lvef_40", 0.0)
|
| 94 |
+
lvef_50 = results.get("lvef_50", 0.0)
|
| 95 |
+
afib_5y = results.get("afib_5y", 0.0)
|
| 96 |
+
risk_fig = plot_risk_gauges(lvef_40, lvef_50, afib_5y)
|
| 97 |
+
|
| 98 |
+
# Generate summary text
|
| 99 |
+
inference_time = results.get("inference_time_ms", 0)
|
| 100 |
+
summary = f"""## Analysis Summary
|
| 101 |
+
|
| 102 |
+
**Inference Time:** {inference_time:.1f} ms
|
| 103 |
+
|
| 104 |
+
### Risk Predictions
|
| 105 |
+
- **LVEF < 40%:** {lvef_40*100:.1f}% probability
|
| 106 |
+
- **LVEF < 50%:** {lvef_50*100:.1f}% probability
|
| 107 |
+
- **5-year AFib Risk:** {afib_5y*100:.1f}% probability
|
| 108 |
+
|
| 109 |
+
### Top Diagnoses
|
| 110 |
+
"""
|
| 111 |
+
if "diagnosis_77" in results:
|
| 112 |
+
probs = results["diagnosis_77"]["probabilities"]
|
| 113 |
+
class_names = results["diagnosis_77"]["class_names"]
|
| 114 |
+
top_indices = np.argsort(probs)[::-1][:5]
|
| 115 |
+
for i, idx in enumerate(top_indices, 1):
|
| 116 |
+
summary += f"{i}. {class_names[idx]}: {probs[idx]*100:.1f}%\n"
|
| 117 |
+
|
| 118 |
+
return ecg_fig, diagnosis_fig, risk_fig, summary
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def analyze_uploaded_file(file):
|
| 122 |
+
"""Handle uploaded .npy file."""
|
| 123 |
+
if file is None:
|
| 124 |
+
return None, None, None, "Please upload a .npy file containing ECG data."
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
ecg_signal = np.load(file.name)
|
| 128 |
+
filename = Path(file.name).stem
|
| 129 |
+
return analyze_ecg(ecg_signal, filename)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error loading file: {e}")
|
| 132 |
+
return None, None, None, f"Error loading file: {str(e)}"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def analyze_sample(sample_name: str):
|
| 136 |
+
"""Analyze a sample ECG from the gallery."""
|
| 137 |
+
samples = get_sample_ecgs()
|
| 138 |
+
|
| 139 |
+
for sample in samples:
|
| 140 |
+
if sample["name"] == sample_name:
|
| 141 |
+
ecg_signal = np.load(sample["path"])
|
| 142 |
+
return analyze_ecg(ecg_signal, sample["name"])
|
| 143 |
+
|
| 144 |
+
return None, None, None, "Sample not found."
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def create_demo_interface():
|
| 148 |
+
"""Create the Gradio interface."""
|
| 149 |
+
|
| 150 |
+
# Custom CSS for styling
|
| 151 |
+
custom_css = """
|
| 152 |
+
.gradio-container {
|
| 153 |
+
font-family: 'Inter', sans-serif;
|
| 154 |
+
}
|
| 155 |
+
.main-header {
|
| 156 |
+
text-align: center;
|
| 157 |
+
padding: 20px;
|
| 158 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 159 |
+
color: white;
|
| 160 |
+
border-radius: 10px;
|
| 161 |
+
margin-bottom: 20px;
|
| 162 |
+
}
|
| 163 |
+
.main-header h1 {
|
| 164 |
+
margin: 0;
|
| 165 |
+
font-size: 2.5em;
|
| 166 |
+
}
|
| 167 |
+
.main-header p {
|
| 168 |
+
margin: 10px 0 0 0;
|
| 169 |
+
opacity: 0.9;
|
| 170 |
+
}
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
with gr.Blocks(css=custom_css, title="HeartWatch AI") as demo:
|
| 174 |
+
# Header
|
| 175 |
+
gr.HTML("""
|
| 176 |
+
<div class="main-header">
|
| 177 |
+
<h1>HeartWatch AI</h1>
|
| 178 |
+
<p>AI-Powered 12-Lead ECG Analysis</p>
|
| 179 |
+
</div>
|
| 180 |
+
""")
|
| 181 |
+
|
| 182 |
+
gr.Markdown("""
|
| 183 |
+
This demo analyzes 12-lead ECG signals using deep learning models trained on large clinical datasets.
|
| 184 |
+
|
| 185 |
+
**Models:**
|
| 186 |
+
- 77-class ECG diagnosis classifier
|
| 187 |
+
- LVEF < 40% prediction
|
| 188 |
+
- LVEF < 50% prediction
|
| 189 |
+
- 5-year Atrial Fibrillation risk
|
| 190 |
+
|
| 191 |
+
**Note:** This is a research demo. Results should not be used for clinical decision-making.
|
| 192 |
+
""")
|
| 193 |
+
|
| 194 |
+
with gr.Tabs():
|
| 195 |
+
# Tab 1: Upload ECG
|
| 196 |
+
with gr.TabItem("Upload ECG"):
|
| 197 |
+
with gr.Row():
|
| 198 |
+
with gr.Column(scale=1):
|
| 199 |
+
file_input = gr.File(
|
| 200 |
+
label="Upload ECG (.npy file)",
|
| 201 |
+
file_types=[".npy"],
|
| 202 |
+
type="filepath"
|
| 203 |
+
)
|
| 204 |
+
analyze_btn = gr.Button("Analyze ECG", variant="primary")
|
| 205 |
+
|
| 206 |
+
gr.Markdown("""
|
| 207 |
+
**Expected Format:**
|
| 208 |
+
- NumPy array shape: (2500, 12) or (12, 2500)
|
| 209 |
+
- 12 leads: I, II, III, aVR, aVL, aVF, V1-V6
|
| 210 |
+
- 2500 samples (10 seconds at 250 Hz)
|
| 211 |
+
""")
|
| 212 |
+
|
| 213 |
+
with gr.Column(scale=2):
|
| 214 |
+
summary_output = gr.Markdown(label="Summary")
|
| 215 |
+
|
| 216 |
+
with gr.Row():
|
| 217 |
+
ecg_plot = gr.Plot(label="12-Lead ECG")
|
| 218 |
+
|
| 219 |
+
with gr.Row():
|
| 220 |
+
with gr.Column():
|
| 221 |
+
diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
|
| 222 |
+
with gr.Column():
|
| 223 |
+
risk_plot = gr.Plot(label="Risk Assessment")
|
| 224 |
+
|
| 225 |
+
analyze_btn.click(
|
| 226 |
+
fn=analyze_uploaded_file,
|
| 227 |
+
inputs=[file_input],
|
| 228 |
+
outputs=[ecg_plot, diagnosis_plot, risk_plot, summary_output]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Tab 2: Sample Gallery
|
| 232 |
+
with gr.TabItem("Sample Gallery"):
|
| 233 |
+
gr.Markdown("### Select a sample ECG to analyze")
|
| 234 |
+
|
| 235 |
+
samples = get_sample_ecgs()
|
| 236 |
+
if samples:
|
| 237 |
+
sample_names = [s["name"] for s in samples]
|
| 238 |
+
sample_dropdown = gr.Dropdown(
|
| 239 |
+
choices=sample_names,
|
| 240 |
+
label="Select Sample",
|
| 241 |
+
value=sample_names[0] if sample_names else None
|
| 242 |
+
)
|
| 243 |
+
analyze_sample_btn = gr.Button("Analyze Sample", variant="primary")
|
| 244 |
+
|
| 245 |
+
with gr.Row():
|
| 246 |
+
sample_summary = gr.Markdown(label="Summary")
|
| 247 |
+
|
| 248 |
+
with gr.Row():
|
| 249 |
+
sample_ecg_plot = gr.Plot(label="12-Lead ECG")
|
| 250 |
+
|
| 251 |
+
with gr.Row():
|
| 252 |
+
with gr.Column():
|
| 253 |
+
sample_diagnosis_plot = gr.Plot(label="Diagnosis Probabilities")
|
| 254 |
+
with gr.Column():
|
| 255 |
+
sample_risk_plot = gr.Plot(label="Risk Assessment")
|
| 256 |
+
|
| 257 |
+
analyze_sample_btn.click(
|
| 258 |
+
fn=analyze_sample,
|
| 259 |
+
inputs=[sample_dropdown],
|
| 260 |
+
outputs=[sample_ecg_plot, sample_diagnosis_plot, sample_risk_plot, sample_summary]
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
gr.Markdown("*No sample ECGs available. Upload your own in the Upload tab.*")
|
| 264 |
+
|
| 265 |
+
# Tab 3: About
|
| 266 |
+
with gr.TabItem("About"):
|
| 267 |
+
gr.Markdown("""
|
| 268 |
+
## About HeartWatch AI
|
| 269 |
+
|
| 270 |
+
HeartWatch AI is a deep learning-based ECG analysis system that can:
|
| 271 |
+
|
| 272 |
+
### Models
|
| 273 |
+
|
| 274 |
+
1. **77-Class Diagnosis Model**
|
| 275 |
+
- Trained to detect 77 different ECG patterns and conditions
|
| 276 |
+
- Based on EfficientNetV2 architecture
|
| 277 |
+
- Outputs probability for each condition
|
| 278 |
+
|
| 279 |
+
2. **LVEF Prediction Models**
|
| 280 |
+
- LVEF < 40%: Identifies patients with reduced ejection fraction
|
| 281 |
+
- LVEF < 50%: Identifies patients with moderately reduced ejection fraction
|
| 282 |
+
|
| 283 |
+
3. **AFib Risk Model**
|
| 284 |
+
- Predicts 5-year risk of developing Atrial Fibrillation
|
| 285 |
+
|
| 286 |
+
### Technical Details
|
| 287 |
+
|
| 288 |
+
- **Input:** 12-lead ECG, 10 seconds, 250 Hz sampling rate
|
| 289 |
+
- **Architecture:** EfficientNetV2 (TorchScript optimized)
|
| 290 |
+
- **Inference:** CPU-optimized for accessibility
|
| 291 |
+
|
| 292 |
+
### Disclaimer
|
| 293 |
+
|
| 294 |
+
This is a research demonstration tool. The predictions provided should NOT be used
|
| 295 |
+
for clinical decision-making. Always consult qualified healthcare professionals
|
| 296 |
+
for medical advice and diagnosis.
|
| 297 |
+
|
| 298 |
+
### Contact
|
| 299 |
+
|
| 300 |
+
For questions or issues, please visit our GitHub repository.
|
| 301 |
+
""")
|
| 302 |
+
|
| 303 |
+
# Footer
|
| 304 |
+
gr.Markdown("""
|
| 305 |
+
---
|
| 306 |
+
*Built with Gradio and PyTorch. Models from DeepECG project.*
|
| 307 |
+
""")
|
| 308 |
+
|
| 309 |
+
return demo
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# Create and launch the demo
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
# Pre-load the inference engine
|
| 315 |
+
try:
|
| 316 |
+
load_inference_engine()
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.warning(f"Could not pre-load models: {e}")
|
| 319 |
+
|
| 320 |
+
# Create and launch demo
|
| 321 |
+
demo = create_demo_interface()
|
| 322 |
+
demo.launch(
|
| 323 |
+
server_name="0.0.0.0",
|
| 324 |
+
server_port=7860,
|
| 325 |
+
share=False
|
| 326 |
+
)
|
class_names.json
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"Sinusal",
|
| 3 |
+
"Regular",
|
| 4 |
+
"Monomorph",
|
| 5 |
+
"QS complex in V1-V2-V3",
|
| 6 |
+
"R complex in V5-V6",
|
| 7 |
+
"T wave inversion (inferior - II, III, aVF)",
|
| 8 |
+
"Left bundle branch block",
|
| 9 |
+
"RaVL > 11 mm",
|
| 10 |
+
"SV1 + RV5 or RV6 > 35 mm",
|
| 11 |
+
"T wave inversion (lateral -I, aVL, V5-V6)",
|
| 12 |
+
"T wave inversion (anterior - V3-V4)",
|
| 13 |
+
"Left axis deviation",
|
| 14 |
+
"Left ventricular hypertrophy",
|
| 15 |
+
"Bradycardia",
|
| 16 |
+
"Q wave (inferior - II, III, aVF)",
|
| 17 |
+
"Afib",
|
| 18 |
+
"Irregularly irregular",
|
| 19 |
+
"Atrial tachycardia (>= 100 BPM)",
|
| 20 |
+
"Nonspecific intraventricular conduction delay",
|
| 21 |
+
"Premature ventricular complex",
|
| 22 |
+
"Polymorph",
|
| 23 |
+
"T wave inversion (septal- V1-V2)",
|
| 24 |
+
"Right bundle branch block",
|
| 25 |
+
"Ventricular paced",
|
| 26 |
+
"ST elevation (anterior - V3-V4)",
|
| 27 |
+
"ST elevation (septal - V1-V2)",
|
| 28 |
+
"1st degree AV block",
|
| 29 |
+
"Premature atrial complex",
|
| 30 |
+
"Atrial flutter",
|
| 31 |
+
"rSR' in V1-V2",
|
| 32 |
+
"qRS in V5-V6-I, aVL",
|
| 33 |
+
"Left anterior fascicular block",
|
| 34 |
+
"Right axis deviation",
|
| 35 |
+
"2nd degree AV block - mobitz 1",
|
| 36 |
+
"ST depression (inferior - II, III, aVF)",
|
| 37 |
+
"Acute pericarditis",
|
| 38 |
+
"ST elevation (inferior - II, III, aVF)",
|
| 39 |
+
"Low voltage",
|
| 40 |
+
"Regularly irregular",
|
| 41 |
+
"Junctional rhythm",
|
| 42 |
+
"Left atrial enlargement",
|
| 43 |
+
"ST elevation (lateral - I, aVL, V5-V6)",
|
| 44 |
+
"Atrial paced",
|
| 45 |
+
"Right ventricular hypertrophy",
|
| 46 |
+
"Delta wave",
|
| 47 |
+
"Wolff-Parkinson-White (Pre-excitation syndrome)",
|
| 48 |
+
"Prolonged QT",
|
| 49 |
+
"ST depression (anterior - V3-V4)",
|
| 50 |
+
"QRS complex negative in III",
|
| 51 |
+
"Q wave (lateral- I, aVL, V5-V6)",
|
| 52 |
+
"Supraventricular tachycardia",
|
| 53 |
+
"ST downslopping",
|
| 54 |
+
"ST depression (lateral - I, avL, V5-V6)",
|
| 55 |
+
"2nd degree AV block - mobitz 2",
|
| 56 |
+
"U wave",
|
| 57 |
+
"R/S ratio in V1-V2 >1",
|
| 58 |
+
"RV1 + SV6 > 11 mm",
|
| 59 |
+
"Left posterior fascicular block",
|
| 60 |
+
"Right atrial enlargement",
|
| 61 |
+
"ST depression (septal- V1-V2)",
|
| 62 |
+
"Q wave (septal- V1-V2)",
|
| 63 |
+
"Q wave (anterior - V3-V4)",
|
| 64 |
+
"ST upslopping",
|
| 65 |
+
"Right superior axis",
|
| 66 |
+
"Ventricular tachycardia",
|
| 67 |
+
"ST elevation (posterior - V7-V8-V9)",
|
| 68 |
+
"Ectopic atrial rhythm (< 100 BPM)",
|
| 69 |
+
"Lead misplacement",
|
| 70 |
+
"Third Degree AV Block",
|
| 71 |
+
"Acute MI",
|
| 72 |
+
"Early repolarization",
|
| 73 |
+
"Q wave (posterior - V7-V9)",
|
| 74 |
+
"Bi-atrial enlargement",
|
| 75 |
+
"LV pacing",
|
| 76 |
+
"Brugada",
|
| 77 |
+
"Ventricular Rhythm",
|
| 78 |
+
"no_qrs"
|
| 79 |
+
]
|
demo_data/samples/Atrial_Flutter.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8704d5fb100d5a9080640587aef6cb69c725b82b4dbed4ecf72510f67acd92c9
|
| 3 |
+
size 120128
|
demo_data/samples/Normal_Sinus_Rhythm.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a1191aa4d9a7dc79b67fcd2eac378b86456d7b1ed035a443819ed88b0c7065f
|
| 3 |
+
size 120128
|
demo_data/samples/Ventricular_Tachycardia.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a43e7555616d2601daf958140b9770a69522cce0126ed1402f546829646263dc
|
| 3 |
+
size 120128
|
inference.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeepECG Inference Module for HeartWatch AI
|
| 3 |
+
===========================================
|
| 4 |
+
|
| 5 |
+
This module provides CPU-optimized inference for 4 EfficientNetV2 models:
|
| 6 |
+
- 77-class ECG diagnosis
|
| 7 |
+
- LVEF <= 40% prediction
|
| 8 |
+
- LVEF < 50% prediction
|
| 9 |
+
- 5-year AFib risk prediction
|
| 10 |
+
|
| 11 |
+
The preprocessing exactly replicates DeepECG's pipeline:
|
| 12 |
+
1. Load signal as (samples, leads) = (2500, 12)
|
| 13 |
+
2. Transpose to (leads, samples) = (12, 2500)
|
| 14 |
+
3. Apply MHI factor scaling: signal *= (1/0.0048)
|
| 15 |
+
4. Apply sigmoid to model logits
|
| 16 |
+
|
| 17 |
+
Models are downloaded from HuggingFace Hub using HF_TOKEN from environment.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import json
|
| 22 |
+
import time
|
| 23 |
+
import logging
|
| 24 |
+
from typing import Dict, Optional, Any, Union
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from huggingface_hub import snapshot_download
|
| 30 |
+
|
| 31 |
+
# Configure logging
|
| 32 |
+
logging.basicConfig(level=logging.INFO)
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
# CPU optimizations for HuggingFace Spaces (no GPU)
|
| 36 |
+
torch.set_num_threads(2)
|
| 37 |
+
torch.set_flush_denormal(True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DeepECGInference:
|
| 41 |
+
"""
|
| 42 |
+
CPU-optimized inference engine for DeepECG EfficientNetV2 models.
|
| 43 |
+
|
| 44 |
+
Loads 4 models from HuggingFace Hub:
|
| 45 |
+
- heartwise/EfficientNetV2_77_Classes: 77-class ECG diagnosis
|
| 46 |
+
- heartwise/EfficientNetV2_LVEF_40: LVEF <= 40% prediction
|
| 47 |
+
- heartwise/EfficientNetV2_LVEF_50: LVEF < 50% prediction
|
| 48 |
+
- heartwise/EfficientNetV2_AFIB_5y: 5-year AFib risk prediction
|
| 49 |
+
|
| 50 |
+
Attributes:
|
| 51 |
+
device: Always CPU for HF Spaces
|
| 52 |
+
models: Dict containing loaded TorchScript models
|
| 53 |
+
class_names: List of 77 ECG diagnosis class names
|
| 54 |
+
mhi_factor: Scaling factor for signal preprocessing (1/0.0048)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# Model repository mappings
|
| 58 |
+
MODEL_REPOS = {
|
| 59 |
+
"diagnosis_77": "heartwise/EfficientNetV2_77_Classes",
|
| 60 |
+
"lvef_40": "heartwise/EfficientNetV2_LVEF_40",
|
| 61 |
+
"lvef_50": "heartwise/EfficientNetV2_LVEF_50",
|
| 62 |
+
"afib_5y": "heartwise/EfficientNetV2_AFIB_5y",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Expected input specifications
|
| 66 |
+
EXPECTED_LEADS = 12
|
| 67 |
+
EXPECTED_SAMPLES = 2500 # 10 seconds at 250 Hz
|
| 68 |
+
SAMPLING_RATE = 250 # Hz
|
| 69 |
+
|
| 70 |
+
# Preprocessing constants from DeepECG
|
| 71 |
+
MHI_FACTOR = 1 / 0.0048 # ~208.33
|
| 72 |
+
|
| 73 |
+
def __init__(self, cache_dir: Optional[str] = None):
|
| 74 |
+
"""
|
| 75 |
+
Initialize the inference engine.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
cache_dir: Directory to cache downloaded models.
|
| 79 |
+
Defaults to ./weights
|
| 80 |
+
"""
|
| 81 |
+
self.device = torch.device("cpu")
|
| 82 |
+
self.cache_dir = cache_dir or os.path.join(os.getcwd(), "weights")
|
| 83 |
+
self.models: Dict[str, torch.jit.ScriptModule] = {}
|
| 84 |
+
self.class_names: list = []
|
| 85 |
+
self._load_class_names()
|
| 86 |
+
|
| 87 |
+
def _load_class_names(self) -> None:
|
| 88 |
+
"""Load the 77 ECG class names from class_names.json."""
|
| 89 |
+
class_names_path = os.path.join(
|
| 90 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 91 |
+
"class_names.json"
|
| 92 |
+
)
|
| 93 |
+
try:
|
| 94 |
+
with open(class_names_path, "r") as f:
|
| 95 |
+
self.class_names = json.load(f)
|
| 96 |
+
logger.info(f"Loaded {len(self.class_names)} class names")
|
| 97 |
+
except FileNotFoundError:
|
| 98 |
+
logger.warning(f"class_names.json not found at {class_names_path}")
|
| 99 |
+
self.class_names = []
|
| 100 |
+
|
| 101 |
+
def _get_hf_token(self) -> Optional[str]:
|
| 102 |
+
"""Get HuggingFace token from environment variable."""
|
| 103 |
+
token = os.environ.get("HF_TOKEN")
|
| 104 |
+
if not token:
|
| 105 |
+
logger.warning("HF_TOKEN environment variable not set")
|
| 106 |
+
return token
|
| 107 |
+
|
| 108 |
+
def _download_model(self, repo_id: str, model_name: str) -> str:
|
| 109 |
+
"""
|
| 110 |
+
Download model from HuggingFace Hub.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
repo_id: HuggingFace repository ID
|
| 114 |
+
model_name: Local name for the model
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Path to the downloaded model directory
|
| 118 |
+
"""
|
| 119 |
+
local_dir = os.path.join(self.cache_dir, model_name)
|
| 120 |
+
|
| 121 |
+
if os.path.exists(local_dir):
|
| 122 |
+
logger.info(f"Model {model_name} already cached at {local_dir}")
|
| 123 |
+
return local_dir
|
| 124 |
+
|
| 125 |
+
logger.info(f"Downloading {repo_id} to {local_dir}")
|
| 126 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 127 |
+
|
| 128 |
+
hf_token = self._get_hf_token()
|
| 129 |
+
local_dir = snapshot_download(
|
| 130 |
+
repo_id=repo_id,
|
| 131 |
+
local_dir=local_dir,
|
| 132 |
+
repo_type="model",
|
| 133 |
+
token=hf_token
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
logger.info(f"Downloaded {repo_id} to {local_dir}")
|
| 137 |
+
return local_dir
|
| 138 |
+
|
| 139 |
+
def _load_model_from_dir(self, model_dir: str) -> torch.jit.ScriptModule:
|
| 140 |
+
"""
|
| 141 |
+
Load a TorchScript model from a directory.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
model_dir: Directory containing the .pt file
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Loaded TorchScript model
|
| 148 |
+
|
| 149 |
+
Raises:
|
| 150 |
+
ValueError: If no .pt file is found in the directory
|
| 151 |
+
"""
|
| 152 |
+
pt_file = next(
|
| 153 |
+
(f for f in os.listdir(model_dir) if f.endswith('.pt')),
|
| 154 |
+
None
|
| 155 |
+
)
|
| 156 |
+
if not pt_file:
|
| 157 |
+
raise ValueError(f"No .pt file found in {model_dir}")
|
| 158 |
+
|
| 159 |
+
model_path = os.path.join(model_dir, pt_file)
|
| 160 |
+
model = torch.jit.load(model_path, map_location=self.device)
|
| 161 |
+
model.eval()
|
| 162 |
+
|
| 163 |
+
return model
|
| 164 |
+
|
| 165 |
+
def load_models(self) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Download and load all 4 models from HuggingFace Hub.
|
| 168 |
+
|
| 169 |
+
Uses HF_TOKEN from os.environ for authentication.
|
| 170 |
+
Models are loaded in eval mode on CPU.
|
| 171 |
+
"""
|
| 172 |
+
logger.info("Loading DeepECG models...")
|
| 173 |
+
|
| 174 |
+
for model_key, repo_id in self.MODEL_REPOS.items():
|
| 175 |
+
try:
|
| 176 |
+
model_dir = self._download_model(repo_id, model_key)
|
| 177 |
+
self.models[model_key] = self._load_model_from_dir(model_dir)
|
| 178 |
+
logger.info(f"Loaded model: {model_key} from {repo_id}")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"Failed to load {model_key}: {e}")
|
| 181 |
+
raise
|
| 182 |
+
|
| 183 |
+
logger.info(f"Successfully loaded {len(self.models)} models")
|
| 184 |
+
|
| 185 |
+
def preprocess_ecg(
|
| 186 |
+
self,
|
| 187 |
+
ecg_signal: Union[np.ndarray, torch.Tensor]
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
"""
|
| 190 |
+
Preprocess ECG signal to match DeepECG's exact preprocessing.
|
| 191 |
+
|
| 192 |
+
The preprocessing pipeline:
|
| 193 |
+
1. Ensure signal is numpy array with correct shape
|
| 194 |
+
2. Handle shape: expect (samples, leads) = (2500, 12) or (12, 2500)
|
| 195 |
+
3. Transpose to (leads, samples) = (12, 2500) if needed
|
| 196 |
+
4. Convert to float32 tensor
|
| 197 |
+
5. Add batch dimension: (1, 12, 2500)
|
| 198 |
+
6. Apply MHI factor scaling: signal *= (1/0.0048)
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
ecg_signal: Raw ECG signal, shape (samples, leads) or (leads, samples)
|
| 202 |
+
Expected: 12 leads, 2500 samples (10s at 250Hz)
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Preprocessed tensor ready for model inference, shape (1, 12, 2500)
|
| 206 |
+
|
| 207 |
+
Raises:
|
| 208 |
+
ValueError: If signal shape is invalid
|
| 209 |
+
"""
|
| 210 |
+
# Convert to numpy if tensor
|
| 211 |
+
if isinstance(ecg_signal, torch.Tensor):
|
| 212 |
+
ecg_signal = ecg_signal.numpy()
|
| 213 |
+
|
| 214 |
+
# Ensure float32
|
| 215 |
+
ecg_signal = ecg_signal.astype(np.float32)
|
| 216 |
+
|
| 217 |
+
# Handle shape - expect (samples, leads) = (2500, 12) or (12, 2500)
|
| 218 |
+
if ecg_signal.ndim != 2:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"Expected 2D signal, got shape {ecg_signal.shape}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Determine orientation and transpose if needed
|
| 224 |
+
# If shape is (samples, leads) = (2500, 12), transpose to (12, 2500)
|
| 225 |
+
# If shape is (12, 2500), it's already correct
|
| 226 |
+
if ecg_signal.shape[0] == self.EXPECTED_SAMPLES and ecg_signal.shape[1] == self.EXPECTED_LEADS:
|
| 227 |
+
# Shape is (2500, 12) -> transpose to (12, 2500)
|
| 228 |
+
ecg_signal = ecg_signal.T
|
| 229 |
+
elif ecg_signal.shape[0] == self.EXPECTED_LEADS and ecg_signal.shape[1] == self.EXPECTED_SAMPLES:
|
| 230 |
+
# Shape is already (12, 2500)
|
| 231 |
+
pass
|
| 232 |
+
else:
|
| 233 |
+
# Try to infer orientation
|
| 234 |
+
if ecg_signal.shape[1] == self.EXPECTED_LEADS:
|
| 235 |
+
ecg_signal = ecg_signal.T
|
| 236 |
+
elif ecg_signal.shape[0] != self.EXPECTED_LEADS:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Invalid signal shape {ecg_signal.shape}. "
|
| 239 |
+
f"Expected (2500, 12) or (12, 2500)"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Verify final shape
|
| 243 |
+
if ecg_signal.shape[0] != self.EXPECTED_LEADS:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
f"Signal must have {self.EXPECTED_LEADS} leads, "
|
| 246 |
+
f"got {ecg_signal.shape[0]}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Convert to tensor and add batch dimension
|
| 250 |
+
signal_tensor = torch.from_numpy(ecg_signal).float()
|
| 251 |
+
signal_tensor = signal_tensor.unsqueeze(0) # (1, 12, samples)
|
| 252 |
+
|
| 253 |
+
# Move to device (CPU)
|
| 254 |
+
signal_tensor = signal_tensor.to(self.device)
|
| 255 |
+
|
| 256 |
+
# Apply MHI factor scaling (this is done in model __call__ in DeepECG)
|
| 257 |
+
signal_tensor = signal_tensor * self.MHI_FACTOR
|
| 258 |
+
|
| 259 |
+
return signal_tensor
|
| 260 |
+
|
| 261 |
+
def predict(
|
| 262 |
+
self,
|
| 263 |
+
ecg_signal: Union[np.ndarray, torch.Tensor]
|
| 264 |
+
) -> Dict[str, Any]:
|
| 265 |
+
"""
|
| 266 |
+
Run inference on an ECG signal using all 4 models.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
ecg_signal: Raw ECG signal, shape (samples, leads) or (leads, samples)
|
| 270 |
+
Expected: 12 leads, 2500 samples (10s at 250Hz)
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Dictionary containing:
|
| 274 |
+
- diagnosis_77: Dict with 'probabilities' (77 floats) and 'class_names'
|
| 275 |
+
- lvef_40: Probability of LVEF <= 40%
|
| 276 |
+
- lvef_50: Probability of LVEF < 50%
|
| 277 |
+
- afib_5y: Probability of AFib within 5 years
|
| 278 |
+
- inference_time_ms: Total inference time in milliseconds
|
| 279 |
+
"""
|
| 280 |
+
if not self.models:
|
| 281 |
+
raise RuntimeError("Models not loaded. Call load_models() first.")
|
| 282 |
+
|
| 283 |
+
start_time = time.time()
|
| 284 |
+
|
| 285 |
+
# Preprocess the signal
|
| 286 |
+
signal_tensor = self.preprocess_ecg(ecg_signal)
|
| 287 |
+
|
| 288 |
+
results = {}
|
| 289 |
+
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
# 77-class diagnosis
|
| 292 |
+
if "diagnosis_77" in self.models:
|
| 293 |
+
logits = self.models["diagnosis_77"](signal_tensor)
|
| 294 |
+
probs = torch.sigmoid(logits)
|
| 295 |
+
probs_list = probs.squeeze().cpu().numpy().tolist()
|
| 296 |
+
results["diagnosis_77"] = {
|
| 297 |
+
"probabilities": probs_list,
|
| 298 |
+
"class_names": self.class_names if self.class_names else None,
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
# LVEF <= 40%
|
| 302 |
+
if "lvef_40" in self.models:
|
| 303 |
+
logits = self.models["lvef_40"](signal_tensor)
|
| 304 |
+
prob = torch.sigmoid(logits)
|
| 305 |
+
results["lvef_40"] = float(prob.squeeze().cpu().numpy())
|
| 306 |
+
|
| 307 |
+
# LVEF < 50%
|
| 308 |
+
if "lvef_50" in self.models:
|
| 309 |
+
logits = self.models["lvef_50"](signal_tensor)
|
| 310 |
+
prob = torch.sigmoid(logits)
|
| 311 |
+
results["lvef_50"] = float(prob.squeeze().cpu().numpy())
|
| 312 |
+
|
| 313 |
+
# 5-year AFib risk
|
| 314 |
+
if "afib_5y" in self.models:
|
| 315 |
+
logits = self.models["afib_5y"](signal_tensor)
|
| 316 |
+
prob = torch.sigmoid(logits)
|
| 317 |
+
results["afib_5y"] = float(prob.squeeze().cpu().numpy())
|
| 318 |
+
|
| 319 |
+
end_time = time.time()
|
| 320 |
+
results["inference_time_ms"] = (end_time - start_time) * 1000
|
| 321 |
+
|
| 322 |
+
return results
|
| 323 |
+
|
| 324 |
+
def predict_diagnosis_top_k(
|
| 325 |
+
self,
|
| 326 |
+
ecg_signal: Union[np.ndarray, torch.Tensor],
|
| 327 |
+
k: int = 5
|
| 328 |
+
) -> Dict[str, Any]:
|
| 329 |
+
"""
|
| 330 |
+
Get top-k diagnoses from the 77-class model.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
ecg_signal: Raw ECG signal
|
| 334 |
+
k: Number of top predictions to return
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Dictionary with top-k predictions sorted by probability
|
| 338 |
+
"""
|
| 339 |
+
results = self.predict(ecg_signal)
|
| 340 |
+
|
| 341 |
+
if "diagnosis_77" not in results:
|
| 342 |
+
raise RuntimeError("77-class diagnosis model not loaded")
|
| 343 |
+
|
| 344 |
+
probs = results["diagnosis_77"]["probabilities"]
|
| 345 |
+
class_names = results["diagnosis_77"]["class_names"] or [f"Class_{i}" for i in range(77)]
|
| 346 |
+
|
| 347 |
+
# Get top-k indices
|
| 348 |
+
top_k_indices = np.argsort(probs)[::-1][:k]
|
| 349 |
+
|
| 350 |
+
top_k_predictions = [
|
| 351 |
+
{
|
| 352 |
+
"class_name": class_names[idx],
|
| 353 |
+
"probability": probs[idx],
|
| 354 |
+
"class_index": int(idx)
|
| 355 |
+
}
|
| 356 |
+
for idx in top_k_indices
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
return {
|
| 360 |
+
"top_k_predictions": top_k_predictions,
|
| 361 |
+
"inference_time_ms": results["inference_time_ms"]
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_inference_engine(cache_dir: Optional[str] = None) -> DeepECGInference:
|
| 366 |
+
"""
|
| 367 |
+
Factory function to create and initialize a DeepECGInference instance.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
cache_dir: Optional directory to cache models
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Initialized DeepECGInference with models loaded
|
| 374 |
+
"""
|
| 375 |
+
engine = DeepECGInference(cache_dir=cache_dir)
|
| 376 |
+
engine.load_models()
|
| 377 |
+
return engine
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
if __name__ == "__main__":
|
| 381 |
+
# Example usage / testing
|
| 382 |
+
print("DeepECG Inference Module")
|
| 383 |
+
print("=" * 50)
|
| 384 |
+
|
| 385 |
+
# Create inference engine
|
| 386 |
+
engine = DeepECGInference()
|
| 387 |
+
|
| 388 |
+
# Load models (requires HF_TOKEN environment variable)
|
| 389 |
+
try:
|
| 390 |
+
engine.load_models()
|
| 391 |
+
print("Models loaded successfully!")
|
| 392 |
+
|
| 393 |
+
# Create dummy signal for testing
|
| 394 |
+
dummy_signal = np.random.randn(2500, 12).astype(np.float32)
|
| 395 |
+
|
| 396 |
+
# Run inference
|
| 397 |
+
results = engine.predict(dummy_signal)
|
| 398 |
+
|
| 399 |
+
print(f"\nInference time: {results['inference_time_ms']:.2f} ms")
|
| 400 |
+
print(f"LVEF <= 40%: {results['lvef_40']:.4f}")
|
| 401 |
+
print(f"LVEF < 50%: {results['lvef_50']:.4f}")
|
| 402 |
+
print(f"5-year AFib risk: {results['afib_5y']:.4f}")
|
| 403 |
+
print(f"77-class diagnosis: {len(results['diagnosis_77']['probabilities'])} classes")
|
| 404 |
+
|
| 405 |
+
# Get top-5 diagnoses
|
| 406 |
+
top_5 = engine.predict_diagnosis_top_k(dummy_signal, k=5)
|
| 407 |
+
print("\nTop 5 diagnoses:")
|
| 408 |
+
for pred in top_5["top_k_predictions"]:
|
| 409 |
+
print(f" {pred['class_name']}: {pred['probability']:.4f}")
|
| 410 |
+
|
| 411 |
+
except Exception as e:
|
| 412 |
+
print(f"Error: {e}")
|
| 413 |
+
print("\nMake sure HF_TOKEN environment variable is set:")
|
| 414 |
+
print(" export HF_TOKEN='your_huggingface_token'")
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
numpy>=1.21.0
|
| 4 |
+
matplotlib>=3.5.0
|
| 5 |
+
Pillow>=9.0.0
|
| 6 |
+
huggingface_hub>=0.16.0
|
| 7 |
+
scipy>=1.7.0
|
test_inference.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for DeepECG Inference Module
|
| 4 |
+
=========================================
|
| 5 |
+
|
| 6 |
+
Run this script to verify the inference engine works correctly.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# Set HF_TOKEN environment variable first
|
| 10 |
+
export HF_TOKEN='your_huggingface_token'
|
| 11 |
+
|
| 12 |
+
# Run the test
|
| 13 |
+
python test_inference.py
|
| 14 |
+
|
| 15 |
+
Expected output:
|
| 16 |
+
- Models download from HuggingFace Hub
|
| 17 |
+
- Dummy signal inference completes
|
| 18 |
+
- Results for all 4 models are printed
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
# Ensure HF_TOKEN is set
|
| 26 |
+
if not os.environ.get("HF_TOKEN"):
|
| 27 |
+
print("ERROR: HF_TOKEN environment variable not set")
|
| 28 |
+
print("Please run: export HF_TOKEN='your_token'")
|
| 29 |
+
sys.exit(1)
|
| 30 |
+
|
| 31 |
+
print("=" * 60)
|
| 32 |
+
print("DeepECG Inference Test")
|
| 33 |
+
print("=" * 60)
|
| 34 |
+
|
| 35 |
+
# Import the inference module
|
| 36 |
+
try:
|
| 37 |
+
from inference import DeepECGInference
|
| 38 |
+
print("[OK] Import successful")
|
| 39 |
+
except ImportError as e:
|
| 40 |
+
print(f"[FAIL] Import failed: {e}")
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
|
| 43 |
+
# Create inference engine
|
| 44 |
+
try:
|
| 45 |
+
engine = DeepECGInference()
|
| 46 |
+
print(f"[OK] Engine created with {len(engine.class_names)} class names")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"[FAIL] Engine creation failed: {e}")
|
| 49 |
+
sys.exit(1)
|
| 50 |
+
|
| 51 |
+
# Load models
|
| 52 |
+
print("\nLoading models from HuggingFace Hub...")
|
| 53 |
+
try:
|
| 54 |
+
engine.load_models()
|
| 55 |
+
print(f"[OK] Loaded {len(engine.models)} models")
|
| 56 |
+
for name in engine.models:
|
| 57 |
+
print(f" - {name}")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"[FAIL] Model loading failed: {e}")
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
# Test with dummy signal
|
| 63 |
+
print("\nTesting inference with dummy signal...")
|
| 64 |
+
try:
|
| 65 |
+
# Create dummy 10-second ECG (2500 samples at 250Hz, 12 leads)
|
| 66 |
+
dummy_signal = np.random.randn(2500, 12).astype(np.float32)
|
| 67 |
+
|
| 68 |
+
# Run inference
|
| 69 |
+
results = engine.predict(dummy_signal)
|
| 70 |
+
|
| 71 |
+
print(f"[OK] Inference completed in {results['inference_time_ms']:.2f} ms")
|
| 72 |
+
print(f"\nResults:")
|
| 73 |
+
print(f" - LVEF <= 40%: {results['lvef_40']:.4f}")
|
| 74 |
+
print(f" - LVEF < 50%: {results['lvef_50']:.4f}")
|
| 75 |
+
print(f" - 5-year AFib: {results['afib_5y']:.4f}")
|
| 76 |
+
print(f" - 77-class diagnosis: {len(results['diagnosis_77']['probabilities'])} probabilities")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"[FAIL] Inference failed: {e}")
|
| 80 |
+
import traceback
|
| 81 |
+
traceback.print_exc()
|
| 82 |
+
sys.exit(1)
|
| 83 |
+
|
| 84 |
+
# Test top-k predictions
|
| 85 |
+
print("\nTop 5 diagnoses:")
|
| 86 |
+
try:
|
| 87 |
+
top_5 = engine.predict_diagnosis_top_k(dummy_signal, k=5)
|
| 88 |
+
for pred in top_5["top_k_predictions"]:
|
| 89 |
+
print(f" {pred['class_name']}: {pred['probability']:.4f}")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"[FAIL] Top-k prediction failed: {e}")
|
| 92 |
+
sys.exit(1)
|
| 93 |
+
|
| 94 |
+
# Test preprocessing edge cases
|
| 95 |
+
print("\nTesting preprocessing with different input shapes...")
|
| 96 |
+
try:
|
| 97 |
+
# Test (2500, 12) shape
|
| 98 |
+
signal_1 = np.random.randn(2500, 12).astype(np.float32)
|
| 99 |
+
tensor_1 = engine.preprocess_ecg(signal_1)
|
| 100 |
+
assert tensor_1.shape == (1, 12, 2500), f"Expected (1, 12, 2500), got {tensor_1.shape}"
|
| 101 |
+
print(f"[OK] Shape (2500, 12) -> {tuple(tensor_1.shape)}")
|
| 102 |
+
|
| 103 |
+
# Test (12, 2500) shape
|
| 104 |
+
signal_2 = np.random.randn(12, 2500).astype(np.float32)
|
| 105 |
+
tensor_2 = engine.preprocess_ecg(signal_2)
|
| 106 |
+
assert tensor_2.shape == (1, 12, 2500), f"Expected (1, 12, 2500), got {tensor_2.shape}"
|
| 107 |
+
print(f"[OK] Shape (12, 2500) -> {tuple(tensor_2.shape)}")
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"[FAIL] Preprocessing test failed: {e}")
|
| 111 |
+
sys.exit(1)
|
| 112 |
+
|
| 113 |
+
print("\n" + "=" * 60)
|
| 114 |
+
print("ALL TESTS PASSED!")
|
| 115 |
+
print("=" * 60)
|
visualization.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HeartWatch AI Visualization Module
|
| 3 |
+
|
| 4 |
+
This module provides visualization functions for ECG analysis including:
|
| 5 |
+
- 12-lead ECG waveform plotting with clinical layout
|
| 6 |
+
- Diagnosis probability bar charts
|
| 7 |
+
- Risk assessment gauges
|
| 8 |
+
- ECG thumbnail generation for galleries
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib.patches as mpatches
|
| 14 |
+
from matplotlib.patches import Wedge
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import io
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Standard 12-lead ECG names in clinical order
|
| 20 |
+
LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
|
| 21 |
+
|
| 22 |
+
# Clinical layout: 4 columns x 3 rows
|
| 23 |
+
# Col 1: I, II, III | Col 2: aVR, aVL, aVF | Col 3: V1, V2, V3 | Col 4: V4, V5, V6
|
| 24 |
+
LEAD_LAYOUT = [
|
| 25 |
+
['I', 'aVR', 'V1', 'V4'],
|
| 26 |
+
['II', 'aVL', 'V2', 'V5'],
|
| 27 |
+
['III', 'aVF', 'V3', 'V6']
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def plot_ecg_waveform(ecg_signal: np.ndarray, sample_rate: int = 250,
|
| 32 |
+
title: str = "12-Lead ECG") -> plt.Figure:
|
| 33 |
+
"""
|
| 34 |
+
Plot a 12-lead ECG waveform in clinical layout format.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
ecg_signal : np.ndarray
|
| 39 |
+
ECG signal array of shape (12, n_samples) or (n_samples, 12)
|
| 40 |
+
Each row/column represents one of the 12 standard leads
|
| 41 |
+
sample_rate : int, optional
|
| 42 |
+
Sampling rate in Hz, default 250
|
| 43 |
+
title : str, optional
|
| 44 |
+
Figure title, default "12-Lead ECG"
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
plt.Figure
|
| 49 |
+
Matplotlib figure with 4x3 ECG layout
|
| 50 |
+
"""
|
| 51 |
+
# Ensure correct shape (12, n_samples)
|
| 52 |
+
if ecg_signal.shape[0] != 12:
|
| 53 |
+
if ecg_signal.shape[1] == 12:
|
| 54 |
+
ecg_signal = ecg_signal.T
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}")
|
| 57 |
+
|
| 58 |
+
n_samples = ecg_signal.shape[1]
|
| 59 |
+
|
| 60 |
+
# 2.5 seconds per column
|
| 61 |
+
samples_per_col = int(2.5 * sample_rate)
|
| 62 |
+
|
| 63 |
+
# Create figure with clinical dimensions
|
| 64 |
+
fig, axes = plt.subplots(3, 4, figsize=(14, 8))
|
| 65 |
+
fig.suptitle(title, fontsize=14, fontweight='bold', y=0.98)
|
| 66 |
+
|
| 67 |
+
# Create lead index mapping
|
| 68 |
+
lead_to_idx = {name: i for i, name in enumerate(LEAD_NAMES)}
|
| 69 |
+
|
| 70 |
+
for row in range(3):
|
| 71 |
+
for col in range(4):
|
| 72 |
+
ax = axes[row, col]
|
| 73 |
+
lead_name = LEAD_LAYOUT[row][col]
|
| 74 |
+
lead_idx = lead_to_idx[lead_name]
|
| 75 |
+
|
| 76 |
+
# Get signal segment for this column (2.5 sec)
|
| 77 |
+
start_sample = 0
|
| 78 |
+
end_sample = min(samples_per_col, n_samples)
|
| 79 |
+
|
| 80 |
+
signal_segment = ecg_signal[lead_idx, start_sample:end_sample]
|
| 81 |
+
time_segment = np.arange(len(signal_segment)) / sample_rate
|
| 82 |
+
|
| 83 |
+
# Set up ECG paper grid background (pink/red)
|
| 84 |
+
ax.set_facecolor('#fff5f5')
|
| 85 |
+
|
| 86 |
+
# Major grid (0.5 sec, 0.5 mV equivalent)
|
| 87 |
+
ax.set_axisbelow(True)
|
| 88 |
+
ax.grid(True, which='major', color='#ffcccc', linewidth=0.8, linestyle='-')
|
| 89 |
+
ax.grid(True, which='minor', color='#ffe6e6', linewidth=0.4, linestyle='-')
|
| 90 |
+
|
| 91 |
+
# Set tick spacing for major/minor grids
|
| 92 |
+
ax.set_xticks(np.arange(0, 2.6, 0.5))
|
| 93 |
+
ax.set_xticks(np.arange(0, 2.6, 0.1), minor=True)
|
| 94 |
+
|
| 95 |
+
# Calculate y-limits based on signal range
|
| 96 |
+
signal_min, signal_max = signal_segment.min(), signal_segment.max()
|
| 97 |
+
signal_range = signal_max - signal_min
|
| 98 |
+
if signal_range < 0.1:
|
| 99 |
+
signal_range = 2.0 # Default range if signal is flat
|
| 100 |
+
padding = signal_range * 0.1
|
| 101 |
+
y_min = signal_min - padding
|
| 102 |
+
y_max = signal_max + padding
|
| 103 |
+
|
| 104 |
+
# Set y-ticks for grid
|
| 105 |
+
y_tick_spacing = signal_range / 4
|
| 106 |
+
ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing, y_tick_spacing))
|
| 107 |
+
ax.set_yticks(np.arange(y_min, y_max + y_tick_spacing/5, y_tick_spacing/5), minor=True)
|
| 108 |
+
|
| 109 |
+
# Plot ECG waveform
|
| 110 |
+
ax.plot(time_segment, signal_segment, color='black', linewidth=0.8)
|
| 111 |
+
|
| 112 |
+
# Add lead label
|
| 113 |
+
ax.text(0.02, 0.98, lead_name, transform=ax.transAxes,
|
| 114 |
+
fontsize=10, fontweight='bold', verticalalignment='top',
|
| 115 |
+
bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
|
| 116 |
+
edgecolor='none', alpha=0.7))
|
| 117 |
+
|
| 118 |
+
# Set axis limits
|
| 119 |
+
ax.set_xlim(0, 2.5)
|
| 120 |
+
ax.set_ylim(y_min, y_max)
|
| 121 |
+
|
| 122 |
+
# Remove tick labels for cleaner look (except bottom row and left column)
|
| 123 |
+
if row < 2:
|
| 124 |
+
ax.set_xticklabels([])
|
| 125 |
+
else:
|
| 126 |
+
ax.set_xlabel('Time (s)', fontsize=8)
|
| 127 |
+
|
| 128 |
+
if col > 0:
|
| 129 |
+
ax.set_yticklabels([])
|
| 130 |
+
else:
|
| 131 |
+
ax.set_ylabel('Amplitude (mV)', fontsize=8)
|
| 132 |
+
|
| 133 |
+
ax.tick_params(axis='both', which='both', labelsize=6)
|
| 134 |
+
|
| 135 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 136 |
+
return fig
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def plot_diagnosis_bars(diagnosis_77: dict, top_n: int = 10,
|
| 140 |
+
ground_truth: list = None) -> plt.Figure:
|
| 141 |
+
"""
|
| 142 |
+
Plot horizontal bar chart of diagnosis probabilities.
|
| 143 |
+
|
| 144 |
+
Parameters
|
| 145 |
+
----------
|
| 146 |
+
diagnosis_77 : dict
|
| 147 |
+
Dictionary mapping diagnosis names to probabilities (0-1)
|
| 148 |
+
top_n : int, optional
|
| 149 |
+
Number of top diagnoses to display, default 10
|
| 150 |
+
ground_truth : list, optional
|
| 151 |
+
List of ground truth diagnosis names to mark with star
|
| 152 |
+
|
| 153 |
+
Returns
|
| 154 |
+
-------
|
| 155 |
+
plt.Figure
|
| 156 |
+
Matplotlib figure with horizontal bar chart
|
| 157 |
+
"""
|
| 158 |
+
if ground_truth is None:
|
| 159 |
+
ground_truth = []
|
| 160 |
+
|
| 161 |
+
# Sort diagnoses by probability (descending)
|
| 162 |
+
sorted_diagnoses = sorted(diagnosis_77.items(), key=lambda x: x[1], reverse=True)
|
| 163 |
+
top_diagnoses = sorted_diagnoses[:top_n]
|
| 164 |
+
|
| 165 |
+
# Extract names and probabilities
|
| 166 |
+
names = [d[0] for d in top_diagnoses]
|
| 167 |
+
probs = [d[1] for d in top_diagnoses]
|
| 168 |
+
|
| 169 |
+
# Determine colors based on probability thresholds
|
| 170 |
+
colors = []
|
| 171 |
+
for p in probs:
|
| 172 |
+
if p >= 0.7:
|
| 173 |
+
colors.append('#2ecc71') # Green for high confidence
|
| 174 |
+
elif p >= 0.3:
|
| 175 |
+
colors.append('#f1c40f') # Yellow for moderate
|
| 176 |
+
else:
|
| 177 |
+
colors.append('#95a5a6') # Gray for low confidence
|
| 178 |
+
|
| 179 |
+
# Create figure
|
| 180 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 181 |
+
|
| 182 |
+
# Create horizontal bar chart
|
| 183 |
+
y_pos = np.arange(len(names))
|
| 184 |
+
bars = ax.barh(y_pos, probs, color=colors, edgecolor='black', linewidth=0.5)
|
| 185 |
+
|
| 186 |
+
# Add probability labels on bars
|
| 187 |
+
for i, (bar, prob) in enumerate(zip(bars, probs)):
|
| 188 |
+
width = bar.get_width()
|
| 189 |
+
label_x = width + 0.02 if width < 0.85 else width - 0.08
|
| 190 |
+
label_color = 'black' if width < 0.85 else 'white'
|
| 191 |
+
ax.text(label_x, bar.get_y() + bar.get_height()/2,
|
| 192 |
+
f'{prob:.1%}', va='center', fontsize=9, color=label_color)
|
| 193 |
+
|
| 194 |
+
# Mark ground truth with star
|
| 195 |
+
display_names = []
|
| 196 |
+
for name in names:
|
| 197 |
+
if name in ground_truth:
|
| 198 |
+
display_names.append(f'{name} \u2605') # Unicode star
|
| 199 |
+
else:
|
| 200 |
+
display_names.append(name)
|
| 201 |
+
|
| 202 |
+
# Set y-axis labels
|
| 203 |
+
ax.set_yticks(y_pos)
|
| 204 |
+
ax.set_yticklabels(display_names, fontsize=9)
|
| 205 |
+
|
| 206 |
+
# Set axis limits and labels
|
| 207 |
+
ax.set_xlim(0, 1.0)
|
| 208 |
+
ax.set_xlabel('Probability', fontsize=11)
|
| 209 |
+
ax.set_title('Diagnosis Probabilities (Top {})'.format(top_n),
|
| 210 |
+
fontsize=12, fontweight='bold', pad=10)
|
| 211 |
+
|
| 212 |
+
# Add legend
|
| 213 |
+
legend_elements = [
|
| 214 |
+
mpatches.Patch(facecolor='#2ecc71', edgecolor='black', label='High (\u2265 70%)'),
|
| 215 |
+
mpatches.Patch(facecolor='#f1c40f', edgecolor='black', label='Moderate (30-70%)'),
|
| 216 |
+
mpatches.Patch(facecolor='#95a5a6', edgecolor='black', label='Low (< 30%)')
|
| 217 |
+
]
|
| 218 |
+
if ground_truth:
|
| 219 |
+
legend_elements.append(mpatches.Patch(facecolor='white', edgecolor='white',
|
| 220 |
+
label='\u2605 = Ground Truth'))
|
| 221 |
+
ax.legend(handles=legend_elements, loc='lower right', fontsize=8)
|
| 222 |
+
|
| 223 |
+
# Add grid for readability
|
| 224 |
+
ax.xaxis.grid(True, linestyle='--', alpha=0.7)
|
| 225 |
+
ax.set_axisbelow(True)
|
| 226 |
+
|
| 227 |
+
# Invert y-axis so highest probability is at top
|
| 228 |
+
ax.invert_yaxis()
|
| 229 |
+
|
| 230 |
+
plt.tight_layout()
|
| 231 |
+
return fig
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _draw_gauge(ax, value: float, title: str):
|
| 235 |
+
"""
|
| 236 |
+
Draw a semicircular gauge on the given axes.
|
| 237 |
+
|
| 238 |
+
Parameters
|
| 239 |
+
----------
|
| 240 |
+
ax : matplotlib.axes.Axes
|
| 241 |
+
Axes to draw on
|
| 242 |
+
value : float
|
| 243 |
+
Value between 0 and 1 to display
|
| 244 |
+
title : str
|
| 245 |
+
Gauge title
|
| 246 |
+
"""
|
| 247 |
+
# Clear axes
|
| 248 |
+
ax.clear()
|
| 249 |
+
ax.set_xlim(-1.5, 1.5)
|
| 250 |
+
ax.set_ylim(-0.3, 1.3)
|
| 251 |
+
ax.set_aspect('equal')
|
| 252 |
+
ax.axis('off')
|
| 253 |
+
|
| 254 |
+
# Create gradient background arc (Green -> Yellow -> Red)
|
| 255 |
+
n_segments = 100
|
| 256 |
+
for i in range(n_segments):
|
| 257 |
+
theta1 = 180 - i * (180 / n_segments)
|
| 258 |
+
theta2 = 180 - (i + 1) * (180 / n_segments)
|
| 259 |
+
|
| 260 |
+
# Calculate color based on position
|
| 261 |
+
pos = i / n_segments
|
| 262 |
+
if pos < 0.3:
|
| 263 |
+
# Green zone
|
| 264 |
+
color = '#2ecc71'
|
| 265 |
+
elif pos < 0.6:
|
| 266 |
+
# Yellow zone (transition from green to yellow)
|
| 267 |
+
t = (pos - 0.3) / 0.3
|
| 268 |
+
r = int(46 + t * (241 - 46))
|
| 269 |
+
g = int(204 + t * (196 - 204))
|
| 270 |
+
b = int(113 + t * (15 - 113))
|
| 271 |
+
color = f'#{r:02x}{g:02x}{b:02x}'
|
| 272 |
+
else:
|
| 273 |
+
# Red zone (transition from yellow to red)
|
| 274 |
+
t = (pos - 0.6) / 0.4
|
| 275 |
+
r = int(241 + t * (231 - 241))
|
| 276 |
+
g = int(196 - t * 196)
|
| 277 |
+
b = int(15 - t * 15)
|
| 278 |
+
color = f'#{r:02x}{g:02x}{b:02x}'
|
| 279 |
+
|
| 280 |
+
wedge = Wedge((0, 0), 1.0, theta2, theta1, width=0.3, facecolor=color,
|
| 281 |
+
edgecolor='white', linewidth=0.5)
|
| 282 |
+
ax.add_patch(wedge)
|
| 283 |
+
|
| 284 |
+
# Draw needle
|
| 285 |
+
needle_angle = 180 - value * 180
|
| 286 |
+
needle_rad = np.radians(needle_angle)
|
| 287 |
+
needle_length = 0.85
|
| 288 |
+
needle_x = needle_length * np.cos(needle_rad)
|
| 289 |
+
needle_y = needle_length * np.sin(needle_rad)
|
| 290 |
+
|
| 291 |
+
ax.annotate('', xy=(needle_x, needle_y), xytext=(0, 0),
|
| 292 |
+
arrowprops=dict(arrowstyle='->', color='#2c3e50', lw=2))
|
| 293 |
+
|
| 294 |
+
# Draw center circle
|
| 295 |
+
center_circle = plt.Circle((0, 0), 0.1, color='#2c3e50', zorder=5)
|
| 296 |
+
ax.add_patch(center_circle)
|
| 297 |
+
|
| 298 |
+
# Add value text
|
| 299 |
+
ax.text(0, -0.15, f'{value*100:.0f}%', ha='center', va='top',
|
| 300 |
+
fontsize=14, fontweight='bold', color='#2c3e50')
|
| 301 |
+
|
| 302 |
+
# Add title
|
| 303 |
+
ax.text(0, 1.2, title, ha='center', va='bottom',
|
| 304 |
+
fontsize=11, fontweight='bold', color='#2c3e50')
|
| 305 |
+
|
| 306 |
+
# Add risk labels
|
| 307 |
+
ax.text(-1.1, -0.05, 'Low', ha='center', va='top', fontsize=8, color='#27ae60')
|
| 308 |
+
ax.text(0, 1.05, 'Moderate', ha='center', va='bottom', fontsize=8, color='#f39c12')
|
| 309 |
+
ax.text(1.1, -0.05, 'High', ha='center', va='top', fontsize=8, color='#c0392b')
|
| 310 |
+
|
| 311 |
+
# Add threshold markers
|
| 312 |
+
for pct, label in [(0.3, '30%'), (0.6, '60%')]:
|
| 313 |
+
angle = 180 - pct * 180
|
| 314 |
+
rad = np.radians(angle)
|
| 315 |
+
x_outer = 1.05 * np.cos(rad)
|
| 316 |
+
y_outer = 1.05 * np.sin(rad)
|
| 317 |
+
ax.text(x_outer, y_outer, label, ha='center', va='center', fontsize=7, color='#7f8c8d')
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def plot_risk_gauges(lvef_40: float, lvef_50: float, afib_5y: float) -> plt.Figure:
|
| 321 |
+
"""
|
| 322 |
+
Plot risk assessment gauges for LVEF and AFib predictions.
|
| 323 |
+
|
| 324 |
+
Parameters
|
| 325 |
+
----------
|
| 326 |
+
lvef_40 : float
|
| 327 |
+
Probability (0-1) of LVEF < 40%
|
| 328 |
+
lvef_50 : float
|
| 329 |
+
Probability (0-1) of LVEF < 50%
|
| 330 |
+
afib_5y : float
|
| 331 |
+
Probability (0-1) of AFib within 5 years
|
| 332 |
+
|
| 333 |
+
Returns
|
| 334 |
+
-------
|
| 335 |
+
plt.Figure
|
| 336 |
+
Matplotlib figure with 3 semicircular gauges
|
| 337 |
+
"""
|
| 338 |
+
# Clamp values to [0, 1]
|
| 339 |
+
lvef_40 = np.clip(lvef_40, 0, 1)
|
| 340 |
+
lvef_50 = np.clip(lvef_50, 0, 1)
|
| 341 |
+
afib_5y = np.clip(afib_5y, 0, 1)
|
| 342 |
+
|
| 343 |
+
# Create figure with 3 subplots
|
| 344 |
+
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
|
| 345 |
+
fig.suptitle('Risk Assessment', fontsize=14, fontweight='bold', y=0.98)
|
| 346 |
+
|
| 347 |
+
# Draw each gauge
|
| 348 |
+
_draw_gauge(axes[0], lvef_40, 'LVEF < 40%')
|
| 349 |
+
_draw_gauge(axes[1], lvef_50, 'LVEF < 50%')
|
| 350 |
+
_draw_gauge(axes[2], afib_5y, 'AFib (5-year)')
|
| 351 |
+
|
| 352 |
+
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| 353 |
+
return fig
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def generate_thumbnail(ecg_signal: np.ndarray, label: str,
|
| 357 |
+
sample_rate: int = 250) -> Image.Image:
|
| 358 |
+
"""
|
| 359 |
+
Generate a thumbnail preview image of Lead II for gallery display.
|
| 360 |
+
|
| 361 |
+
Parameters
|
| 362 |
+
----------
|
| 363 |
+
ecg_signal : np.ndarray
|
| 364 |
+
ECG signal array of shape (12, n_samples) or (n_samples, 12)
|
| 365 |
+
label : str
|
| 366 |
+
Label text to display on thumbnail
|
| 367 |
+
sample_rate : int, optional
|
| 368 |
+
Sampling rate in Hz, default 250
|
| 369 |
+
|
| 370 |
+
Returns
|
| 371 |
+
-------
|
| 372 |
+
PIL.Image.Image
|
| 373 |
+
Thumbnail image approximately 300x150 pixels
|
| 374 |
+
"""
|
| 375 |
+
# Ensure correct shape (12, n_samples)
|
| 376 |
+
if ecg_signal.shape[0] != 12:
|
| 377 |
+
if ecg_signal.shape[1] == 12:
|
| 378 |
+
ecg_signal = ecg_signal.T
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(f"ECG signal must have 12 leads, got shape {ecg_signal.shape}")
|
| 381 |
+
|
| 382 |
+
# Extract Lead II (index 1)
|
| 383 |
+
lead_ii = ecg_signal[1, :]
|
| 384 |
+
n_samples = len(lead_ii)
|
| 385 |
+
time = np.arange(n_samples) / sample_rate
|
| 386 |
+
|
| 387 |
+
# Create figure with appropriate DPI for ~300x150 pixel output
|
| 388 |
+
fig, ax = plt.subplots(figsize=(3, 1.5), dpi=100)
|
| 389 |
+
|
| 390 |
+
# Clean, minimal design
|
| 391 |
+
ax.plot(time, lead_ii, color='#e74c3c', linewidth=1.0)
|
| 392 |
+
|
| 393 |
+
# Set background
|
| 394 |
+
ax.set_facecolor('#fafafa')
|
| 395 |
+
fig.patch.set_facecolor('#fafafa')
|
| 396 |
+
|
| 397 |
+
# Remove axes for clean look
|
| 398 |
+
ax.set_xticks([])
|
| 399 |
+
ax.set_yticks([])
|
| 400 |
+
for spine in ax.spines.values():
|
| 401 |
+
spine.set_visible(False)
|
| 402 |
+
|
| 403 |
+
# Add label
|
| 404 |
+
ax.text(0.02, 0.98, label, transform=ax.transAxes,
|
| 405 |
+
fontsize=8, fontweight='bold', verticalalignment='top',
|
| 406 |
+
color='#2c3e50')
|
| 407 |
+
|
| 408 |
+
# Add "Lead II" indicator
|
| 409 |
+
ax.text(0.98, 0.02, 'Lead II', transform=ax.transAxes,
|
| 410 |
+
fontsize=6, verticalalignment='bottom', horizontalalignment='right',
|
| 411 |
+
color='#7f8c8d')
|
| 412 |
+
|
| 413 |
+
plt.tight_layout(pad=0.2)
|
| 414 |
+
|
| 415 |
+
# Convert to PIL Image
|
| 416 |
+
buf = io.BytesIO()
|
| 417 |
+
fig.savefig(buf, format='png', facecolor=fig.get_facecolor(),
|
| 418 |
+
edgecolor='none', bbox_inches='tight', pad_inches=0.05)
|
| 419 |
+
plt.close(fig)
|
| 420 |
+
|
| 421 |
+
buf.seek(0)
|
| 422 |
+
img = Image.open(buf)
|
| 423 |
+
|
| 424 |
+
# Resize to ensure ~300x150 pixels
|
| 425 |
+
img = img.resize((300, 150), Image.Resampling.LANCZOS)
|
| 426 |
+
|
| 427 |
+
return img
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
if __name__ == '__main__':
|
| 431 |
+
# Quick test
|
| 432 |
+
print("Visualization module loaded successfully.")
|
| 433 |
+
print(f"Available functions: plot_ecg_waveform, plot_diagnosis_bars, plot_risk_gauges, generate_thumbnail")
|