Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- Dockerfile +22 -0
- README.md +277 -10
- app.py +257 -0
- binary_threshold.json +3 -0
- mi_best.pth +3 -0
- multilabel_best.pth +3 -0
- multilabel_thresholds.json +7 -0
- requirements.txt +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
build-essential \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy requirements first for better caching
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy project files
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
# Expose Streamlit default port (HF Spaces expects 7860)
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
|
| 21 |
+
|
| 22 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true"]
|
README.md
CHANGED
|
@@ -1,10 +1,277 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
pinned: false
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: HMT-ECGNet
|
| 3 |
+
colorFrom: red
|
| 4 |
+
colorTo: blue
|
| 5 |
+
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
---
|
| 10 |
+
# HMT-ECGNet
|
| 11 |
+
**Lightweight Hierarchical Multi-Lead ECG Classification on PTB-XL**
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Overview
|
| 16 |
+
|
| 17 |
+
**HMT-ECGNet** is a **lightweight, hierarchical deep learning system** for automatic ECG interpretation, designed and evaluated on the **PTB-XL** dataset under **strict, leakage-free conditions**.
|
| 18 |
+
|
| 19 |
+
The project demonstrates that **carefully designed, parameter-efficient neural architectures** can achieve **competitive diagnostic performance** compared to large CNNs (e.g., ResNet) while remaining **deployable in real-world clinical and edge environments**.
|
| 20 |
+
|
| 21 |
+
This repository represents an **end-to-end ML system** — from data preprocessing and training to evaluation, inference API, and interactive visualization.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Key Contributions
|
| 26 |
+
|
| 27 |
+
- **Hierarchical multi-lead ECG modeling** (lead-wise → global aggregation)
|
| 28 |
+
- **Sub-million parameter architecture** (~0.34M params)
|
| 29 |
+
- **Strict PTB-XL official splits** (no patient leakage)
|
| 30 |
+
- **Honest evaluation** (no test-set threshold tuning)
|
| 31 |
+
- **End-to-end deployment demo** (FastAPI + Streamlit)
|
| 32 |
+
- **Baseline comparison with ResNet**
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Problem Statement
|
| 37 |
+
|
| 38 |
+
ECG classification is typically addressed using:
|
| 39 |
+
- very large CNNs (10–60M parameters), or
|
| 40 |
+
- Transformer-based architectures with heavy compute requirements.
|
| 41 |
+
|
| 42 |
+
However, such models:
|
| 43 |
+
- are difficult to deploy on **edge / wearable devices**,
|
| 44 |
+
- often over-report performance due to **data leakage**,
|
| 45 |
+
- ignore **realistic performance ceilings** caused by label ambiguity.
|
| 46 |
+
|
| 47 |
+
> **Goal:**
|
| 48 |
+
> Can a **lightweight, hierarchical neural network** achieve strong diagnostic performance on PTB-XL when evaluated correctly?
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Architecture: HMT-ECGNet
|
| 53 |
+
|
| 54 |
+
### High-Level Design
|
| 55 |
+
|
| 56 |
+
``` markdown
|
| 57 |
+
|
| 58 |
+
├─ 12-Lead ECG (10s)
|
| 59 |
+
│
|
| 60 |
+
├─ Shared per-lead temporal encoder
|
| 61 |
+
│
|
| 62 |
+
├─ Lead-wise feature tokens
|
| 63 |
+
│
|
| 64 |
+
├─ Hierarchical cross-lead aggregation
|
| 65 |
+
│
|
| 66 |
+
├─ Global temporal pooling
|
| 67 |
+
│
|
| 68 |
+
└─ Classification head
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+

|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
### Design Principles
|
| 78 |
+
|
| 79 |
+
- **Per-lead temporal modeling** with shared weights
|
| 80 |
+
- **Hierarchical aggregation** instead of heavy attention
|
| 81 |
+
- **Explicit separation of temporal and spatial modeling**
|
| 82 |
+
- **Parameter efficiency first**, accuracy second
|
| 83 |
+
|
| 84 |
+
**Total parameters:** ~**338K**
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Training Protocol
|
| 89 |
+
|
| 90 |
+
- Optimizer: **AdamW**
|
| 91 |
+
- Learning rate schedule: **Cosine Annealing**
|
| 92 |
+
- Loss:
|
| 93 |
+
- Multi-label: `AsymmetricFocalLoss` with class balancing
|
| 94 |
+
- Binary: `BCEWithLogitsLoss`
|
| 95 |
+
- Regularization:
|
| 96 |
+
- Signal preprocessing
|
| 97 |
+
- Early stopping
|
| 98 |
+
- Reproducibility:
|
| 99 |
+
- Fixed random seeds
|
| 100 |
+
- Deterministic splits
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## Results
|
| 105 |
+
|
| 106 |
+
### Multi-Label Classification (Test Set)
|
| 107 |
+
|
| 108 |
+
| Metric | HMT-ECGNet |
|
| 109 |
+
|------|-----------|
|
| 110 |
+
| AUROC (macro) | **≈ 0.92** |
|
| 111 |
+
| AUPRC (macro) | ≈ 0.78 |
|
| 112 |
+
| F1 (macro) | ≈ **0.73** |
|
| 113 |
+
| Parameters | **0.34M** |
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
### Binary Classification — MI vs Normal (Test Set)
|
| 118 |
+
|
| 119 |
+
| Metric | HMT-ECGNet |
|
| 120 |
+
|------|-----------|
|
| 121 |
+
| AUROC | **≈ 0.98** |
|
| 122 |
+
| Accuracy | ≈ 0.92–0.93 |
|
| 123 |
+
| F1 | ≈ **0.89** |
|
| 124 |
+
|
| 125 |
+
**Observation:**
|
| 126 |
+
Accuracy saturates due to ambiguous ECGs, while AUROC remains high — indicating strong class separability under realistic conditions.
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
## Baseline Comparison
|
| 131 |
+
|
| 132 |
+
| Model | Params | AUROC (Multi) | F1 (Multi) |
|
| 133 |
+
|------|--------|--------------|------------|
|
| 134 |
+
| **ResNet-1D** | ~8.7M | ≈ 0.90 | ≈ 0.70 |
|
| 135 |
+
| **HMT-ECGNet (ours)** | **0.34M** | **≈ 0.92** | **≈ 0.73** |
|
| 136 |
+
|
| 137 |
+
**HMT-ECGNet outperforms ResNet while using ~25× fewer parameters**
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## Deployment Demo
|
| 142 |
+
|
| 143 |
+
Due to dataset licensing and size constraints, this project is not deployed as a public live demo.
|
| 144 |
+
|
| 145 |
+
However, the **full inference and visualization pipeline is implemented and reproducible locally**.
|
| 146 |
+
|
| 147 |
+
To launch the interactive ECG visualization and AI diagnosis interface:
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
streamlit run app.py
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
This repository includes a **production-style demo**:
|
| 154 |
+
|
| 155 |
+
- **FastAPI** inference server
|
| 156 |
+
- **Streamlit** UI
|
| 157 |
+
- Live ECG visualization
|
| 158 |
+
- Real-time predictions
|
| 159 |
+
- MI risk screening
|
| 160 |
+
- Uses **unseen PTB-XL test ECGs**
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
[](https://github.com/MahboobAlam0/hmt_ecg_healthmonitoringsystem/issues/1#issue-3938528989)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
## Dataset
|
| 170 |
+
|
| 171 |
+
### PTB-XL (PhysioNet, 2020)
|
| 172 |
+
|
| 173 |
+
- ~21,800 ECG recordings
|
| 174 |
+
- 12 leads
|
| 175 |
+
- 10 seconds per ECG
|
| 176 |
+
- Original sampling: 500 Hz (downsampled during preprocessing)
|
| 177 |
+
- Official **patient-level splits**:
|
| 178 |
+
- Train: folds 1–8
|
| 179 |
+
- Validation: fold 9
|
| 180 |
+
- Test: fold 10
|
| 181 |
+
|
| 182 |
+
### Tasks
|
| 183 |
+
|
| 184 |
+
- **Multi-label classification (5 diagnostic superclasses)**
|
| 185 |
+
- NORM, MI, STTC, CD, HYP
|
| 186 |
+
- **Binary classification**
|
| 187 |
+
- MI vs Normal
|
| 188 |
+
- Normal vs Abnormal
|
| 189 |
+
|
| 190 |
+
⚠️ **Important:**
|
| 191 |
+
All experiments strictly follow official PTB-XL splits.
|
| 192 |
+
There is **no patient leakage**, **no test-set tuning**, and **no post-hoc threshold optimization**.
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
## Error Analysis & Insights
|
| 197 |
+
|
| 198 |
+
- Ensemble models improve **stability**, not accuracy
|
| 199 |
+
- Remaining errors are **systematic**, not variance-driven
|
| 200 |
+
- Confirms a **performance ceiling** on PTB-XL due to:
|
| 201 |
+
- label ambiguity,
|
| 202 |
+
- inter-observer disagreement,
|
| 203 |
+
- borderline ECG patterns
|
| 204 |
+
|
| 205 |
+
---
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
## Project Structure
|
| 209 |
+
|
| 210 |
+
```markdown
|
| 211 |
+
├── hmt_ecgnet/
|
| 212 |
+
├── artifacts/
|
| 213 |
+
│ ├── mi_best.pth
|
| 214 |
+
│ ├── multilabel_best.pth
|
| 215 |
+
│ ├── multilabel_thresholds.json
|
| 216 |
+
│ └── resnet_baseline.pth
|
| 217 |
+
│
|
| 218 |
+
├── .gitignore
|
| 219 |
+
├── models/
|
| 220 |
+
│ ├── hmt_ecgnet.py
|
| 221 |
+
│ └── resnet1d.py
|
| 222 |
+
│
|
| 223 |
+
├── api.py
|
| 224 |
+
├── app.py
|
| 225 |
+
├── dataset.py
|
| 226 |
+
├── train_multilabel.py
|
| 227 |
+
├── train_binary.py
|
| 228 |
+
├── eval_multilabel.py
|
| 229 |
+
├── eval_binary.py
|
| 230 |
+
├── threshold_search.py
|
| 231 |
+
├── threshold_search_multilabel.py
|
| 232 |
+
├── config.py
|
| 233 |
+
└── README.md
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
---
|
| 237 |
+
|
| 238 |
+
## References
|
| 239 |
+
|
| 240 |
+
1. Wagner et al.
|
| 241 |
+
**PTB-XL: A Large Publicly Available Electrocardiography Dataset**
|
| 242 |
+
*PhysioNet, 2020*
|
| 243 |
+
|
| 244 |
+
2. Ribeiro et al.
|
| 245 |
+
**Automatic diagnosis of the 12-lead ECG using deep neural networks**
|
| 246 |
+
*Nature Communications, 2020*
|
| 247 |
+
|
| 248 |
+
3. Hannun et al.
|
| 249 |
+
**Cardiologist-Level Arrhythmia Detection with Deep Neural Networks**
|
| 250 |
+
*Nature Medicine, 2019*
|
| 251 |
+
|
| 252 |
+
4. Rajpurkar et al.
|
| 253 |
+
**Cardiologist-Level Arrhythmia Detection Using Deep Neural Networks**
|
| 254 |
+
*arXiv:1707.01836*
|
| 255 |
+
|
| 256 |
+
5. Tan & Le
|
| 257 |
+
**EfficientNet: Rethinking Model Scaling for CNNs**
|
| 258 |
+
*ICML, 2019*
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
|
| 262 |
+
## Disclaimer
|
| 263 |
+
|
| 264 |
+
This system is **for research and demonstration purposes only**
|
| 265 |
+
and **not intended for clinical diagnosis or treatment**.
|
| 266 |
+
|
| 267 |
+
---
|
| 268 |
+
|
| 269 |
+
## Author Note
|
| 270 |
+
|
| 271 |
+
This project emphasizes:
|
| 272 |
+
- **engineering discipline**
|
| 273 |
+
- **honest evaluation**
|
| 274 |
+
- **deployment realism**
|
| 275 |
+
- and **model efficiency**
|
| 276 |
+
|
| 277 |
+
rather than leaderboard chasing.
|
app.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py — Self-contained Streamlit ECG Diagnostic App (HF Spaces compatible)
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import wfdb
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from scipy.signal import butter, filtfilt, resample
|
| 11 |
+
|
| 12 |
+
# Local imports
|
| 13 |
+
from models.hmt_ecgnet import HMT_ECGNet
|
| 14 |
+
from transforms import preprocess_signal
|
| 15 |
+
from config import N_LEADS
|
| 16 |
+
|
| 17 |
+
# Constants
|
| 18 |
+
DIAG_CLASSES = ["NORM", "MI", "STTC", "CD", "HYP"]
|
| 19 |
+
MI_BINARY_THRESHOLD = 0.05
|
| 20 |
+
|
| 21 |
+
LEAD_NAMES = [
|
| 22 |
+
"I", "II", "III", "aVR", "aVL", "aVF",
|
| 23 |
+
"V1", "V2", "V3", "V4", "V5", "V6",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
FS_ORIG = 500
|
| 27 |
+
FS_TARGET = 100
|
| 28 |
+
DURATION_SEC = 10
|
| 29 |
+
TARGET_LEN = FS_TARGET * DURATION_SEC
|
| 30 |
+
|
| 31 |
+
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "EcgDataset")
|
| 32 |
+
ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "artifacts")
|
| 33 |
+
|
| 34 |
+
# Page Config
|
| 35 |
+
st.set_page_config(
|
| 36 |
+
page_title="ECG AI Diagnostic System",
|
| 37 |
+
page_icon="",
|
| 38 |
+
layout="wide",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# MODEL LOADING (cached — runs once)
|
| 43 |
+
|
| 44 |
+
@st.cache_resource(show_spinner="Loading AI model...")
|
| 45 |
+
def load_model():
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
|
| 48 |
+
model = HMT_ECGNet(num_classes=5, num_leads=N_LEADS).to(device)
|
| 49 |
+
|
| 50 |
+
ckpt_path = os.path.join(ARTIFACTS_DIR, "multilabel_best.pth")
|
| 51 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 52 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
return model, device
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@st.cache_data(show_spinner=False)
|
| 59 |
+
def load_thresholds():
|
| 60 |
+
path = os.path.join(ARTIFACTS_DIR, "multilabel_thresholds.json")
|
| 61 |
+
with open(path) as f:
|
| 62 |
+
return json.load(f)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# DATA DOWNLOADING & LOADING
|
| 66 |
+
|
| 67 |
+
@st.cache_data(show_spinner="Downloading PTB-XL sample data...")
|
| 68 |
+
def download_ptbxl_data():
|
| 69 |
+
"""Download PTB-XL database CSV + a subset of high-res records."""
|
| 70 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
csv_path = os.path.join(DATA_DIR, "ptbxl_database.csv")
|
| 73 |
+
if not os.path.exists(csv_path):
|
| 74 |
+
# Download the metadata CSV and SCP statements
|
| 75 |
+
wfdb.dl_database("ptb-xl", dl_dir=DATA_DIR, records="all", annotators=None)
|
| 76 |
+
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@st.cache_data(show_spinner=False)
|
| 81 |
+
def load_test_metadata():
|
| 82 |
+
csv_path = os.path.join(DATA_DIR, "ptbxl_database.csv")
|
| 83 |
+
df = pd.read_csv(csv_path)
|
| 84 |
+
df_test = df[df["strat_fold"] == 10].reset_index(drop=True)
|
| 85 |
+
return df_test
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@st.cache_data(show_spinner=False)
|
| 89 |
+
def load_and_preprocess_ecg(filename):
|
| 90 |
+
"""Load a single ECG record from PTB-XL and preprocess for display."""
|
| 91 |
+
filepath = os.path.join(DATA_DIR, filename)
|
| 92 |
+
|
| 93 |
+
# Download this specific record if not present
|
| 94 |
+
record_dir = os.path.dirname(filepath)
|
| 95 |
+
os.makedirs(record_dir, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
if not os.path.exists(filepath + ".hea"):
|
| 98 |
+
# Download just this record
|
| 99 |
+
rel_path = filename.replace("\\", "/")
|
| 100 |
+
try:
|
| 101 |
+
wfdb.dl_database(
|
| 102 |
+
"ptb-xl",
|
| 103 |
+
dl_dir=DATA_DIR,
|
| 104 |
+
records=[rel_path],
|
| 105 |
+
)
|
| 106 |
+
except Exception:
|
| 107 |
+
st.error(f"Could not download record: {filename}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
sig, _ = wfdb.rdsamp(filepath)
|
| 111 |
+
sig = sig.T # (12, T)
|
| 112 |
+
|
| 113 |
+
# Bandpass filter for display
|
| 114 |
+
nyq = 0.5 * FS_ORIG
|
| 115 |
+
b, a = butter(4, [0.5 / nyq, 40.0 / nyq], btype="band")
|
| 116 |
+
for i in range(12):
|
| 117 |
+
sig[i] = filtfilt(b, a, sig[i])
|
| 118 |
+
|
| 119 |
+
# Resample for display
|
| 120 |
+
sig = resample(sig, TARGET_LEN, axis=1)
|
| 121 |
+
|
| 122 |
+
return sig.astype(np.float32)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# INFERENCE (runs directly — no FastAPI needed)
|
| 126 |
+
|
| 127 |
+
def run_inference(ecg_display, model, device, thresholds):
|
| 128 |
+
"""Run model inference on preprocessed ECG data."""
|
| 129 |
+
# Re-preprocess from display signal for model input
|
| 130 |
+
ecg_for_model = preprocess_signal(ecg_display.copy())
|
| 131 |
+
|
| 132 |
+
x = torch.tensor(ecg_for_model, dtype=torch.float32).unsqueeze(0).to(device)
|
| 133 |
+
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
probs = torch.sigmoid(model(x)).cpu().numpy()[0]
|
| 136 |
+
|
| 137 |
+
result = {}
|
| 138 |
+
predicted = []
|
| 139 |
+
|
| 140 |
+
for cls, p in zip(DIAG_CLASSES, probs):
|
| 141 |
+
thr = thresholds[cls]
|
| 142 |
+
result[cls] = float(p)
|
| 143 |
+
if p >= thr:
|
| 144 |
+
predicted.append(cls)
|
| 145 |
+
|
| 146 |
+
mi_prob = float(probs[1])
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"probabilities": result,
|
| 150 |
+
"predicted_classes": predicted,
|
| 151 |
+
"mi_probability": mi_prob,
|
| 152 |
+
"mi_risk": mi_prob >= MI_BINARY_THRESHOLD,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ECG PLOTTING
|
| 157 |
+
|
| 158 |
+
def plot_ecg(ecg):
|
| 159 |
+
"""Plot full 12-lead ECG as a static figure."""
|
| 160 |
+
fig, axes = plt.subplots(12, 1, figsize=(24, 14), sharex=True)
|
| 161 |
+
|
| 162 |
+
x = np.arange(ecg.shape[1])
|
| 163 |
+
|
| 164 |
+
for i in range(12):
|
| 165 |
+
axes[i].plot(x, ecg[i], lw=1.1, color="#1f77b4")
|
| 166 |
+
axes[i].set_ylabel(
|
| 167 |
+
LEAD_NAMES[i],
|
| 168 |
+
rotation=0,
|
| 169 |
+
labelpad=28,
|
| 170 |
+
fontsize=10,
|
| 171 |
+
)
|
| 172 |
+
axes[i].grid(True, alpha=0.3)
|
| 173 |
+
|
| 174 |
+
axes[-1].set_xlabel("Time (samples)")
|
| 175 |
+
plt.tight_layout()
|
| 176 |
+
return fig
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# MAIN APP
|
| 180 |
+
|
| 181 |
+
st.title(" 12-Lead ECG AI Diagnostic System")
|
| 182 |
+
st.markdown(
|
| 183 |
+
"**Live demo on unseen PTB-XL TEST ECGs** \n"
|
| 184 |
+
"Lightweight hierarchical model • No data leakage • Realistic evaluation"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Load model & data
|
| 188 |
+
model, device = load_model()
|
| 189 |
+
thresholds = load_thresholds()
|
| 190 |
+
|
| 191 |
+
with st.spinner("Preparing PTB-XL test data..."):
|
| 192 |
+
download_ptbxl_data()
|
| 193 |
+
df_test = load_test_metadata()
|
| 194 |
+
|
| 195 |
+
if len(df_test) == 0:
|
| 196 |
+
st.error("No test data found. Please check the PTB-XL dataset.")
|
| 197 |
+
st.stop()
|
| 198 |
+
|
| 199 |
+
# Sidebar
|
| 200 |
+
st.sidebar.header("ECG Sample Selector")
|
| 201 |
+
|
| 202 |
+
sample_idx = st.sidebar.slider(
|
| 203 |
+
"Select ECG from TEST set",
|
| 204 |
+
0,
|
| 205 |
+
len(df_test) - 1,
|
| 206 |
+
0,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
row = df_test.iloc[sample_idx]
|
| 210 |
+
ecg = load_and_preprocess_ecg(row["filename_hr"])
|
| 211 |
+
|
| 212 |
+
if ecg is None:
|
| 213 |
+
st.warning("Could not load this ECG record. Try another sample.")
|
| 214 |
+
st.stop()
|
| 215 |
+
|
| 216 |
+
# ECG Display
|
| 217 |
+
st.subheader(f"ECG Sample #{sample_idx}")
|
| 218 |
+
|
| 219 |
+
fig = plot_ecg(ecg)
|
| 220 |
+
st.pyplot(fig)
|
| 221 |
+
plt.close(fig)
|
| 222 |
+
|
| 223 |
+
# AI Inference
|
| 224 |
+
st.subheader("AI Diagnosis")
|
| 225 |
+
|
| 226 |
+
with st.spinner("Running inference..."):
|
| 227 |
+
result = run_inference(ecg, model, device, thresholds)
|
| 228 |
+
|
| 229 |
+
# Results
|
| 230 |
+
st.markdown("### Per-class Probabilities")
|
| 231 |
+
|
| 232 |
+
cols = st.columns(5)
|
| 233 |
+
for col, (cls, prob) in zip(cols, result["probabilities"].items()):
|
| 234 |
+
col.metric(cls, f"{prob:.3f}")
|
| 235 |
+
|
| 236 |
+
st.markdown("### Final Predicted Classes")
|
| 237 |
+
|
| 238 |
+
if result["predicted_classes"]:
|
| 239 |
+
st.error(", ".join(result["predicted_classes"]))
|
| 240 |
+
else:
|
| 241 |
+
st.success("Normal ECG — No pathology detected")
|
| 242 |
+
|
| 243 |
+
st.markdown("### Myocardial Infarction Screening")
|
| 244 |
+
|
| 245 |
+
st.metric("MI Probability", f"{result['mi_probability']:.3f}")
|
| 246 |
+
|
| 247 |
+
if result["mi_risk"]:
|
| 248 |
+
st.error("⚠️ High likelihood of Myocardial Infarction")
|
| 249 |
+
else:
|
| 250 |
+
st.success(" No strong MI indication")
|
| 251 |
+
|
| 252 |
+
# Footer
|
| 253 |
+
st.markdown("---")
|
| 254 |
+
st.caption(
|
| 255 |
+
"⚕️ **Disclaimer:** This system is for research and demonstration only. "
|
| 256 |
+
"Not intended for clinical diagnosis or treatment."
|
| 257 |
+
)
|
binary_threshold.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mi_vs_norm": 0.05
|
| 3 |
+
}
|
mi_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1d554e14d3d5a81bbdd19b067f349d58c1eea5719915b44607345d2bd563ccb
|
| 3 |
+
size 1368226
|
multilabel_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:598ffa81fcba7d94e1b229c6c81b4fe4479b53d01a15eac90e50348bff7a4cc7
|
| 3 |
+
size 1369958
|
multilabel_thresholds.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"NORM": 0.145,
|
| 3 |
+
"MI": 0.195,
|
| 4 |
+
"STTC": 0.8749999999999999,
|
| 5 |
+
"CD": 0.055,
|
| 6 |
+
"HYP": 0.31499999999999995
|
| 7 |
+
}
|
requirements.txt
ADDED
|
Binary file (534 Bytes). View file
|
|
|