jhpuah commited on
Commit
d0bb8b8
·
verified ·
1 Parent(s): bd8889f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +148 -37
README.md CHANGED
@@ -35,22 +35,23 @@ Our EEGDM is a novel self-supervised diffusion model designed for superior EEG s
35
 
36
  EEGDM is distinguished by three key innovations:
37
 
38
- 1. **First Application of Diffusion Models for EEG Representation Learning:** This work pioneers the use of diffusion models for extracting EEG signal representations rather than just signal generation and data augmentation, opening up a new research direction in neurological signal processing.
39
- 2. **Structured State-Space Model Architecture (SSMDP):** EEGDM introduces a specialized neural architecture based on structured state-space models specifically designed for diffusion pre-training, enabling better capture of the temporal dynamics inherent in EEG signals.
40
- 3. **Latent Fusion Transformer for Downstream Tasks:** The framework incorporates a novel latent fusion transformer (LFT) that effectively utilizes the learned diffusion representations for downstream classification tasks like seizure detection, addressing the challenge of translating generative representations to discriminative tasks.
41
  The proposed method addresses critical limitations in current EEG analysis, including the difficulty of learning robust representations due to limited high-quality annotations and high signal variability across subjects and conditions, while potentially offering computational advantages over existing transformer-based EEG foundation models.
42
 
43
  ## 😮 Hightlights
44
 
45
- • We presented EEGDM, a diffusion model-based framework for learning EEG signal representations and classification of multi-event EEG, extending diffusion model beyond signal generation and data augmentation.
46
 
47
- • We developed structured state-space model diffusion pretraining (SSMDP) to capture the temporal dynamics of EEG signals and trained it via the forward and reverse process of DDPM for representation learning.
48
 
49
  • We proposed LFT to leverage and fuse the latent representations from SSMDP for downstream classification tasks.
50
 
51
- • We empirically compared our method with current state-of-the-art approaches on multi-event dataset TUEV to show its competitiveness and provided a detailed ablation study to analyze its components.
52
 
53
  ## 📈 Main result
 
54
 
55
  <div align="center">
56
  <br>
@@ -58,32 +59,50 @@ The proposed method addresses critical limitations in current EEG analysis, incl
58
  </div>
59
 
60
  ## ✂️ Ablation
 
61
 
62
  <div align="center">
63
- <br>
64
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result2.png" width="566">
65
  </div>
66
 
 
67
  <div align="center">
68
- <br>
69
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result3.png" width="566">
70
  </div>
71
 
 
 
 
72
  <div align="center">
73
- <br>
74
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result4.png" width="566">
75
  </div>
76
 
 
 
 
77
  <div align="center">
78
- <br>
79
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result5.png" width="566">
80
  </div>
81
 
82
- <div align="center">
83
  <br>
 
 
 
84
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result6.png" width="566">
85
  </div>
86
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ## 🧠 Generation Sample
88
 
89
  <div align="center">
@@ -95,10 +114,12 @@ The proposed method addresses critical limitations in current EEG analysis, incl
95
 
96
  * **[2025-07-16]** Initial setup and README update.
97
  * **[2025-08-11]** Main pages and experiment result update.
 
 
98
 
99
  ## ⚙️ Quick Start
100
 
101
- First, set up the environment with Conda: [https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)
102
 
103
  ```bash
104
  conda create -n eegdm python=3.11
@@ -108,54 +129,72 @@ Then, install dependencies:
108
  ```bash
109
  pip install -r requirements.txt
110
  ```
111
- The `requirement.txt` file is exported directly from our working environment (NVIDIA GeForce RTX 4090, CUDA Version: 12.4), if your hardware is incompatible, do the following instead:
112
 
113
- 1. Install torch following the official guide: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
114
 
115
  2. Run:
116
  ```bash
117
  pip install numpy==1.26.4 hydra-core mne lightning pyhealth ema-pytorch diffusers einops wandb scipy
118
  ```
119
 
120
- We use Weight and Bias ([https://wandb.ai/site/](https://wandb.ai/site/)) for logging, and you will need an account for that. Alternatively, replace instances of `WandbLogger` to your own logger, check Pytorch Lightning documentation for available options: [https://lightning.ai/docs/pytorch/stable/extensions/logging.html](https://lightning.ai/docs/pytorch/stable/extensions/logging.html)
 
 
121
 
122
  ### Usage Examples:
123
 
124
  ```bash
125
  python main.py [preprocessing=?] [pretrain=?] [cache=?] [finetune=?] [report=?] [extra=?]
126
  ```
127
- Replace "?" with config file name (without extension).
128
  The file must be put inside "conf", under the directory with the same name.
129
 
130
  e.g.
131
  ```bash
132
  python main.py pretrain=base
133
  ```
134
- Run pretraining with config specified in `conf/pretrain/base.yaml`.
135
 
136
- You can override config in command line,
137
- see Hydra documentation ([https://hydra.cc/docs/intro/](https://hydra.cc/docs/intro/)). E.g.
138
  ```bash
139
  python main.py finetune=base finetune.rng_seeding.seed=10
140
  ```
141
- Run finetuning with config specified in `conf/finetune/base.yaml`, and set the rng seed to 10.
142
 
143
 
144
  `extra` config is special: the function specified in its `target` field will be loaded,
145
- and the config will be passed to that function. This is a quick and dirty way to add experiments that does not fit well to the established workflow.
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  ### Experiments:
149
  **Preprocessing:**
150
 
151
- We follow the general preprocessing logic of LaBraM: [https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py](https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py)
152
 
153
- To produce single-channel EEG signal for diffusion model pretraining, run:
 
 
154
  ```bash
155
  python main.py preprocessing=pretrain
156
  ```
157
 
158
- To produce signal for finetuning, run:
159
  ```bash
160
  python main.py preprocessing=faithful
161
  ```
@@ -167,26 +206,28 @@ python main.py pretrain=?
167
  ```
168
  Where `?` is `base`, `linear` or `nolaw`.
169
 
170
- `base` uses cosine noise scheduler and perform mu-law based extreme value suppression. `linear` uses linear noise scheduler, and `nolaw` does not perform value suppression.
171
 
172
  **Caching:**
173
 
174
- If noise injection is disabled, the latent tokens can be cached to avoid repeated computation.
175
 
176
- The test data is untouched during caching.
177
 
178
- See `conf/cache` for available options.
179
  ```bash
180
  python main.py cache=base
181
  ```
182
 
183
  **Fine-tuning:**
184
 
185
- If data is cached, the code will check metadata to ensure that it is consistent with the model hyperparameter.
 
 
186
 
187
  See `conf/finetune` for available options.
188
 
189
- In our experiment, `finetune.rng_seeding.seed` is set to 0, 1, 2, 3 and 4 to produce 5 checkpoints
190
 
191
  ```bash
192
  python main.py finetune=base finetune.rng_seeding.seed=0
@@ -196,16 +237,16 @@ python main.py finetune=base finetune.rng_seeding.seed=0
196
 
197
  If testing data cannot be distributed evenly across devices, certain data will be duplicated and cause inaccuracy in the reported metrics. Using `report` will avoid this issue.
198
 
199
- `report` also calculate the mean and standard deviation of metrices of multiple checkpoints.
200
  ```bash
201
  python main.py report=base
202
  ```
203
 
204
- **Other**
205
 
206
  Scripts of certain ablation experiments are put in `src/extra`:
207
  ```bash
208
- python main.py extra=reduce_sampling extra.rate=0.95
209
  python main.py extra=no_fusion extra.rng_seeding.seed=0
210
  python main.py extra=report_no_fusion
211
  python main.py extra=mean_fusion extra.rng_seeding.seed=0
@@ -213,14 +254,82 @@ python main.py extra=report_mean_fusion
213
  ```
214
  All seeds need to be iterated from 0 to 4
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  ## ℹ️ Unused Code
217
- This repo is still under active development, and left in several pieces of unused/untested code. Any functionality implied by the code but not mentioned in the paper shall be considered experimental. Documentation about these code (if any) might be outdated or unreliable.
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  ## 📖 Citation
220
 
221
  If you use this work, please cite:
222
 
223
- ```bibtex
224
  @misc{puah2025eegdm,
225
  title={{EEGDM: EEG Representation Learning via Generative Diffusion Model}},
226
  author={Jia Hong Puah and Sim Kuan Goh and Ziwei Zhang and Zixuan Ye and Chow Khuen Chan and Kheng Seang Lim and Si Lei Fong and Kok Sin Woon},
@@ -237,4 +346,6 @@ This work is inspired by and builds upon various open-source projects and resear
237
 
238
  ## 💬 Discussion and Collaboration
239
 
240
- We welcome discussions and collaborations to improve EEGDM. Please feel free to open issues or pull requests on GitHub.
 
 
 
35
 
36
  EEGDM is distinguished by three key innovations:
37
 
38
+ 1. First Application of Diffusion Models for EEG Representation Learning: This work pioneers the use of diffusion models for extracting EEG signal representations rather than just signal generation and data augmentation, opening up a new research direction in neurological signal processing.
39
+ 2. Structured State-Space Model Architecture (SSMDP): EEGDM introduces a specialized neural architecture based on structured state-space models specifically designed for diffusion pre-training, enabling better capture of the temporal dynamics inherent in EEG signals.
40
+ 3. Latent Fusion Transformer for Downstream Tasks: The framework incorporates a novel latent fusion transformer (LFT) that effectively utilizes the learned diffusion representations for downstream classification tasks like seizure detection, addressing the challenge of translating generative representations to discriminative tasks.
41
  The proposed method addresses critical limitations in current EEG analysis, including the difficulty of learning robust representations due to limited high-quality annotations and high signal variability across subjects and conditions, while potentially offering computational advantages over existing transformer-based EEG foundation models.
42
 
43
  ## 😮 Hightlights
44
 
45
+ • We presented EEGDM, a diffusion model-based framework for learning EEG signal representations and classification of multi-event EEG, extending diffusion models beyond signal generation and data augmentation.
46
 
47
+ • We developed the structured state-space model diffusion pretraining (SSMDP) to capture the temporal dynamics of EEG signals and trained it via the forward and reverse process of DDPM for representation learning.
48
 
49
  • We proposed LFT to leverage and fuse the latent representations from SSMDP for downstream classification tasks.
50
 
51
+ • We empirically compared our method with current state-of-the-art approaches on the multi-event dataset TUEV to show its competitiveness and provided a detailed ablation study to analyze its components.
52
 
53
  ## 📈 Main result
54
+ EEGDM outperforms various EEG FMs despite the disadvantage in the volume of training data and the number of trainable parameters. On top of that, finetuning EEGDM will not update the pretrained parameters, allowing one backbone to be used in multiple downstream tasks simultaneously.
55
 
56
  <div align="center">
57
  <br>
 
59
  </div>
60
 
61
  ## ✂️ Ablation
62
+ DDPM is a framework with many moving parts. In this section, we show that our design choices are necessary for improved performance.
63
 
64
  <div align="center">
 
65
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result2.png" width="566">
66
  </div>
67
 
68
+
69
  <div align="center">
 
70
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result3.png" width="566">
71
  </div>
72
 
73
+ <br/>
74
+ Another ablation shows that the latent activities of every part of the diffusion backbones contain classification-efficient representation, and the quality tends to increase as the layers deepen.
75
+
76
  <div align="center">
 
77
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result4.png" width="566">
78
  </div>
79
 
80
+ <br/>
81
+ The latent fusion module is the largest trainable component of the LFT. Here, we show that it is irreplaceable by other non-parameterized methods such as average pooling and flattening.
82
+
83
  <div align="center">
 
84
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result5.png" width="566">
85
  </div>
86
 
 
87
  <br>
88
+ The unique formulation of SSMDP and LFT enables the EEGDM framework to operate at a different sampling rate without retraining, at the cost of degraded performance.
89
+
90
+ <div align="center">
91
  <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result6.png" width="566">
92
  </div>
93
 
94
+ ## 🔀 Generalize to CHB-MIT
95
+ To verify the robustness of the learned representations in cross-domain generalization, we finetuned the model on a dataset with unseen characteristics.
96
+
97
+ More specifically, the model pretrained on TUEV (containing sharp waves and artifacts) is finetuned on CHB-MIT for seizure detection.
98
+
99
+ The results show that EEGDM outperforms other FMs despite having a much smaller pretraining set that lacks variety, indicating high generalizability and robustness.
100
+
101
+ <div align="center">
102
+ <br>
103
+ <img src="https://github.com/jhpuah/EEGDM/raw/main/assets/ResultChb.png" width="566">
104
+ </div>
105
+
106
  ## 🧠 Generation Sample
107
 
108
  <div align="center">
 
114
 
115
  * **[2025-07-16]** Initial setup and README update.
116
  * **[2025-08-11]** Main pages and experiment result update.
117
+ * **[2025-08-27]** Preprint V2.
118
+ * **[2025-10-02]** Update README to match preprint V2.
119
 
120
  ## ⚙️ Quick Start
121
 
122
+ First, set up the environment with Conda: https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html
123
 
124
  ```bash
125
  conda create -n eegdm python=3.11
 
129
  ```bash
130
  pip install -r requirements.txt
131
  ```
132
+ The `requirement.txt` file is exported directly from our working environment (NVIDIA GeForce RTX 4090, CUDA Version: 12.4). If your hardware is incompatible, do the following instead:
133
 
134
+ 1. Install torch following the official guide: https://pytorch.org/get-started/locally/
135
 
136
  2. Run:
137
  ```bash
138
  pip install numpy==1.26.4 hydra-core mne lightning pyhealth ema-pytorch diffusers einops wandb scipy
139
  ```
140
 
141
+ We use Weight and Bias (https://wandb.ai/site/) for logging, and you will need an account for that. Alternatively, replace instances of `WandbLogger` with your own logger, check Pytorch Lightning documentation for available options: https://lightning.ai/docs/pytorch/stable/extensions/logging.html
142
+
143
+
144
 
145
  ### Usage Examples:
146
 
147
  ```bash
148
  python main.py [preprocessing=?] [pretrain=?] [cache=?] [finetune=?] [report=?] [extra=?]
149
  ```
150
+ Replace "?" with the config file name (without extension).
151
  The file must be put inside "conf", under the directory with the same name.
152
 
153
  e.g.
154
  ```bash
155
  python main.py pretrain=base
156
  ```
157
+ Run pretraining with the config specified in `conf/pretrain/base.yaml`.
158
 
159
+ You can override config in the command line,
160
+ see Hydra documentation (https://hydra.cc/docs/intro/). E.g.
161
  ```bash
162
  python main.py finetune=base finetune.rng_seeding.seed=10
163
  ```
164
+ Run finetuning with the config specified in `conf/finetune/base.yaml`, and set the RNG seed to 10.
165
 
166
 
167
  `extra` config is special: the function specified in its `target` field will be loaded,
168
+ and the config will be passed to that function. This is a quick and dirty way to add experiments that do not fit well into the established workflow.
169
+
170
+ An example `extra` config:
171
+ ```yaml
172
+ # Specify the script and function to load
173
+ target:
174
+ _target_: src.util.dynamic_load
175
+ item: src.extra.<script name>.<function name>
176
+
177
+ # Everything will be passed to the specified function
178
+ # Including the "target" field above too
179
+ config1: configcontent
180
+ config2:
181
+ - 1
182
+ - 2
183
+ _ ...
184
+ ```
185
 
186
  ### Experiments:
187
  **Preprocessing:**
188
 
 
189
 
190
+ We follow the general preprocessing logic of LaBraM: https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py
191
+
192
+ To produce the single-channel EEG signals for diffusion model pretraining, run:
193
  ```bash
194
  python main.py preprocessing=pretrain
195
  ```
196
 
197
+ To produce signals for finetuning, run:
198
  ```bash
199
  python main.py preprocessing=faithful
200
  ```
 
206
  ```
207
  Where `?` is `base`, `linear` or `nolaw`.
208
 
209
+ `base` uses cosine noise scheduler and performs mu-law based extreme value suppression. `linear` uses linear noise scheduler, and `nolaw` does not perform value suppression.
210
 
211
  **Caching:**
212
 
213
+ If noise injection is disabled, the latent tokens can be cached to avoid repeated computation. This speeds up finetuning and reduces the memory usage significantly.
214
 
215
+ The test data are untouched during caching: The model can handle cached and uncached data.
216
 
217
+ See `conf/cache` for available options. Note that the size of the cached TUEV is 94 GB, and 480 GB for CHB-MIT.
218
  ```bash
219
  python main.py cache=base
220
  ```
221
 
222
  **Fine-tuning:**
223
 
224
+ <!-- Use `finetune.data_is_cached=<boolean>` to tell -->
225
+
226
+ If data is cached, the code will check the metadata to ensure that it is consistent with the model hyperparameter.
227
 
228
  See `conf/finetune` for available options.
229
 
230
+ In our experiment, `finetune.rng_seeding.seed` is set to 0, 1, 2, 3, and 4 to produce 5 checkpoints
231
 
232
  ```bash
233
  python main.py finetune=base finetune.rng_seeding.seed=0
 
237
 
238
  If testing data cannot be distributed evenly across devices, certain data will be duplicated and cause inaccuracy in the reported metrics. Using `report` will avoid this issue.
239
 
240
+ `report` also calculates the mean and standard deviation of metrics of multiple checkpoints.
241
  ```bash
242
  python main.py report=base
243
  ```
244
 
245
+ **Other Ablation**
246
 
247
  Scripts of certain ablation experiments are put in `src/extra`:
248
  ```bash
249
+ python main.py extra=reduce_sampling extra.rate=0.95 # 200 Hz (original sampling rate) * 0.95 = 190 Hz
250
  python main.py extra=no_fusion extra.rng_seeding.seed=0
251
  python main.py extra=report_no_fusion
252
  python main.py extra=mean_fusion extra.rng_seeding.seed=0
 
254
  ```
255
  All seeds need to be iterated from 0 to 4
256
 
257
+ **CHB-MIT**
258
+
259
+ Using the `backbone.ckpt` pretrained on TUEV, the following commands cache and finetune EEGDM on CHB-MIT, then report the result:
260
+ ```bash
261
+ python main.py cache=base_chbmit
262
+ python main.py finetune=base_chbmit_bin_filt finetune.rng_seeding.seed=0
263
+ python main.py report=base_chbmit_bin
264
+ ```
265
+ All seeds need to be iterated from 0 to 4
266
+
267
+ ## 🔬 Reproducibility
268
+ Pytorch does not guarantee reproducibility across different environments: https://docs.pytorch.org/docs/stable/notes/randomness.html
269
+
270
+ Regardless, we released the checkpoints trained in our environment on HuggingFace:
271
+ * `backbone.ckpt`: The single channel diffusion model trained on TUEV training set, RNG seed 0. This checkpoint allows you to skip `pretrain`, and it is not required to run `report`.
272
+
273
+ * `classifier.ckpt`: The finetuned model on TUEV for EEG event classification, RNG seed 0. This model can be used directly in `report`:
274
+
275
+ ```bash
276
+ python main.py report=base report.checkpoint=["<path to the downloaded checkpoint>"]
277
+ ```
278
+
279
+ * `chbmit_classifier.ckpt`: The finetuned model on CHB-MIT dataset, using the `backbone.ckpt` pretrained on TUEV, RNG seed 0. This model can be used directly in `report`:
280
+
281
+ ```bash
282
+ python main.py report=base_chbmit_bin report.checkpoint=["<path to the downloaded checkpoint>"]
283
+ ```
284
+
285
+
286
+
287
+ <!-- ## Repo Structure
288
+ `main.py` is the entry point of this repo.
289
+
290
+ `src/` contains the scripts for generic
291
+ `src/extra/` contains the scripts of extra...
292
+
293
+ `model/`
294
+
295
+ `dataloader/`
296
+
297
+
298
+ `conf/`
299
+
300
+ Finally, `assets` contains images used in this README file.
301
+
302
+ During pretraining and finetuning, the scripts may create new directories:
303
+ * `data/`: training, validation, and testing data, if `cache` is used, cached latent will be put under `data/cached`
304
+
305
+ * `gen/`: EEG signal samples generated by SSMDP
306
+
307
+ * `checkpoint` contains model checkpoints of SSMDP and LFT.
308
+
309
+ Others are logs by dependencies (`lightning_logs` by Pytorch Lightning, `outputs` by Hydra, etc.). -->
310
+
311
+
312
  ## ℹ️ Unused Code
313
+ This repo is still under active development and has several pieces of unused/untested code. Any functionality implied in the code but not mentioned in the paper shall be considered experimental. Documentation about them (if any) might be outdated or unreliable.
314
+
315
+ In particular, the layerwise learning rate and weight decay for LFT will not work. Best to leave `lrd_kwargs` untouched, or set it to `null`.
316
+
317
+ ## 🗺️ Roadmap
318
+ Current aim: clean up the mess by Dec 2025
319
+ * Proper documentation of class parameters and available options, add user-friendly error messages
320
+ * Refactor `model.classifier.MHAStack`: it makes calculating the depth of a layer unnecessarily complicated, hindering the implementation of layerwise learning rate decay
321
+ * Cleanup config files: most files are copy-pasted from the respective `base.yaml`, only changing one or two lines, there must be a better way
322
+ * `hydra.utils.instantiate` + `src.util.dynamic_load`: horrible
323
+ * Rename classes: `dataloader.TUEVDataset` is used for other dataset as well, `model.classifier` should be `LatentFusionTransfromer`, etc.
324
+ * Optimize the code: parallelize `cache` and `report`, optimize checkpoint size, check `TODO`s in code ...
325
+ * Remove `preprocessing` from the workflow, it should be a directory with standalone scripts, like in other FM repo
326
+ * Remove unused code
327
 
328
  ## 📖 Citation
329
 
330
  If you use this work, please cite:
331
 
332
+ ```
333
  @misc{puah2025eegdm,
334
  title={{EEGDM: EEG Representation Learning via Generative Diffusion Model}},
335
  author={Jia Hong Puah and Sim Kuan Goh and Ziwei Zhang and Zixuan Ye and Chow Khuen Chan and Kheng Seang Lim and Si Lei Fong and Kok Sin Woon},
 
346
 
347
  ## 💬 Discussion and Collaboration
348
 
349
+ We welcome discussions and collaborations to improve EEGDM. Please feel free to open issues or pull requests on GitHub.
350
+
351
+