Update usage instructions with huggingface_hub download and multiple usage options
Browse files
README.md
CHANGED
|
@@ -74,22 +74,66 @@ This model separates ECG signals into class-invariant **content** (beat morpholo
|
|
| 74 |
|
| 75 |
## Usage
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
```python
|
|
|
|
| 78 |
import torch
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Load the diffusion model checkpoint
|
| 81 |
-
checkpoint = torch.load(
|
| 82 |
|
| 83 |
# The checkpoint contains:
|
| 84 |
-
# - content_encoder state dict
|
| 85 |
-
# - style_encoder state dict
|
| 86 |
-
# - unet state dict
|
| 87 |
-
# - config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
```
|
| 92 |
|
|
|
|
|
|
|
| 93 |
## Citation
|
| 94 |
|
| 95 |
```bibtex
|
|
|
|
| 74 |
|
| 75 |
## Usage
|
| 76 |
|
| 77 |
+
### Option 1: Interactive Demo (Easiest)
|
| 78 |
+
|
| 79 |
+
Try the model directly in your browser — no code needed:
|
| 80 |
+
|
| 81 |
+
👉 **[Launch Demo](https://huggingface.co/spaces/TharakaDil2001/ecg-augmentation-demo)**
|
| 82 |
+
|
| 83 |
+
Upload an ECG (`.npy` or `.csv`, 2500 samples at 250 Hz) or browse pre-loaded examples.
|
| 84 |
+
|
| 85 |
+
### Option 2: Download & Use in Python
|
| 86 |
+
|
| 87 |
```python
|
| 88 |
+
from huggingface_hub import hf_hub_download
|
| 89 |
import torch
|
| 90 |
|
| 91 |
+
# Download model files from Hugging Face
|
| 92 |
+
diffusion_path = hf_hub_download(
|
| 93 |
+
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
|
| 94 |
+
filename="diffusion_model.pth"
|
| 95 |
+
)
|
| 96 |
+
classifier_path = hf_hub_download(
|
| 97 |
+
repo_id="TharakaDil2001/diffusion-ecg-augmentation",
|
| 98 |
+
filename="classifier_model.pth"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
# Load the diffusion model checkpoint
|
| 102 |
+
checkpoint = torch.load(diffusion_path, map_location="cpu")
|
| 103 |
|
| 104 |
# The checkpoint contains:
|
| 105 |
+
# - checkpoint['content_encoder'] → Content Encoder state dict
|
| 106 |
+
# - checkpoint['style_encoder'] → Style Encoder state dict
|
| 107 |
+
# - checkpoint['unet'] → UNet state dict
|
| 108 |
+
# - checkpoint['config'] → Training config with all hyperparameters
|
| 109 |
+
|
| 110 |
+
# Load the classifier checkpoint
|
| 111 |
+
cls_checkpoint = torch.load(classifier_path, map_location="cpu")
|
| 112 |
+
# - cls_checkpoint['model_state_dict'] → AFibResLSTM state dict
|
| 113 |
+
|
| 114 |
+
# To use the full pipeline, clone the repository:
|
| 115 |
+
# git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
|
| 116 |
+
# See: diffusion_pipeline/final_pipeline/ for model architectures
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### Option 3: Clone the Full Pipeline
|
| 120 |
|
| 121 |
+
```bash
|
| 122 |
+
# Clone the full codebase with all model architectures
|
| 123 |
+
git clone https://github.com/vlbthambawita/PERA_AF_Detection.git
|
| 124 |
+
cd PERA_AF_Detection/diffusion_pipeline/final_pipeline/
|
| 125 |
+
|
| 126 |
+
# Download weights
|
| 127 |
+
pip install huggingface_hub
|
| 128 |
+
python -c "
|
| 129 |
+
from huggingface_hub import hf_hub_download
|
| 130 |
+
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'diffusion_model.pth', local_dir='.')
|
| 131 |
+
hf_hub_download('TharakaDil2001/diffusion-ecg-augmentation', 'classifier_model.pth', local_dir='.')
|
| 132 |
+
"
|
| 133 |
```
|
| 134 |
|
| 135 |
+
> **Note**: The model architectures (DiffStyleTS, AFibResLSTM) are defined in the repository code. You need the architecture classes to instantiate the models before loading the state dicts.
|
| 136 |
+
|
| 137 |
## Citation
|
| 138 |
|
| 139 |
```bibtex
|