File size: 5,348 Bytes
d0e55ba 82266ff d0e55ba 82266ff d0e55ba 82266ff d0e55ba 82266ff d0e55ba 82266ff d0e55ba 82266ff d0e55ba 82266ff d0e55ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | ---
license: mit
tags:
- ecg
- atrial-fibrillation
- diffusion-model
- data-augmentation
- medical-ai
- time-series
- pytorch
datasets:
- mimic-iv-ecg
language:
- en
pipeline_tag: other
---
# Diffusion-Based ECG Augmentation Model
A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for **training data augmentation** in AFib detection systems.
## Model Description
This model separates ECG signals into class-invariant **content** (beat morphology) and class-specific **style** (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another.
### Architecture
| Component | Parameters | Description |
|-----------|-----------|-------------|
| Content Encoder (VAE) | 4.3M | Extracts class-invariant temporal patterns using BatchNorm |
| Style Encoder (CNN) | 567K | Captures class-discriminative features using InstanceNorm |
| Conditional UNet | 14.3M | Denoises with FiLM conditioning on content + style |
| **Total** | **19.1M** | |
### Training
- **Dataset**: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows)
- **Stage 1** (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence)
- **Stage 2** (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss)
- **Hardware**: NVIDIA RTX 6000 Ada (48 GB VRAM)
- **Training time**: ~22 hours
### Generation
- **Method**: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0)
- **Filtering**: Clinical plausibility validator (morphological + physiological checks, threshold 0.7)
- **Output**: 7,784 accepted synthetic ECGs from test set
## Key Results
### Augmentation Viability (5-Fold Cross-Validation)
| Condition | Training Data | Accuracy | F1 Score |
|-----------|--------------|----------|----------|
| A (Real only) | 18,681 original ECGs | 95.63 ± 0.33% | 95.65 ± 0.35% |
| B (Synthetic only) | 7,784 generated (×3) | 85.94 ± 1.32% | 86.70 ± 1.24% |
| C (Augmented) | 67% real + 33% synthetic | 95.05 ± 0.50% | 95.09 ± 0.46% |
**TOST equivalence test confirms A ≈ C** (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance.
### Signal Quality
| Metric | Value |
|--------|-------|
| PSNR | 12.58 ± 2.09 dB |
| SSIM | 0.471 ± 0.110 |
| MSE | 0.005 ± 0.004 |
## Files
- `diffusion_model.pth` — Trained diffusion model (Stage 2, Epoch 100)
- `classifier_model.pth` — ResNet-BiLSTM AFib classifier
- `model_metadata.json` — Training configuration and final metrics
## Usage
### Option 1: Interactive Demo (Easiest)
Try the model directly in your browser — no code needed:
👉 **[Launch Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)**
Upload an ECG (`.npy` or `.csv`, 2500 samples at 250 Hz) or browse pre-loaded examples.
### Option 2: Download & Use in Python
```python
from huggingface_hub import hf_hub_download
import torch
# Download model files from Hugging Face
diffusion_path = hf_hub_download(
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
filename="diffusion_model.pth"
)
classifier_path = hf_hub_download(
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
filename="classifier_model.pth"
)
# Load the diffusion model checkpoint
checkpoint = torch.load(diffusion_path, map_location="cpu")
# The checkpoint contains:
# - checkpoint['content_encoder'] → Content Encoder state dict
# - checkpoint['style_encoder'] → Style Encoder state dict
# - checkpoint['unet'] → UNet state dict
# - checkpoint['config'] → Training config with all hyperparameters
# Load the classifier checkpoint
cls_checkpoint = torch.load(classifier_path, map_location="cpu")
# - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict
# To use the full pipeline, clone the repository:
# git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
# See: diffusion_pipeline/final_pipeline/ for model architectures
```
### Option 3: Clone the Full Pipeline
```bash
# Clone the full codebase with all model architectures
git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/
# Download weights
pip install huggingface_hub
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.')
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.')
"
```
> **Note**: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts.
## Citation
```bibtex
@misc{pera_af_detection_2025,
title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection},
author={Dilshan, D.M.T. and Karunarathne, K.N.P.},
year={2025},
institution={University of Peradeniya, Sri Lanka},
collaboration={SimulaMet, Oslo, Norway}
}
```
## Links
- [GitHub Repository](https://github.com/vlbthambawita/PERA_AF_Detection)
- [Interactive Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)
- [Paper (coming soon)]()
|