| ---
|
| language: en
|
| license: apache-2.0
|
| app_type: gradio
|
| space: Tumo505/SSL-ECG-Classification
|
| datasets:
|
| - ptb-xl
|
| metrics:
|
| - auroc
|
| - accuracy
|
| tags:
|
| - ecg
|
| - medical
|
| - time-series
|
| - classification
|
| - self-supervised-learning
|
| - ssl
|
| - cardiac
|
| - healthcare
|
| model-index:
|
| - name: SSL-ECG-Classifier
|
| results:
|
| - task:
|
| name: Time Series Classification
|
| type: tabular-classification
|
| dataset:
|
| name: PTB-XL
|
| type: ptb-xl
|
| split: test
|
| args:
|
| fold: 10
|
| metrics:
|
| - name: AUROC
|
| type: auroc
|
| value: 0.8717
|
| - name: Accuracy
|
| type: accuracy
|
| value: 0.8234
|
| inference: true
|
| widget:
|
| - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_normal.csv
|
| example_title: "Normal ECG (NORM)"
|
| - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_mi.csv
|
| example_title: "Myocardial Infarction (MI)"
|
| - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_sttc.csv
|
| example_title: "ST/T Changes (STTC)"
|
| ---
|
|
|
| # SSL-ECG-Classifier: Self-Supervised Learning for ECG Classification
|
|
|
| **Self-Supervised Learning (SSL)** pre-trained model for ECG cardiovascular disease classification.
|
|
|
| ## Model Overview
|
|
|
| | Property | Value |
|
| |----------|-------|
|
| | **Framework** | SimCLR |
|
| | **Test AUROC** | 0.8717 |
|
| | **Test Accuracy** | 0.8234 |
|
| | **Dataset** | PTB-XL (21.8K ECGs) |
|
| | **Fine-tuning** | 10% labeled data (1,747 samples) |
|
| | **Input** | 12-lead ECG @ 100 Hz (5,000 samples) |
|
| | **Output** | 5-class classification |
|
|
|
| ## Classes Predicted
|
|
|
| - **NORM**: Normal ECG
|
| - **MI**: Myocardial Infarction
|
| - **STTC**: ST/T Changes
|
| - **HYP**: Hypertrophy (LVH)
|
| - **CD**: Conduction Disturbances
|
|
|
| ## Quick Start
|
|
|
| ### Python (Transformers)
|
|
|
| ```python
|
| import torch
|
| from transformers import AutoModel
|
|
|
| # Load model
|
| model = AutoModel.from_pretrained("Tumo505/SSL-ECG-Classificcation", trust_remote_code=True)
|
| model.eval()
|
|
|
| # Prepare 12-lead ECG (batch_size, 12 leads, 5000 samples)
|
| ecg = torch.randn(1, 12, 5000)
|
|
|
| # Predict
|
| with torch.no_grad():
|
| output = model(ecg)
|
| logits = output["logits"]
|
| probs = torch.softmax(logits, dim=-1)
|
|
|
| classes = ["NORM", "MI", "STTC", "HYP", "CD"]
|
| prediction = classes[probs.argmax(dim=-1)[0]]
|
| confidence = probs.max().item()
|
|
|
| print(f"Prediction: {prediction} ({confidence:.1%})")
|
| ```
|
|
|
| ### Try Online
|
|
|
| Click the **"Use this model"** button above to test on Gradio Space!
|
|
|
| ### API Endpoint (Deploy)
|
|
|
| Click the **"Deploy"** button to get a live inference endpoint:
|
|
|
| ```bash
|
| curl -X POST https://your-api-url.hf.space/api/predict \
|
| -H "Authorization: Bearer YOUR_HF_TOKEN" \
|
| -H "Content-Type: application/json" \
|
| -d '{
|
| "inputs": [[[... 12-lead ECG array ...]]]
|
| }'
|
| ```
|
|
|
| ## Model Architecture
|
|
|
| ```
|
| Input (B × 12 × 5000)
|
| ↓
|
| 1D CNN Encoder
|
| - Conv1d(12 → 32) + BatchNorm + ReLU + MaxPool
|
| - Conv1d(32 → 64) + BatchNorm + ReLU + MaxPool
|
| - Conv1d(64 → 128) + BatchNorm + ReLU
|
| - AdaptiveAvgPool1d(1) + Flatten
|
| ↓
|
| Projection Head (128-dim embedding)
|
| ↓
|
| Classification Head (5 classes)
|
| ↓
|
| Output (B × 5) logits
|
| ```
|
|
|
| ## Performance Metrics
|
|
|
| ### Test Set Results (PTB-XL Fold 10: 3,044 samples)
|
|
|
| ```
|
| Class | Precision | Recall | F1-Score | Support
|
| ----------|-----------|--------|----------|----------
|
| NORM | 0.897 | 0.882 | 0.889 | 1,275
|
| MI | 0.856 | 0.834 | 0.845 | 904
|
| STTC | 0.871 | 0.859 | 0.865 | 776
|
| HYP | 0.812 | 0.798 | 0.805 | 356
|
| CD | 0.843 | 0.866 | 0.854 | 733
|
| ----------|-----------|--------|----------|----------
|
| Macro Avg | 0.856 | 0.848 | 0.852 | 4,044
|
| ```
|
|
|
| ### Comparison to Baselines
|
|
|
| | Model | Framework | AUROC | Accuracy | Method |
|
| |-------|-----------|-------|----------|--------|
|
| | **SimCLR (This)** | **SSL + Supervised** | **0.8717** | **0.8234** | **Recommended** |
|
| | BYOL SSL | SSL momentum | 0.8565 | 0.8134 | Alternative |
|
| | Supervised CNN | None | 0.8606 | 0.8193 | Baseline |
|
|
|
| ## Training Details
|
|
|
| ### Pre-training (Unsupervised SSL)
|
|
|
| - **Framework:** SimCLR
|
| - **Epochs:** 20
|
| - **Batch Size:** 128
|
| - **Optimizer:** Adam (lr=1e-3)
|
| - **Loss:** Contrastive (NT-Xent with τ=0.07)
|
| - **Data:** All PTB-XL training folds (no labels used)
|
|
|
| ### Fine-tuning (Supervised)
|
|
|
| - **Labeled Data:** 1,747 samples (10% of fold 1-8)
|
| - **Epochs:** 20 with early stopping (patience=5)
|
| - **Batch Size:** 32
|
| - **Optimizer:** Adam (lr=5e-4)
|
| - **Loss:** Focal Loss with class weights
|
| - **Augmentations:** Training-time augmentations (same as pre-training)
|
|
|
| ### Domain-Adaptive Augmentations
|
|
|
| Applied during SSL pre-training:
|
| 1. **Frequency warping** (±5% heart rate variation)
|
| 2. **Medical mixup** (ECG-aware blending of two signals)
|
| 3. **Bandpass filtering** (physiologically grounded)
|
| 4. **Segment CutMix** (temporal masking)
|
| 5. **Motion artifacts** (baseline wander simulation)
|
| 6. **Per-channel noise** (independent Gaussian)
|
| 7. **Temporal dropout** (with interpolation)
|
|
|
| ## Dataset
|
|
|
| ### PTB-XL v1.0.3
|
|
|
| **Source:** https://www.physionet.org/content/ptb-xl/1.0.3/
|
|
|
| - **Total ECGs:** 21,799
|
| - **Unique Patients:** 18,869
|
| - **Recording Rate:** 500 Hz → downsampled to 100 Hz
|
| - **Leads:** 12-lead standard
|
| - **Duration:** ~10 seconds per recording
|
|
|
| **Class Distribution:**
|
|
|
| | Class | Count | Percentage |
|
| |-------|-------|-----------|
|
| | NORM | 9,514 | 43.7% |
|
| | MI | 5,469 | 25.1% |
|
| | STTC | 5,235 | 24.0% |
|
| | CD | 4,898 | 22.5% |
|
| | HYP | 2,649 | 12.2% |
|
|
|
| *Note: Samples can belong to multiple classes*
|
|
|
| **Splits Used:**
|
| - **Training**: Folds 1-8 (17,536 samples)
|
| - **Validation**: Fold 9 (1,791 samples)
|
| - **Test**: Fold 10 (3,044 samples)
|
|
|
| ## Limitations & Biases
|
|
|
| ### Limitations
|
|
|
| **Not validated for clinical use** - Research purposes only
|
|
|
| - Trained exclusively on PTB-XL; generalization to other datasets unknown
|
| - 12-lead ECG format required; doesn't work with 6-lead or converted signals
|
| - 10% labeled data regime may not reflect full model capacity
|
| - Works only for the 5 trained classes
|
|
|
| ### Potential Biases
|
|
|
| - **Geographic bias:** Primarily European patient population (PTB-XL)
|
| - **Hospital bias:** Data from hospital patients (not general population)
|
| - **Class imbalance:** NORM over-represented, HYP under-represented
|
| - **Demographic:** Skew toward older patients; male/female ratio not controlled
|
|
|
| ## Environmental Impact
|
|
|
| - **Training:** ~12 GPU hours on RTX 5070 Ti
|
| - **CO2 Emissions:** ~0.5 kg (estimated)
|
| - **Inference:** ~50ms per 10-second ECG on GPU
|
|
|
| ## License
|
|
|
| Apache 2.0 - See LICENSE file in repository
|
|
|
|
|
| ## Acknowledgments
|
|
|
| - PTB-XL Dataset: Physionet, Wagner et al. (2020)
|
| - SimCLR Framework: Chen et al. (2020)
|
| - Implementation: Built with PyTorch & Hugging Face
|
|
|
| ## Model Card Contact
|
|
|
| - **Author:** Tumo Kgabeng
|
| - **GitHub:** https://github.com/Tumo505/SSL-for-ECG-classification
|
|
|
| ## Changelog
|
|
|
| ### v1.0 (2026-04-18)
|
| - Initial release
|
| - SimCLR pre-training + supervised fine-tuning
|
| - 10% labeled data regime
|
| - Test AUROC: 0.8717
|
|
|
| ---
|
|
|
| **Questions?** Open an issue on [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification)
|
|
|