--- language: en license: apache-2.0 app_type: gradio space: Tumo505/SSL-ECG-Classification datasets: - ptb-xl metrics: - auroc - accuracy tags: - ecg - medical - time-series - classification - self-supervised-learning - ssl - cardiac - healthcare model-index: - name: SSL-ECG-Classifier results: - task: name: Time Series Classification type: tabular-classification dataset: name: PTB-XL type: ptb-xl split: test args: fold: 10 metrics: - name: AUROC type: auroc value: 0.8717 - name: Accuracy type: accuracy value: 0.8234 inference: true widget: - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_normal.csv example_title: "Normal ECG (NORM)" - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_mi.csv example_title: "Myocardial Infarction (MI)" - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_sttc.csv example_title: "ST/T Changes (STTC)" --- # SSL-ECG-Classifier: Self-Supervised Learning for ECG Classification **Self-Supervised Learning (SSL)** pre-trained model for ECG cardiovascular disease classification. ## Model Overview | Property | Value | |----------|-------| | **Framework** | SimCLR | | **Test AUROC** | 0.8717 | | **Test Accuracy** | 0.8234 | | **Dataset** | PTB-XL (21.8K ECGs) | | **Fine-tuning** | 10% labeled data (1,747 samples) | | **Input** | 12-lead ECG @ 100 Hz (5,000 samples) | | **Output** | 5-class classification | ## Classes Predicted - **NORM**: Normal ECG - **MI**: Myocardial Infarction - **STTC**: ST/T Changes - **HYP**: Hypertrophy (LVH) - **CD**: Conduction Disturbances ## Quick Start ### Python (Transformers) ```python import torch from transformers import AutoModel # Load model model = AutoModel.from_pretrained("Tumo505/SSL-ECG-Classificcation", trust_remote_code=True) model.eval() # Prepare 12-lead ECG (batch_size, 12 leads, 5000 samples) ecg = torch.randn(1, 12, 5000) # Predict with torch.no_grad(): output = model(ecg) logits = output["logits"] probs = torch.softmax(logits, dim=-1) classes = ["NORM", "MI", "STTC", "HYP", "CD"] prediction = classes[probs.argmax(dim=-1)[0]] confidence = probs.max().item() print(f"Prediction: {prediction} ({confidence:.1%})") ``` ### Try Online Click the **"Use this model"** button above to test on Gradio Space! ### API Endpoint (Deploy) Click the **"Deploy"** button to get a live inference endpoint: ```bash curl -X POST https://your-api-url.hf.space/api/predict \ -H "Authorization: Bearer YOUR_HF_TOKEN" \ -H "Content-Type: application/json" \ -d '{ "inputs": [[[... 12-lead ECG array ...]]] }' ``` ## Model Architecture ``` Input (B × 12 × 5000) ↓ 1D CNN Encoder - Conv1d(12 → 32) + BatchNorm + ReLU + MaxPool - Conv1d(32 → 64) + BatchNorm + ReLU + MaxPool - Conv1d(64 → 128) + BatchNorm + ReLU - AdaptiveAvgPool1d(1) + Flatten ↓ Projection Head (128-dim embedding) ↓ Classification Head (5 classes) ↓ Output (B × 5) logits ``` ## Performance Metrics ### Test Set Results (PTB-XL Fold 10: 3,044 samples) ``` Class | Precision | Recall | F1-Score | Support ----------|-----------|--------|----------|---------- NORM | 0.897 | 0.882 | 0.889 | 1,275 MI | 0.856 | 0.834 | 0.845 | 904 STTC | 0.871 | 0.859 | 0.865 | 776 HYP | 0.812 | 0.798 | 0.805 | 356 CD | 0.843 | 0.866 | 0.854 | 733 ----------|-----------|--------|----------|---------- Macro Avg | 0.856 | 0.848 | 0.852 | 4,044 ``` ### Comparison to Baselines | Model | Framework | AUROC | Accuracy | Method | |-------|-----------|-------|----------|--------| | **SimCLR (This)** | **SSL + Supervised** | **0.8717** | **0.8234** | **Recommended** | | BYOL SSL | SSL momentum | 0.8565 | 0.8134 | Alternative | | Supervised CNN | None | 0.8606 | 0.8193 | Baseline | ## Training Details ### Pre-training (Unsupervised SSL) - **Framework:** SimCLR - **Epochs:** 20 - **Batch Size:** 128 - **Optimizer:** Adam (lr=1e-3) - **Loss:** Contrastive (NT-Xent with τ=0.07) - **Data:** All PTB-XL training folds (no labels used) ### Fine-tuning (Supervised) - **Labeled Data:** 1,747 samples (10% of fold 1-8) - **Epochs:** 20 with early stopping (patience=5) - **Batch Size:** 32 - **Optimizer:** Adam (lr=5e-4) - **Loss:** Focal Loss with class weights - **Augmentations:** Training-time augmentations (same as pre-training) ### Domain-Adaptive Augmentations Applied during SSL pre-training: 1. **Frequency warping** (±5% heart rate variation) 2. **Medical mixup** (ECG-aware blending of two signals) 3. **Bandpass filtering** (physiologically grounded) 4. **Segment CutMix** (temporal masking) 5. **Motion artifacts** (baseline wander simulation) 6. **Per-channel noise** (independent Gaussian) 7. **Temporal dropout** (with interpolation) ## Dataset ### PTB-XL v1.0.3 **Source:** https://www.physionet.org/content/ptb-xl/1.0.3/ - **Total ECGs:** 21,799 - **Unique Patients:** 18,869 - **Recording Rate:** 500 Hz → downsampled to 100 Hz - **Leads:** 12-lead standard - **Duration:** ~10 seconds per recording **Class Distribution:** | Class | Count | Percentage | |-------|-------|-----------| | NORM | 9,514 | 43.7% | | MI | 5,469 | 25.1% | | STTC | 5,235 | 24.0% | | CD | 4,898 | 22.5% | | HYP | 2,649 | 12.2% | *Note: Samples can belong to multiple classes* **Splits Used:** - **Training**: Folds 1-8 (17,536 samples) - **Validation**: Fold 9 (1,791 samples) - **Test**: Fold 10 (3,044 samples) ## Limitations & Biases ### Limitations **Not validated for clinical use** - Research purposes only - Trained exclusively on PTB-XL; generalization to other datasets unknown - 12-lead ECG format required; doesn't work with 6-lead or converted signals - 10% labeled data regime may not reflect full model capacity - Works only for the 5 trained classes ### Potential Biases - **Geographic bias:** Primarily European patient population (PTB-XL) - **Hospital bias:** Data from hospital patients (not general population) - **Class imbalance:** NORM over-represented, HYP under-represented - **Demographic:** Skew toward older patients; male/female ratio not controlled ## Environmental Impact - **Training:** ~12 GPU hours on RTX 5070 Ti - **CO2 Emissions:** ~0.5 kg (estimated) - **Inference:** ~50ms per 10-second ECG on GPU ## License Apache 2.0 - See LICENSE file in repository ## Acknowledgments - PTB-XL Dataset: Physionet, Wagner et al. (2020) - SimCLR Framework: Chen et al. (2020) - Implementation: Built with PyTorch & Hugging Face ## Model Card Contact - **Author:** Tumo Kgabeng - **GitHub:** https://github.com/Tumo505/SSL-for-ECG-classification ## Changelog ### v1.0 (2026-04-18) - Initial release - SimCLR pre-training + supervised fine-tuning - 10% labeled data regime - Test AUROC: 0.8717 --- **Questions?** Open an issue on [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification)