TharakaDil2001's picture
Update usage instructions with huggingface_hub download and multiple usage options
82266ff verified
---
license: mit
tags:
- ecg
- atrial-fibrillation
- diffusion-model
- data-augmentation
- medical-ai
- time-series
- pytorch
datasets:
- mimic-iv-ecg
language:
- en
pipeline_tag: other
---
# Diffusion-Based ECG Augmentation Model
A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for **training data augmentation** in AFib detection systems.
## Model Description
This model separates ECG signals into class-invariant **content** (beat morphology) and class-specific **style** (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another.
### Architecture
| Component | Parameters | Description |
|-----------|-----------|-------------|
| Content Encoder (VAE) | 4.3M | Extracts class-invariant temporal patterns using BatchNorm |
| Style Encoder (CNN) | 567K | Captures class-discriminative features using InstanceNorm |
| Conditional UNet | 14.3M | Denoises with FiLM conditioning on content + style |
| **Total** | **19.1M** | |
### Training
- **Dataset**: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows)
- **Stage 1** (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence)
- **Stage 2** (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss)
- **Hardware**: NVIDIA RTX 6000 Ada (48 GB VRAM)
- **Training time**: ~22 hours
### Generation
- **Method**: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0)
- **Filtering**: Clinical plausibility validator (morphological + physiological checks, threshold 0.7)
- **Output**: 7,784 accepted synthetic ECGs from test set
## Key Results
### Augmentation Viability (5-Fold Cross-Validation)
| Condition | Training Data | Accuracy | F1 Score |
|-----------|--------------|----------|----------|
| A (Real only) | 18,681 original ECGs | 95.63 ± 0.33% | 95.65 ± 0.35% |
| B (Synthetic only) | 7,784 generated (×3) | 85.94 ± 1.32% | 86.70 ± 1.24% |
| C (Augmented) | 67% real + 33% synthetic | 95.05 ± 0.50% | 95.09 ± 0.46% |
**TOST equivalence test confirms A ≈ C** (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance.
### Signal Quality
| Metric | Value |
|--------|-------|
| PSNR | 12.58 ± 2.09 dB |
| SSIM | 0.471 ± 0.110 |
| MSE | 0.005 ± 0.004 |
## Files
- `diffusion_model.pth` — Trained diffusion model (Stage 2, Epoch 100)
- `classifier_model.pth` — ResNet-BiLSTM AFib classifier
- `model_metadata.json` — Training configuration and final metrics
## Usage
### Option 1: Interactive Demo (Easiest)
Try the model directly in your browser — no code needed:
👉 **[Launch Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)**
Upload an ECG (`.npy` or `.csv`, 2500 samples at 250 Hz) or browse pre-loaded examples.
### Option 2: Download & Use in Python
```python
from huggingface_hub import hf_hub_download
import torch
# Download model files from Hugging Face
diffusion_path = hf_hub_download(
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
filename="diffusion_model.pth"
)
classifier_path = hf_hub_download(
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
filename="classifier_model.pth"
)
# Load the diffusion model checkpoint
checkpoint = torch.load(diffusion_path, map_location="cpu")
# The checkpoint contains:
# - checkpoint['content_encoder'] → Content Encoder state dict
# - checkpoint['style_encoder'] → Style Encoder state dict
# - checkpoint['unet'] → UNet state dict
# - checkpoint['config'] → Training config with all hyperparameters
# Load the classifier checkpoint
cls_checkpoint = torch.load(classifier_path, map_location="cpu")
# - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict
# To use the full pipeline, clone the repository:
# git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
# See: diffusion_pipeline/final_pipeline/ for model architectures
```
### Option 3: Clone the Full Pipeline
```bash
# Clone the full codebase with all model architectures
git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/
# Download weights
pip install huggingface_hub
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.')
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.')
"
```
> **Note**: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts.
## Citation
```bibtex
@misc{pera_af_detection_2025,
title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection},
author={Dilshan, D.M.T. and Karunarathne, K.N.P.},
year={2025},
institution={University of Peradeniya, Sri Lanka},
collaboration={SimulaMet, Oslo, Norway}
}
```
## Links
- [GitHub Repository](https://github.com/vlbthambawita/PERA_AF_Detection)
- [Interactive Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)
- [Paper (coming soon)]()