Add comprehensive model card for EEGDM

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +235 -3
README.md CHANGED
@@ -1,3 +1,235 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: feature-extraction
4
+ ---
5
+
6
+ # EEGDM: EEG Representation Learning via Generative Diffusion Model
7
+
8
+ [📝 Paper](https://huggingface.co/papers/2508.14086) - [🌐 Project Page](https://aimplifier.github.io/projects/eegdm/) - [💻 Code](https://github.com/jhpuah/EEGDM)
9
+
10
+ <div align="center">
11
+ <br>
12
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/title.png" width="166">
13
+ </div>
14
+
15
+ <div align="center">
16
+ <br>
17
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/ssmdp_cap.png" width="1066">
18
+
19
+ </div>
20
+
21
+ <div align="center">
22
+ <br>
23
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/pool_cap.png" width="1066">
24
+ </div>
25
+
26
+ <div align="center">
27
+ <br>
28
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/lft_cap.png" width="1066">
29
+ </div>
30
+
31
+
32
+ ## 🌌 Introduction
33
+
34
+ Our EEGDM is a novel self-supervised diffusion model designed for superior EEG signal representation learning. Unlike traditional "tokenization-then-masking" approaches used in EEG foundation models, EEGDM leverages the power of diffusion models to achieve robust and meaningful representations through progressive noise corruption and denoising.
35
+
36
+ EEGDM is distinguished by three key innovations:
37
+
38
+ 1. **First Application of Diffusion Models for EEG Representation Learning:** This work pioneers the use of diffusion models for extracting EEG signal representations rather than just signal generation and data augmentation, opening up a new research direction in neurological signal processing.
39
+ 2. **Structured State-Space Model Architecture (SSMDP):** EEGDM introduces a specialized neural architecture based on structured state-space models specifically designed for diffusion pre-training, enabling better capture of the temporal dynamics inherent in EEG signals.
40
+ 3. **Latent Fusion Transformer for Downstream Tasks:** The framework incorporates a novel latent fusion transformer (LFT) that effectively utilizes the learned diffusion representations for downstream classification tasks like seizure detection, addressing the challenge of translating generative representations to discriminative tasks.
41
+ The proposed method addresses critical limitations in current EEG analysis, including the difficulty of learning robust representations due to limited high-quality annotations and high signal variability across subjects and conditions, while potentially offering computational advantages over existing transformer-based EEG foundation models.
42
+
43
+ ## 😮 Hightlights
44
+
45
+ • We presented EEGDM, a diffusion model-based framework for learning EEG signal representations and classification of multi-event EEG, extending diffusion model beyond signal generation and data augmentation.
46
+
47
+ • We developed structured state-space model diffusion pretraining (SSMDP) to capture the temporal dynamics of EEG signals and trained it via the forward and reverse process of DDPM for representation learning.
48
+
49
+ • We proposed LFT to leverage and fuse the latent representations from SSMDP for downstream classification tasks.
50
+
51
+ • We empirically compared our method with current state-of-the-art approaches on multi-event dataset TUEV to show its competitiveness and provided a detailed ablation study to analyze its components.
52
+
53
+ ## 📈 Main result
54
+
55
+ <div align="center">
56
+ <br>
57
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result1.png" width="466">
58
+ </div>
59
+
60
+ ## ✂️ Ablation
61
+
62
+ <div align="center">
63
+ <br>
64
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result2.png" width="566">
65
+ </div>
66
+
67
+ <div align="center">
68
+ <br>
69
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result3.png" width="566">
70
+ </div>
71
+
72
+ <div align="center">
73
+ <br>
74
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result4.png" width="566">
75
+ </div>
76
+
77
+ <div align="center">
78
+ <br>
79
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result5.png" width="566">
80
+ </div>
81
+
82
+ <div align="center">
83
+ <br>
84
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result6.png" width="566">
85
+ </div>
86
+
87
+ ## 🧠 Generation Sample
88
+
89
+ <div align="center">
90
+ <br>
91
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/GenerationResult.png" width="566">
92
+ </div>
93
+
94
+ ## ⚙️ Quick Start
95
+
96
+ First, set up the environment with Conda: [https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)
97
+
98
+ ```bash
99
+ conda create -n eegdm python=3.11
100
+ conda activate eegdm
101
+ ```
102
+ Then, install dependencies:
103
+ ```bash
104
+ pip install -r requirements.txt
105
+ ```
106
+ The `requirement.txt` file is exported directly from our working environment (NVIDIA GeForce RTX 4090, CUDA Version: 12.4), if your hardware is incompatible, do the following instead:
107
+
108
+ 1. Install torch following the official guide: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
109
+
110
+ 2. Run:
111
+ ```bash
112
+ pip install numpy==1.26.4 hydra-core mne lightning pyhealth ema-pytorch diffusers einops wandb scipy
113
+ ```
114
+
115
+ We use Weight and Bias ([https://wandb.ai/site/](https://wandb.ai/site/)) for logging, and you will need an account for that. Alternatively, replace instances of `WandbLogger` to your own logger, check Pytorch Lightning documentation for available options: [https://lightning.ai/docs/pytorch/stable/extensions/logging.html](https://lightning.ai/docs/pytorch/stable/extensions/logging.html)
116
+
117
+ ### Usage Examples:
118
+
119
+ ```bash
120
+ python main.py [preprocessing=?] [pretrain=?] [cache=?] [finetune=?] [report=?] [extra=?]
121
+ ```
122
+ Replace "?" with config file name (without extension).
123
+ The file must be put inside "conf", under the directory with the same name.
124
+
125
+ e.g.
126
+ ```bash
127
+ python main.py pretrain=base
128
+ ```
129
+ Run pretraining with config specified in `conf/pretrain/base.yaml`.
130
+
131
+ You can override config in command line,
132
+ see Hydra documentation ([https://hydra.cc/docs/intro/](https://hydra.cc/docs/intro/)). E.g.
133
+ ```bash
134
+ python main.py finetune=base finetune.rng_seeding.seed=10
135
+ ```
136
+ Run finetuning with config specified in `conf/finetune/base.yaml`, and set the rng seed to 10.
137
+
138
+
139
+ `extra` config is special: the function specified in its `target` field will be loaded,
140
+ and the config will be passed to that function. This is a quick and dirty way to add experiments that does not fit well to the established workflow.
141
+
142
+
143
+ ### Experiments:
144
+ **Preprocessing:**
145
+
146
+ We follow the general preprocessing logic of LaBraM: [https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py](https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py)
147
+
148
+ To produce single-channel EEG signal for diffusion model pretraining, run:
149
+ ```bash
150
+ python main.py preprocessing=pretrain
151
+ ```
152
+
153
+ To produce signal for finetuning, run:
154
+ ```bash
155
+ python main.py preprocessing=faithful
156
+ ```
157
+
158
+ **Pre-training:**
159
+
160
+ ```bash
161
+ python main.py pretrain=?
162
+ ```
163
+ Where `?` is `base`, `linear` or `nolaw`.
164
+
165
+ `base` uses cosine noise scheduler and perform mu-law based extreme value suppression. `linear` uses linear noise scheduler, and `nolaw` does not perform value suppression.
166
+
167
+ **Caching:**
168
+
169
+ If noise injection is disabled, the latent tokens can be cached to avoid repeated computation.
170
+
171
+ The test data is untouched during caching.
172
+
173
+ See `conf/cache` for available options.
174
+ ```bash
175
+ python main.py cache=base
176
+ ```
177
+
178
+ **Fine-tuning:**
179
+
180
+ If data is cached, the code will check metadata to ensure that it is consistent with the model hyperparameter.
181
+
182
+ See `conf/finetune` for available options.
183
+
184
+ In our experiment, `finetune.rng_seeding.seed` is set to 0, 1, 2, 3 and 4 to produce 5 checkpoints
185
+
186
+ ```bash
187
+ python main.py finetune=base finetune.rng_seeding.seed=0
188
+ ```
189
+
190
+ **Reporting:**
191
+
192
+ If testing data cannot be distributed evenly across devices, certain data will be duplicated and cause inaccuracy in the reported metrics. Using `report` will avoid this issue.
193
+
194
+ `report` also calculate the mean and standard deviation of metrices of multiple checkpoints.
195
+ ```bash
196
+ python main.py report=base
197
+ ```
198
+
199
+ **Other**
200
+
201
+ Scripts of certain ablation experiments are put in `src/extra`:
202
+ ```bash
203
+ python main.py extra=reduce_sampling extra.rate=0.95
204
+ python main.py extra=no_fusion extra.rng_seeding.seed=0
205
+ python main.py extra=report_no_fusion
206
+ python main.py extra=mean_fusion extra.rng_seeding.seed=0
207
+ python main.py extra=report_mean_fusion
208
+ ```
209
+ All seeds need to be iterated from 0 to 4
210
+
211
+ ## ℹ️ Unused Code
212
+ This repo is still under active development, and left in several pieces of unused/untested code. Any functionality implied by the code but not mentioned in the paper shall be considered experimental. Documentation about these code (if any) might be outdated or unreliable.
213
+
214
+ ## 📖 Citation
215
+
216
+ If you use this work, please cite:
217
+
218
+ ```bibtex
219
+ @misc{puah2025eegdm,
220
+ title={{EEGDM: EEG Representation Learning via Generative Diffusion Model}},
221
+ author={Jian Hao Puah and Jiaheng Li and Suman K. Chakravorty and Chaitanya Kharyal and Yi-Ting Li and Matthew T. Bianchi and Michael D. Place and Mengyu Wang and David W. Bates and David A. Roberson and John W. Guttag},
222
+ year={2025},
223
+ eprint={2508.14086},
224
+ archivePrefix={arXiv},
225
+ primaryClass={cs.LG}
226
+ }
227
+ ```
228
+
229
+ ## 🤝 Acknowledgments
230
+
231
+ This work is inspired by and builds upon various open-source projects and research in diffusion models and EEG processing. We acknowledge the contributions of the communities behind PyTorch, Hugging Face Diffusers, MNE-Python, and other related libraries.
232
+
233
+ ## 💬 Discussion and Collaboration
234
+
235
+ We welcome discussions and collaborations to improve EEGDM. Please feel free to open issues or pull requests on GitHub.