| | --- |
| | license: mit |
| | tags: |
| | - ecg |
| | - atrial-fibrillation |
| | - diffusion-model |
| | - data-augmentation |
| | - medical-ai |
| | - time-series |
| | - pytorch |
| | datasets: |
| | - mimic-iv-ecg |
| | language: |
| | - en |
| | pipeline_tag: other |
| | --- |
| | |
| | # Diffusion-Based ECG Augmentation Model |
| |
|
| | A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for **training data augmentation** in AFib detection systems. |
| |
|
| | ## Model Description |
| |
|
| | This model separates ECG signals into class-invariant **content** (beat morphology) and class-specific **style** (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another. |
| |
|
| | ### Architecture |
| |
|
| | | Component | Parameters | Description | |
| | |-----------|-----------|-------------| |
| | | Content Encoder (VAE) | 4.3M | Extracts class-invariant temporal patterns using BatchNorm | |
| | | Style Encoder (CNN) | 567K | Captures class-discriminative features using InstanceNorm | |
| | | Conditional UNet | 14.3M | Denoises with FiLM conditioning on content + style | |
| | | **Total** | **19.1M** | | |
| |
|
| | ### Training |
| |
|
| | - **Dataset**: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows) |
| | - **Stage 1** (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence) |
| | - **Stage 2** (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss) |
| | - **Hardware**: NVIDIA RTX 6000 Ada (48 GB VRAM) |
| | - **Training time**: ~22 hours |
| |
|
| | ### Generation |
| |
|
| | - **Method**: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0) |
| | - **Filtering**: Clinical plausibility validator (morphological + physiological checks, threshold 0.7) |
| | - **Output**: 7,784 accepted synthetic ECGs from test set |
| |
|
| | ## Key Results |
| |
|
| | ### Augmentation Viability (5-Fold Cross-Validation) |
| |
|
| | | Condition | Training Data | Accuracy | F1 Score | |
| | |-----------|--------------|----------|----------| |
| | | A (Real only) | 18,681 original ECGs | 95.63 ± 0.33% | 95.65 ± 0.35% | |
| | | B (Synthetic only) | 7,784 generated (×3) | 85.94 ± 1.32% | 86.70 ± 1.24% | |
| | | C (Augmented) | 67% real + 33% synthetic | 95.05 ± 0.50% | 95.09 ± 0.46% | |
| |
|
| | **TOST equivalence test confirms A ≈ C** (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance. |
| |
|
| | ### Signal Quality |
| |
|
| | | Metric | Value | |
| | |--------|-------| |
| | | PSNR | 12.58 ± 2.09 dB | |
| | | SSIM | 0.471 ± 0.110 | |
| | | MSE | 0.005 ± 0.004 | |
| |
|
| | ## Files |
| |
|
| | - `diffusion_model.pth` — Trained diffusion model (Stage 2, Epoch 100) |
| | - `classifier_model.pth` — ResNet-BiLSTM AFib classifier |
| | - `model_metadata.json` — Training configuration and final metrics |
| |
|
| | ## Usage |
| |
|
| | ### Option 1: Interactive Demo (Easiest) |
| |
|
| | Try the model directly in your browser — no code needed: |
| |
|
| | 👉 **[Launch Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)** |
| |
|
| | Upload an ECG (`.npy` or `.csv`, 2500 samples at 250 Hz) or browse pre-loaded examples. |
| |
|
| | ### Option 2: Download & Use in Python |
| |
|
| | ```python |
| | from huggingface_hub import hf_hub_download |
| | import torch |
| | |
| | # Download model files from Hugging Face |
| | diffusion_path = hf_hub_download( |
| | repo_id="TharakaDil2001/diffusion-ecg-augmentation", |
| | filename="diffusion_model.pth" |
| | ) |
| | classifier_path = hf_hub_download( |
| | repo_id="TharakaDil2001/diffusion-ecg-augmentation", |
| | filename="classifier_model.pth" |
| | ) |
| | |
| | # Load the diffusion model checkpoint |
| | checkpoint = torch.load(diffusion_path, map_location="cpu") |
| | |
| | # The checkpoint contains: |
| | # - checkpoint['content_encoder'] → Content Encoder state dict |
| | # - checkpoint['style_encoder'] → Style Encoder state dict |
| | # - checkpoint['unet'] → UNet state dict |
| | # - checkpoint['config'] → Training config with all hyperparameters |
| | |
| | # Load the classifier checkpoint |
| | cls_checkpoint = torch.load(classifier_path, map_location="cpu") |
| | # - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict |
| | |
| | # To use the full pipeline, clone the repository: |
| | # git clone https://github.com/vlbthambawita/PERA_AF_Detection.git |
| | # See: diffusion_pipeline/final_pipeline/ for model architectures |
| | ``` |
| |
|
| | ### Option 3: Clone the Full Pipeline |
| |
|
| | ```bash |
| | # Clone the full codebase with all model architectures |
| | git clone https://github.com/vlbthambawita/PERA_AF_Detection.git |
| | cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/ |
| | |
| | # Download weights |
| | pip install huggingface_hub |
| | python -c " |
| | from huggingface_hub import hf_hub_download |
| | hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.') |
| | hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.') |
| | " |
| | ``` |
| |
|
| | > **Note**: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts. |
| |
|
| | ## Citation |
| |
|
| | ```bibtex |
| | @misc{pera_af_detection_2025, |
| | title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection}, |
| | author={Dilshan, D.M.T. and Karunarathne, K.N.P.}, |
| | year={2025}, |
| | institution={University of Peradeniya, Sri Lanka}, |
| | collaboration={SimulaMet, Oslo, Norway} |
| | } |
| | ``` |
| |
|
| | ## Links |
| |
|
| | - [GitHub Repository](https://github.com/vlbthambawita/PERA_AF_Detection) |
| | - [Interactive Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo) |
| | - [Paper (coming soon)]() |
| |
|