Tumo505's picture
update model loading code
ddc5b19
---
language: en
license: apache-2.0
app_type: gradio
space: Tumo505/SSL-ECG-Classification
datasets:
- ptb-xl
metrics:
- auroc
- accuracy
tags:
- ecg
- medical
- time-series
- classification
- self-supervised-learning
- ssl
- cardiac
- healthcare
model-index:
- name: SSL-ECG-Classifier
results:
- task:
name: Time Series Classification
type: tabular-classification
dataset:
name: PTB-XL
type: ptb-xl
split: test
args:
fold: 10
metrics:
- name: AUROC
type: auroc
value: 0.8717
- name: Accuracy
type: accuracy
value: 0.8234
inference: true
widget:
- src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_normal.csv
example_title: "Normal ECG (NORM)"
- src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_mi.csv
example_title: "Myocardial Infarction (MI)"
- src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_sttc.csv
example_title: "ST/T Changes (STTC)"
---
# SSL-ECG-Classifier: Self-Supervised Learning for ECG Classification
**Self-Supervised Learning (SSL)** pre-trained model for ECG cardiovascular disease classification.
## Model Overview
| Property | Value |
|----------|-------|
| **Framework** | SimCLR |
| **Test AUROC** | 0.8717 |
| **Test Accuracy** | 0.8234 |
| **Dataset** | PTB-XL (21.8K ECGs) |
| **Fine-tuning** | 10% labeled data (1,747 samples) |
| **Input** | 12-lead ECG @ 100 Hz (5,000 samples) |
| **Output** | 5-class classification |
## Classes Predicted
- **NORM**: Normal ECG
- **MI**: Myocardial Infarction
- **STTC**: ST/T Changes
- **HYP**: Hypertrophy (LVH)
- **CD**: Conduction Disturbances
## Quick Start
### Python (Transformers)
```python
import torch
from transformers import AutoModel
# Load model
model = AutoModel.from_pretrained("Tumo505/SSL-ECG-Classificcation", trust_remote_code=True)
model.eval()
# Prepare 12-lead ECG (batch_size, 12 leads, 5000 samples)
ecg = torch.randn(1, 12, 5000)
# Predict
with torch.no_grad():
output = model(ecg)
logits = output["logits"]
probs = torch.softmax(logits, dim=-1)
classes = ["NORM", "MI", "STTC", "HYP", "CD"]
prediction = classes[probs.argmax(dim=-1)[0]]
confidence = probs.max().item()
print(f"Prediction: {prediction} ({confidence:.1%})")
```
### Try Online
Click the **"Use this model"** button above to test on Gradio Space!
### API Endpoint (Deploy)
Click the **"Deploy"** button to get a live inference endpoint:
```bash
curl -X POST https://your-api-url.hf.space/api/predict \
-H "Authorization: Bearer YOUR_HF_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"inputs": [[[... 12-lead ECG array ...]]]
}'
```
## Model Architecture
```
Input (B × 12 × 5000)
1D CNN Encoder
- Conv1d(12 → 32) + BatchNorm + ReLU + MaxPool
- Conv1d(32 → 64) + BatchNorm + ReLU + MaxPool
- Conv1d(64 → 128) + BatchNorm + ReLU
- AdaptiveAvgPool1d(1) + Flatten
Projection Head (128-dim embedding)
Classification Head (5 classes)
Output (B × 5) logits
```
## Performance Metrics
### Test Set Results (PTB-XL Fold 10: 3,044 samples)
```
Class | Precision | Recall | F1-Score | Support
----------|-----------|--------|----------|----------
NORM | 0.897 | 0.882 | 0.889 | 1,275
MI | 0.856 | 0.834 | 0.845 | 904
STTC | 0.871 | 0.859 | 0.865 | 776
HYP | 0.812 | 0.798 | 0.805 | 356
CD | 0.843 | 0.866 | 0.854 | 733
----------|-----------|--------|----------|----------
Macro Avg | 0.856 | 0.848 | 0.852 | 4,044
```
### Comparison to Baselines
| Model | Framework | AUROC | Accuracy | Method |
|-------|-----------|-------|----------|--------|
| **SimCLR (This)** | **SSL + Supervised** | **0.8717** | **0.8234** | **Recommended** |
| BYOL SSL | SSL momentum | 0.8565 | 0.8134 | Alternative |
| Supervised CNN | None | 0.8606 | 0.8193 | Baseline |
## Training Details
### Pre-training (Unsupervised SSL)
- **Framework:** SimCLR
- **Epochs:** 20
- **Batch Size:** 128
- **Optimizer:** Adam (lr=1e-3)
- **Loss:** Contrastive (NT-Xent with τ=0.07)
- **Data:** All PTB-XL training folds (no labels used)
### Fine-tuning (Supervised)
- **Labeled Data:** 1,747 samples (10% of fold 1-8)
- **Epochs:** 20 with early stopping (patience=5)
- **Batch Size:** 32
- **Optimizer:** Adam (lr=5e-4)
- **Loss:** Focal Loss with class weights
- **Augmentations:** Training-time augmentations (same as pre-training)
### Domain-Adaptive Augmentations
Applied during SSL pre-training:
1. **Frequency warping** (±5% heart rate variation)
2. **Medical mixup** (ECG-aware blending of two signals)
3. **Bandpass filtering** (physiologically grounded)
4. **Segment CutMix** (temporal masking)
5. **Motion artifacts** (baseline wander simulation)
6. **Per-channel noise** (independent Gaussian)
7. **Temporal dropout** (with interpolation)
## Dataset
### PTB-XL v1.0.3
**Source:** https://www.physionet.org/content/ptb-xl/1.0.3/
- **Total ECGs:** 21,799
- **Unique Patients:** 18,869
- **Recording Rate:** 500 Hz → downsampled to 100 Hz
- **Leads:** 12-lead standard
- **Duration:** ~10 seconds per recording
**Class Distribution:**
| Class | Count | Percentage |
|-------|-------|-----------|
| NORM | 9,514 | 43.7% |
| MI | 5,469 | 25.1% |
| STTC | 5,235 | 24.0% |
| CD | 4,898 | 22.5% |
| HYP | 2,649 | 12.2% |
*Note: Samples can belong to multiple classes*
**Splits Used:**
- **Training**: Folds 1-8 (17,536 samples)
- **Validation**: Fold 9 (1,791 samples)
- **Test**: Fold 10 (3,044 samples)
## Limitations & Biases
### Limitations
**Not validated for clinical use** - Research purposes only
- Trained exclusively on PTB-XL; generalization to other datasets unknown
- 12-lead ECG format required; doesn't work with 6-lead or converted signals
- 10% labeled data regime may not reflect full model capacity
- Works only for the 5 trained classes
### Potential Biases
- **Geographic bias:** Primarily European patient population (PTB-XL)
- **Hospital bias:** Data from hospital patients (not general population)
- **Class imbalance:** NORM over-represented, HYP under-represented
- **Demographic:** Skew toward older patients; male/female ratio not controlled
## Environmental Impact
- **Training:** ~12 GPU hours on RTX 5070 Ti
- **CO2 Emissions:** ~0.5 kg (estimated)
- **Inference:** ~50ms per 10-second ECG on GPU
## License
Apache 2.0 - See LICENSE file in repository
## Acknowledgments
- PTB-XL Dataset: Physionet, Wagner et al. (2020)
- SimCLR Framework: Chen et al. (2020)
- Implementation: Built with PyTorch & Hugging Face
## Model Card Contact
- **Author:** Tumo Kgabeng
- **GitHub:** https://github.com/Tumo505/SSL-for-ECG-classification
## Changelog
### v1.0 (2026-04-18)
- Initial release
- SimCLR pre-training + supervised fine-tuning
- 10% labeled data regime
- Test AUROC: 0.8717
---
**Questions?** Open an issue on [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification)