--- license: mit tags: - ecg - atrial-fibrillation - diffusion-model - data-augmentation - medical-ai - time-series - pytorch datasets: - mimic-iv-ecg language: - en pipeline_tag: other --- # Diffusion-Based ECG Augmentation Model A disentangled diffusion model (DiffStyleTS) for generating synthetic ECG signals via style transfer between Atrial Fibrillation (AFib) and Normal Sinus Rhythm classes. Designed for **training data augmentation** in AFib detection systems. ## Model Description This model separates ECG signals into class-invariant **content** (beat morphology) and class-specific **style** (rhythm characteristics), then generates synthetic ECGs by transferring the style of one class onto the content of another. ### Architecture | Component | Parameters | Description | |-----------|-----------|-------------| | Content Encoder (VAE) | 4.3M | Extracts class-invariant temporal patterns using BatchNorm | | Style Encoder (CNN) | 567K | Captures class-discriminative features using InstanceNorm | | Conditional UNet | 14.3M | Denoises with FiLM conditioning on content + style | | **Total** | **19.1M** | | ### Training - **Dataset**: MIMIC-IV ECG (~140,000 segments, Lead II, 250 Hz, 10-second windows) - **Stage 1** (Epochs 1-50): Reconstruction (MSE + Cross-Entropy + KL divergence) - **Stage 2** (Epochs 51-100): Style transfer (MSE + Flip loss + Similarity loss) - **Hardware**: NVIDIA RTX 6000 Ada (48 GB VRAM) - **Training time**: ~22 hours ### Generation - **Method**: SDEdit (60% noise addition, 50 DDIM denoising steps, CFG scale 3.0) - **Filtering**: Clinical plausibility validator (morphological + physiological checks, threshold 0.7) - **Output**: 7,784 accepted synthetic ECGs from test set ## Key Results ### Augmentation Viability (5-Fold Cross-Validation) | Condition | Training Data | Accuracy | F1 Score | |-----------|--------------|----------|----------| | A (Real only) | 18,681 original ECGs | 95.63 ± 0.33% | 95.65 ± 0.35% | | B (Synthetic only) | 7,784 generated (×3) | 85.94 ± 1.32% | 86.70 ± 1.24% | | C (Augmented) | 67% real + 33% synthetic | 95.05 ± 0.50% | 95.09 ± 0.46% | **TOST equivalence test confirms A ≈ C** (p = 0.007, margin ±2%), proving that replacing 33% of real data with synthetic ECGs does not degrade classifier performance. ### Signal Quality | Metric | Value | |--------|-------| | PSNR | 12.58 ± 2.09 dB | | SSIM | 0.471 ± 0.110 | | MSE | 0.005 ± 0.004 | ## Files - `diffusion_model.pth` — Trained diffusion model (Stage 2, Epoch 100) - `classifier_model.pth` — ResNet-BiLSTM AFib classifier - `model_metadata.json` — Training configuration and final metrics ## Usage ### Option 1: Interactive Demo (Easiest) Try the model directly in your browser — no code needed: 👉 **[Launch Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)** Upload an ECG (`.npy` or `.csv`, 2500 samples at 250 Hz) or browse pre-loaded examples. ### Option 2: Download & Use in Python ```python from huggingface_hub import hf_hub_download import torch # Download model files from Hugging Face diffusion_path = hf_hub_download( repo_id="TharakaDil2001/diffusion-ecg-augmentation", filename="diffusion_model.pth" ) classifier_path = hf_hub_download( repo_id="TharakaDil2001/diffusion-ecg-augmentation", filename="classifier_model.pth" ) # Load the diffusion model checkpoint checkpoint = torch.load(diffusion_path, map_location="cpu") # The checkpoint contains: # - checkpoint['content_encoder'] → Content Encoder state dict # - checkpoint['style_encoder'] → Style Encoder state dict # - checkpoint['unet'] → UNet state dict # - checkpoint['config'] → Training config with all hyperparameters # Load the classifier checkpoint cls_checkpoint = torch.load(classifier_path, map_location="cpu") # - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict # To use the full pipeline, clone the repository: # git clone https://github.com/vlbthambawita/PERA_AF_Detection.git # See: diffusion_pipeline/final_pipeline/ for model architectures ``` ### Option 3: Clone the Full Pipeline ```bash # Clone the full codebase with all model architectures git clone https://github.com/vlbthambawita/PERA_AF_Detection.git cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/ # Download weights pip install huggingface_hub python -c " from huggingface_hub import hf_hub_download hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.') hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.') " ``` > **Note**: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts. ## Citation ```bibtex @misc{pera_af_detection_2025, title={Diffusion-Based Data Augmentation for Atrial Fibrillation Detection}, author={Dilshan, D.M.T. and Karunarathne, K.N.P.}, year={2025}, institution={University of Peradeniya, Sri Lanka}, collaboration={SimulaMet, Oslo, Norway} } ``` ## Links - [GitHub Repository](https://github.com/vlbthambawita/PERA_AF_Detection) - [Interactive Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo) - [Paper (coming soon)]()