File size: 5,348 Bytes
d0e55ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82266ff
 
 
 
 
 
 
 
 
 
d0e55ba
82266ff
d0e55ba
 
82266ff
 
 
 
 
 
 
 
 
 
d0e55ba
82266ff
d0e55ba
 
82266ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0e55ba
82266ff
 
 
 
 
 
 
 
 
 
 
 
d0e55ba
 
82266ff
 
d0e55ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
---
license: mit
tags:
  - ecg
  - atrial-fibrillation
  - diffusion-model
  - data-augmentation
  - medical-ai
  - time-series
  - pytorch
datasets:
  - mimic-iv-ecg
language:
  - en
pipeline_tag: other
---

# Diffusion-Based ECG Augmentation Model

A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for **training data augmentation** in AFib detection systems.

## Model Description

This model separates ECG signals into class-invariant **content** (beat morphology) and class-specific **style** (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another.

### Architecture

| Component | Parameters | Description |
|-----------|-----------|-------------|
| Content Encoder (VAE) | 4.3M | Extracts class-invariant temporal patterns using BatchNorm |
| Style Encoder (CNN) | 567K | Captures class-discriminative features using InstanceNorm |
| Conditional UNet | 14.3M | Denoises with FiLM conditioning on content + style |
| **Total** | **19.1M** | |

### Training

- **Dataset**: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows)
- **Stage 1** (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence)
- **Stage 2** (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss)
- **Hardware**: NVIDIA RTX 6000 Ada (48 GB VRAM)
- **Training time**: ~22 hours

### Generation

- **Method**: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0)
- **Filtering**: Clinical plausibility validator (morphological + physiological checks, threshold 0.7)
- **Output**: 7,784 accepted synthetic ECGs from test set

## Key Results

### Augmentation Viability (5-Fold Cross-Validation)

| Condition | Training Data | Accuracy | F1 Score |
|-----------|--------------|----------|----------|
| A (Real only) | 18,681 original ECGs | 95.63 ± 0.33% | 95.65 ± 0.35% |
| B (Synthetic only) | 7,784 generated (×3) | 85.94 ± 1.32% | 86.70 ± 1.24% |
| C (Augmented) | 67% real + 33% synthetic | 95.05 ± 0.50% | 95.09 ± 0.46% |

**TOST equivalence test confirms A ≈ C** (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance.

### Signal Quality

| Metric | Value |
|--------|-------|
| PSNR | 12.58 ± 2.09 dB |
| SSIM | 0.471 ± 0.110 |
| MSE | 0.005 ± 0.004 |

## Files

- `diffusion_model.pth` — Trained diffusion model (Stage 2, Epoch 100)
- `classifier_model.pth` — ResNet-BiLSTM AFib classifier
- `model_metadata.json` — Training configuration and final metrics

## Usage

### Option 1: Interactive Demo (Easiest)

Try the model directly in your browser — no code needed:

👉 **[Launch Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)**

Upload an ECG (`.npy` or `.csv`, 2500 samples at 250 Hz) or browse pre-loaded examples.

### Option 2: Download & Use in Python

```python
from huggingface_hub import hf_hub_download
import torch

# Download model files from Hugging Face
diffusion_path = hf_hub_download(
    repo_id="TharakaDil2001/diffusion-ecg-augmentation",
    filename="diffusion_model.pth"
)
classifier_path = hf_hub_download(
    repo_id="TharakaDil2001/diffusion-ecg-augmentation",
    filename="classifier_model.pth"
)

# Load the diffusion model checkpoint
checkpoint = torch.load(diffusion_path, map_location="cpu")

# The checkpoint contains:
# - checkpoint['content_encoder'] → Content Encoder state dict
# - checkpoint['style_encoder']   → Style Encoder state dict
# - checkpoint['unet']            → UNet state dict
# - checkpoint['config']          → Training config with all hyperparameters

# Load the classifier checkpoint
cls_checkpoint = torch.load(classifier_path, map_location="cpu")
# - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict

# To use the full pipeline, clone the repository:
# git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
# See: diffusion_pipeline/final_pipeline/ for model architectures
```

### Option 3: Clone the Full Pipeline

```bash
# Clone the full codebase with all model architectures
git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/

# Download weights
pip install huggingface_hub
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.')
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.')
"
```

> **Note**: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts.

## Citation

```bibtex
@misc{pera_af_detection_2025,
  title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection},
  author={Dilshan, D.M.T. and Karunarathne, K.N.P.},
  year={2025},
  institution={University of Peradeniya, Sri Lanka},
  collaboration={SimulaMet, Oslo, Norway}
}
```

## Links

- [GitHub Repository](https://github.com/vlbthambawita/PERA_AF_Detection)
- [Interactive Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)
- [Paper (coming soon)]()