Update README.md
Browse files
README.md
CHANGED
|
@@ -35,22 +35,23 @@ Our EEGDM is a novel self-supervised diffusion model designed for superior EEG s
|
|
| 35 |
|
| 36 |
EEGDM is distinguished by three key innovations:
|
| 37 |
|
| 38 |
-
1.
|
| 39 |
-
2.
|
| 40 |
-
3.
|
| 41 |
The proposed method addresses critical limitations in current EEG analysis, including the difficulty of learning robust representations due to limited high-quality annotations and high signal variability across subjects and conditions, while potentially offering computational advantages over existing transformer-based EEG foundation models.
|
| 42 |
|
| 43 |
## 😮 Hightlights
|
| 44 |
|
| 45 |
-
• We presented EEGDM, a diffusion model-based framework for learning EEG signal representations and classification of multi-event EEG, extending diffusion
|
| 46 |
|
| 47 |
-
• We developed structured state-space model diffusion pretraining (SSMDP) to capture the temporal dynamics of EEG signals and trained it via the forward and reverse process of DDPM for representation learning.
|
| 48 |
|
| 49 |
• We proposed LFT to leverage and fuse the latent representations from SSMDP for downstream classification tasks.
|
| 50 |
|
| 51 |
-
• We empirically compared our method with current state-of-the-art approaches on multi-event dataset TUEV to show its competitiveness and provided a detailed ablation study to analyze its components.
|
| 52 |
|
| 53 |
## 📈 Main result
|
|
|
|
| 54 |
|
| 55 |
<div align="center">
|
| 56 |
<br>
|
|
@@ -58,32 +59,50 @@ The proposed method addresses critical limitations in current EEG analysis, incl
|
|
| 58 |
</div>
|
| 59 |
|
| 60 |
## ✂️ Ablation
|
|
|
|
| 61 |
|
| 62 |
<div align="center">
|
| 63 |
-
<br>
|
| 64 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result2.png" width="566">
|
| 65 |
</div>
|
| 66 |
|
|
|
|
| 67 |
<div align="center">
|
| 68 |
-
<br>
|
| 69 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result3.png" width="566">
|
| 70 |
</div>
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
<div align="center">
|
| 73 |
-
<br>
|
| 74 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result4.png" width="566">
|
| 75 |
</div>
|
| 76 |
|
|
|
|
|
|
|
|
|
|
| 77 |
<div align="center">
|
| 78 |
-
<br>
|
| 79 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result5.png" width="566">
|
| 80 |
</div>
|
| 81 |
|
| 82 |
-
<div align="center">
|
| 83 |
<br>
|
|
|
|
|
|
|
|
|
|
| 84 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result6.png" width="566">
|
| 85 |
</div>
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
## 🧠 Generation Sample
|
| 88 |
|
| 89 |
<div align="center">
|
|
@@ -95,10 +114,12 @@ The proposed method addresses critical limitations in current EEG analysis, incl
|
|
| 95 |
|
| 96 |
* **[2025-07-16]** Initial setup and README update.
|
| 97 |
* **[2025-08-11]** Main pages and experiment result update.
|
|
|
|
|
|
|
| 98 |
|
| 99 |
## ⚙️ Quick Start
|
| 100 |
|
| 101 |
-
First, set up the environment with Conda:
|
| 102 |
|
| 103 |
```bash
|
| 104 |
conda create -n eegdm python=3.11
|
|
@@ -108,54 +129,72 @@ Then, install dependencies:
|
|
| 108 |
```bash
|
| 109 |
pip install -r requirements.txt
|
| 110 |
```
|
| 111 |
-
The `requirement.txt` file is exported directly from our working environment (NVIDIA GeForce RTX 4090, CUDA Version: 12.4)
|
| 112 |
|
| 113 |
-
1. Install torch following the official guide:
|
| 114 |
|
| 115 |
2. Run:
|
| 116 |
```bash
|
| 117 |
pip install numpy==1.26.4 hydra-core mne lightning pyhealth ema-pytorch diffusers einops wandb scipy
|
| 118 |
```
|
| 119 |
|
| 120 |
-
We use Weight and Bias (
|
|
|
|
|
|
|
| 121 |
|
| 122 |
### Usage Examples:
|
| 123 |
|
| 124 |
```bash
|
| 125 |
python main.py [preprocessing=?] [pretrain=?] [cache=?] [finetune=?] [report=?] [extra=?]
|
| 126 |
```
|
| 127 |
-
Replace "?" with config file name (without extension).
|
| 128 |
The file must be put inside "conf", under the directory with the same name.
|
| 129 |
|
| 130 |
e.g.
|
| 131 |
```bash
|
| 132 |
python main.py pretrain=base
|
| 133 |
```
|
| 134 |
-
Run pretraining with config specified in `conf/pretrain/base.yaml`.
|
| 135 |
|
| 136 |
-
You can override config in command line,
|
| 137 |
-
see Hydra documentation (
|
| 138 |
```bash
|
| 139 |
python main.py finetune=base finetune.rng_seeding.seed=10
|
| 140 |
```
|
| 141 |
-
Run finetuning with config specified in `conf/finetune/base.yaml`, and set the
|
| 142 |
|
| 143 |
|
| 144 |
`extra` config is special: the function specified in its `target` field will be loaded,
|
| 145 |
-
and the config will be passed to that function. This is a quick and dirty way to add experiments that
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
### Experiments:
|
| 149 |
**Preprocessing:**
|
| 150 |
|
| 151 |
-
We follow the general preprocessing logic of LaBraM: [https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py](https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py)
|
| 152 |
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
```bash
|
| 155 |
python main.py preprocessing=pretrain
|
| 156 |
```
|
| 157 |
|
| 158 |
-
To produce
|
| 159 |
```bash
|
| 160 |
python main.py preprocessing=faithful
|
| 161 |
```
|
|
@@ -167,26 +206,28 @@ python main.py pretrain=?
|
|
| 167 |
```
|
| 168 |
Where `?` is `base`, `linear` or `nolaw`.
|
| 169 |
|
| 170 |
-
`base` uses cosine noise scheduler and
|
| 171 |
|
| 172 |
**Caching:**
|
| 173 |
|
| 174 |
-
If noise injection is disabled, the latent tokens can be cached to avoid repeated computation.
|
| 175 |
|
| 176 |
-
The test data
|
| 177 |
|
| 178 |
-
See `conf/cache` for available options.
|
| 179 |
```bash
|
| 180 |
python main.py cache=base
|
| 181 |
```
|
| 182 |
|
| 183 |
**Fine-tuning:**
|
| 184 |
|
| 185 |
-
|
|
|
|
|
|
|
| 186 |
|
| 187 |
See `conf/finetune` for available options.
|
| 188 |
|
| 189 |
-
In our experiment, `finetune.rng_seeding.seed` is set to 0, 1, 2, 3 and 4 to produce 5 checkpoints
|
| 190 |
|
| 191 |
```bash
|
| 192 |
python main.py finetune=base finetune.rng_seeding.seed=0
|
|
@@ -196,16 +237,16 @@ python main.py finetune=base finetune.rng_seeding.seed=0
|
|
| 196 |
|
| 197 |
If testing data cannot be distributed evenly across devices, certain data will be duplicated and cause inaccuracy in the reported metrics. Using `report` will avoid this issue.
|
| 198 |
|
| 199 |
-
`report` also
|
| 200 |
```bash
|
| 201 |
python main.py report=base
|
| 202 |
```
|
| 203 |
|
| 204 |
-
**Other**
|
| 205 |
|
| 206 |
Scripts of certain ablation experiments are put in `src/extra`:
|
| 207 |
```bash
|
| 208 |
-
python main.py extra=reduce_sampling extra.rate=0.95
|
| 209 |
python main.py extra=no_fusion extra.rng_seeding.seed=0
|
| 210 |
python main.py extra=report_no_fusion
|
| 211 |
python main.py extra=mean_fusion extra.rng_seeding.seed=0
|
|
@@ -213,14 +254,82 @@ python main.py extra=report_mean_fusion
|
|
| 213 |
```
|
| 214 |
All seeds need to be iterated from 0 to 4
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
## ℹ️ Unused Code
|
| 217 |
-
This repo is still under active development
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
## 📖 Citation
|
| 220 |
|
| 221 |
If you use this work, please cite:
|
| 222 |
|
| 223 |
-
```
|
| 224 |
@misc{puah2025eegdm,
|
| 225 |
title={{EEGDM: EEG Representation Learning via Generative Diffusion Model}},
|
| 226 |
author={Jia Hong Puah and Sim Kuan Goh and Ziwei Zhang and Zixuan Ye and Chow Khuen Chan and Kheng Seang Lim and Si Lei Fong and Kok Sin Woon},
|
|
@@ -237,4 +346,6 @@ This work is inspired by and builds upon various open-source projects and resear
|
|
| 237 |
|
| 238 |
## 💬 Discussion and Collaboration
|
| 239 |
|
| 240 |
-
We welcome discussions and collaborations to improve EEGDM. Please feel free to open issues or pull requests on GitHub.
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
EEGDM is distinguished by three key innovations:
|
| 37 |
|
| 38 |
+
1. First Application of Diffusion Models for EEG Representation Learning: This work pioneers the use of diffusion models for extracting EEG signal representations rather than just signal generation and data augmentation, opening up a new research direction in neurological signal processing.
|
| 39 |
+
2. Structured State-Space Model Architecture (SSMDP): EEGDM introduces a specialized neural architecture based on structured state-space models specifically designed for diffusion pre-training, enabling better capture of the temporal dynamics inherent in EEG signals.
|
| 40 |
+
3. Latent Fusion Transformer for Downstream Tasks: The framework incorporates a novel latent fusion transformer (LFT) that effectively utilizes the learned diffusion representations for downstream classification tasks like seizure detection, addressing the challenge of translating generative representations to discriminative tasks.
|
| 41 |
The proposed method addresses critical limitations in current EEG analysis, including the difficulty of learning robust representations due to limited high-quality annotations and high signal variability across subjects and conditions, while potentially offering computational advantages over existing transformer-based EEG foundation models.
|
| 42 |
|
| 43 |
## 😮 Hightlights
|
| 44 |
|
| 45 |
+
• We presented EEGDM, a diffusion model-based framework for learning EEG signal representations and classification of multi-event EEG, extending diffusion models beyond signal generation and data augmentation.
|
| 46 |
|
| 47 |
+
• We developed the structured state-space model diffusion pretraining (SSMDP) to capture the temporal dynamics of EEG signals and trained it via the forward and reverse process of DDPM for representation learning.
|
| 48 |
|
| 49 |
• We proposed LFT to leverage and fuse the latent representations from SSMDP for downstream classification tasks.
|
| 50 |
|
| 51 |
+
• We empirically compared our method with current state-of-the-art approaches on the multi-event dataset TUEV to show its competitiveness and provided a detailed ablation study to analyze its components.
|
| 52 |
|
| 53 |
## 📈 Main result
|
| 54 |
+
EEGDM outperforms various EEG FMs despite the disadvantage in the volume of training data and the number of trainable parameters. On top of that, finetuning EEGDM will not update the pretrained parameters, allowing one backbone to be used in multiple downstream tasks simultaneously.
|
| 55 |
|
| 56 |
<div align="center">
|
| 57 |
<br>
|
|
|
|
| 59 |
</div>
|
| 60 |
|
| 61 |
## ✂️ Ablation
|
| 62 |
+
DDPM is a framework with many moving parts. In this section, we show that our design choices are necessary for improved performance.
|
| 63 |
|
| 64 |
<div align="center">
|
|
|
|
| 65 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result2.png" width="566">
|
| 66 |
</div>
|
| 67 |
|
| 68 |
+
|
| 69 |
<div align="center">
|
|
|
|
| 70 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result3.png" width="566">
|
| 71 |
</div>
|
| 72 |
|
| 73 |
+
<br/>
|
| 74 |
+
Another ablation shows that the latent activities of every part of the diffusion backbones contain classification-efficient representation, and the quality tends to increase as the layers deepen.
|
| 75 |
+
|
| 76 |
<div align="center">
|
|
|
|
| 77 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result4.png" width="566">
|
| 78 |
</div>
|
| 79 |
|
| 80 |
+
<br/>
|
| 81 |
+
The latent fusion module is the largest trainable component of the LFT. Here, we show that it is irreplaceable by other non-parameterized methods such as average pooling and flattening.
|
| 82 |
+
|
| 83 |
<div align="center">
|
|
|
|
| 84 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result5.png" width="566">
|
| 85 |
</div>
|
| 86 |
|
|
|
|
| 87 |
<br>
|
| 88 |
+
The unique formulation of SSMDP and LFT enables the EEGDM framework to operate at a different sampling rate without retraining, at the cost of degraded performance.
|
| 89 |
+
|
| 90 |
+
<div align="center">
|
| 91 |
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/Result6.png" width="566">
|
| 92 |
</div>
|
| 93 |
|
| 94 |
+
## 🔀 Generalize to CHB-MIT
|
| 95 |
+
To verify the robustness of the learned representations in cross-domain generalization, we finetuned the model on a dataset with unseen characteristics.
|
| 96 |
+
|
| 97 |
+
More specifically, the model pretrained on TUEV (containing sharp waves and artifacts) is finetuned on CHB-MIT for seizure detection.
|
| 98 |
+
|
| 99 |
+
The results show that EEGDM outperforms other FMs despite having a much smaller pretraining set that lacks variety, indicating high generalizability and robustness.
|
| 100 |
+
|
| 101 |
+
<div align="center">
|
| 102 |
+
<br>
|
| 103 |
+
<img src="https://github.com/jhpuah/EEGDM/raw/main/assets/ResultChb.png" width="566">
|
| 104 |
+
</div>
|
| 105 |
+
|
| 106 |
## 🧠 Generation Sample
|
| 107 |
|
| 108 |
<div align="center">
|
|
|
|
| 114 |
|
| 115 |
* **[2025-07-16]** Initial setup and README update.
|
| 116 |
* **[2025-08-11]** Main pages and experiment result update.
|
| 117 |
+
* **[2025-08-27]** Preprint V2.
|
| 118 |
+
* **[2025-10-02]** Update README to match preprint V2.
|
| 119 |
|
| 120 |
## ⚙️ Quick Start
|
| 121 |
|
| 122 |
+
First, set up the environment with Conda: https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html
|
| 123 |
|
| 124 |
```bash
|
| 125 |
conda create -n eegdm python=3.11
|
|
|
|
| 129 |
```bash
|
| 130 |
pip install -r requirements.txt
|
| 131 |
```
|
| 132 |
+
The `requirement.txt` file is exported directly from our working environment (NVIDIA GeForce RTX 4090, CUDA Version: 12.4). If your hardware is incompatible, do the following instead:
|
| 133 |
|
| 134 |
+
1. Install torch following the official guide: https://pytorch.org/get-started/locally/
|
| 135 |
|
| 136 |
2. Run:
|
| 137 |
```bash
|
| 138 |
pip install numpy==1.26.4 hydra-core mne lightning pyhealth ema-pytorch diffusers einops wandb scipy
|
| 139 |
```
|
| 140 |
|
| 141 |
+
We use Weight and Bias (https://wandb.ai/site/) for logging, and you will need an account for that. Alternatively, replace instances of `WandbLogger` with your own logger, check Pytorch Lightning documentation for available options: https://lightning.ai/docs/pytorch/stable/extensions/logging.html
|
| 142 |
+
|
| 143 |
+
|
| 144 |
|
| 145 |
### Usage Examples:
|
| 146 |
|
| 147 |
```bash
|
| 148 |
python main.py [preprocessing=?] [pretrain=?] [cache=?] [finetune=?] [report=?] [extra=?]
|
| 149 |
```
|
| 150 |
+
Replace "?" with the config file name (without extension).
|
| 151 |
The file must be put inside "conf", under the directory with the same name.
|
| 152 |
|
| 153 |
e.g.
|
| 154 |
```bash
|
| 155 |
python main.py pretrain=base
|
| 156 |
```
|
| 157 |
+
Run pretraining with the config specified in `conf/pretrain/base.yaml`.
|
| 158 |
|
| 159 |
+
You can override config in the command line,
|
| 160 |
+
see Hydra documentation (https://hydra.cc/docs/intro/). E.g.
|
| 161 |
```bash
|
| 162 |
python main.py finetune=base finetune.rng_seeding.seed=10
|
| 163 |
```
|
| 164 |
+
Run finetuning with the config specified in `conf/finetune/base.yaml`, and set the RNG seed to 10.
|
| 165 |
|
| 166 |
|
| 167 |
`extra` config is special: the function specified in its `target` field will be loaded,
|
| 168 |
+
and the config will be passed to that function. This is a quick and dirty way to add experiments that do not fit well into the established workflow.
|
| 169 |
+
|
| 170 |
+
An example `extra` config:
|
| 171 |
+
```yaml
|
| 172 |
+
# Specify the script and function to load
|
| 173 |
+
target:
|
| 174 |
+
_target_: src.util.dynamic_load
|
| 175 |
+
item: src.extra.<script name>.<function name>
|
| 176 |
+
|
| 177 |
+
# Everything will be passed to the specified function
|
| 178 |
+
# Including the "target" field above too
|
| 179 |
+
config1: configcontent
|
| 180 |
+
config2:
|
| 181 |
+
- 1
|
| 182 |
+
- 2
|
| 183 |
+
_ ...
|
| 184 |
+
```
|
| 185 |
|
| 186 |
### Experiments:
|
| 187 |
**Preprocessing:**
|
| 188 |
|
|
|
|
| 189 |
|
| 190 |
+
We follow the general preprocessing logic of LaBraM: https://github.com/935963004/LaBraM/blob/main/dataset_maker/make_TUEV.py
|
| 191 |
+
|
| 192 |
+
To produce the single-channel EEG signals for diffusion model pretraining, run:
|
| 193 |
```bash
|
| 194 |
python main.py preprocessing=pretrain
|
| 195 |
```
|
| 196 |
|
| 197 |
+
To produce signals for finetuning, run:
|
| 198 |
```bash
|
| 199 |
python main.py preprocessing=faithful
|
| 200 |
```
|
|
|
|
| 206 |
```
|
| 207 |
Where `?` is `base`, `linear` or `nolaw`.
|
| 208 |
|
| 209 |
+
`base` uses cosine noise scheduler and performs mu-law based extreme value suppression. `linear` uses linear noise scheduler, and `nolaw` does not perform value suppression.
|
| 210 |
|
| 211 |
**Caching:**
|
| 212 |
|
| 213 |
+
If noise injection is disabled, the latent tokens can be cached to avoid repeated computation. This speeds up finetuning and reduces the memory usage significantly.
|
| 214 |
|
| 215 |
+
The test data are untouched during caching: The model can handle cached and uncached data.
|
| 216 |
|
| 217 |
+
See `conf/cache` for available options. Note that the size of the cached TUEV is 94 GB, and 480 GB for CHB-MIT.
|
| 218 |
```bash
|
| 219 |
python main.py cache=base
|
| 220 |
```
|
| 221 |
|
| 222 |
**Fine-tuning:**
|
| 223 |
|
| 224 |
+
<!-- Use `finetune.data_is_cached=<boolean>` to tell -->
|
| 225 |
+
|
| 226 |
+
If data is cached, the code will check the metadata to ensure that it is consistent with the model hyperparameter.
|
| 227 |
|
| 228 |
See `conf/finetune` for available options.
|
| 229 |
|
| 230 |
+
In our experiment, `finetune.rng_seeding.seed` is set to 0, 1, 2, 3, and 4 to produce 5 checkpoints
|
| 231 |
|
| 232 |
```bash
|
| 233 |
python main.py finetune=base finetune.rng_seeding.seed=0
|
|
|
|
| 237 |
|
| 238 |
If testing data cannot be distributed evenly across devices, certain data will be duplicated and cause inaccuracy in the reported metrics. Using `report` will avoid this issue.
|
| 239 |
|
| 240 |
+
`report` also calculates the mean and standard deviation of metrics of multiple checkpoints.
|
| 241 |
```bash
|
| 242 |
python main.py report=base
|
| 243 |
```
|
| 244 |
|
| 245 |
+
**Other Ablation**
|
| 246 |
|
| 247 |
Scripts of certain ablation experiments are put in `src/extra`:
|
| 248 |
```bash
|
| 249 |
+
python main.py extra=reduce_sampling extra.rate=0.95 # 200 Hz (original sampling rate) * 0.95 = 190 Hz
|
| 250 |
python main.py extra=no_fusion extra.rng_seeding.seed=0
|
| 251 |
python main.py extra=report_no_fusion
|
| 252 |
python main.py extra=mean_fusion extra.rng_seeding.seed=0
|
|
|
|
| 254 |
```
|
| 255 |
All seeds need to be iterated from 0 to 4
|
| 256 |
|
| 257 |
+
**CHB-MIT**
|
| 258 |
+
|
| 259 |
+
Using the `backbone.ckpt` pretrained on TUEV, the following commands cache and finetune EEGDM on CHB-MIT, then report the result:
|
| 260 |
+
```bash
|
| 261 |
+
python main.py cache=base_chbmit
|
| 262 |
+
python main.py finetune=base_chbmit_bin_filt finetune.rng_seeding.seed=0
|
| 263 |
+
python main.py report=base_chbmit_bin
|
| 264 |
+
```
|
| 265 |
+
All seeds need to be iterated from 0 to 4
|
| 266 |
+
|
| 267 |
+
## 🔬 Reproducibility
|
| 268 |
+
Pytorch does not guarantee reproducibility across different environments: https://docs.pytorch.org/docs/stable/notes/randomness.html
|
| 269 |
+
|
| 270 |
+
Regardless, we released the checkpoints trained in our environment on HuggingFace:
|
| 271 |
+
* `backbone.ckpt`: The single channel diffusion model trained on TUEV training set, RNG seed 0. This checkpoint allows you to skip `pretrain`, and it is not required to run `report`.
|
| 272 |
+
|
| 273 |
+
* `classifier.ckpt`: The finetuned model on TUEV for EEG event classification, RNG seed 0. This model can be used directly in `report`:
|
| 274 |
+
|
| 275 |
+
```bash
|
| 276 |
+
python main.py report=base report.checkpoint=["<path to the downloaded checkpoint>"]
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
* `chbmit_classifier.ckpt`: The finetuned model on CHB-MIT dataset, using the `backbone.ckpt` pretrained on TUEV, RNG seed 0. This model can be used directly in `report`:
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
python main.py report=base_chbmit_bin report.checkpoint=["<path to the downloaded checkpoint>"]
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
<!-- ## Repo Structure
|
| 288 |
+
`main.py` is the entry point of this repo.
|
| 289 |
+
|
| 290 |
+
`src/` contains the scripts for generic
|
| 291 |
+
`src/extra/` contains the scripts of extra...
|
| 292 |
+
|
| 293 |
+
`model/`
|
| 294 |
+
|
| 295 |
+
`dataloader/`
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
`conf/`
|
| 299 |
+
|
| 300 |
+
Finally, `assets` contains images used in this README file.
|
| 301 |
+
|
| 302 |
+
During pretraining and finetuning, the scripts may create new directories:
|
| 303 |
+
* `data/`: training, validation, and testing data, if `cache` is used, cached latent will be put under `data/cached`
|
| 304 |
+
|
| 305 |
+
* `gen/`: EEG signal samples generated by SSMDP
|
| 306 |
+
|
| 307 |
+
* `checkpoint` contains model checkpoints of SSMDP and LFT.
|
| 308 |
+
|
| 309 |
+
Others are logs by dependencies (`lightning_logs` by Pytorch Lightning, `outputs` by Hydra, etc.). -->
|
| 310 |
+
|
| 311 |
+
|
| 312 |
## ℹ️ Unused Code
|
| 313 |
+
This repo is still under active development and has several pieces of unused/untested code. Any functionality implied in the code but not mentioned in the paper shall be considered experimental. Documentation about them (if any) might be outdated or unreliable.
|
| 314 |
+
|
| 315 |
+
In particular, the layerwise learning rate and weight decay for LFT will not work. Best to leave `lrd_kwargs` untouched, or set it to `null`.
|
| 316 |
+
|
| 317 |
+
## 🗺️ Roadmap
|
| 318 |
+
Current aim: clean up the mess by Dec 2025
|
| 319 |
+
* Proper documentation of class parameters and available options, add user-friendly error messages
|
| 320 |
+
* Refactor `model.classifier.MHAStack`: it makes calculating the depth of a layer unnecessarily complicated, hindering the implementation of layerwise learning rate decay
|
| 321 |
+
* Cleanup config files: most files are copy-pasted from the respective `base.yaml`, only changing one or two lines, there must be a better way
|
| 322 |
+
* `hydra.utils.instantiate` + `src.util.dynamic_load`: horrible
|
| 323 |
+
* Rename classes: `dataloader.TUEVDataset` is used for other dataset as well, `model.classifier` should be `LatentFusionTransfromer`, etc.
|
| 324 |
+
* Optimize the code: parallelize `cache` and `report`, optimize checkpoint size, check `TODO`s in code ...
|
| 325 |
+
* Remove `preprocessing` from the workflow, it should be a directory with standalone scripts, like in other FM repo
|
| 326 |
+
* Remove unused code
|
| 327 |
|
| 328 |
## 📖 Citation
|
| 329 |
|
| 330 |
If you use this work, please cite:
|
| 331 |
|
| 332 |
+
```
|
| 333 |
@misc{puah2025eegdm,
|
| 334 |
title={{EEGDM: EEG Representation Learning via Generative Diffusion Model}},
|
| 335 |
author={Jia Hong Puah and Sim Kuan Goh and Ziwei Zhang and Zixuan Ye and Chow Khuen Chan and Kheng Seang Lim and Si Lei Fong and Kok Sin Woon},
|
|
|
|
| 346 |
|
| 347 |
## 💬 Discussion and Collaboration
|
| 348 |
|
| 349 |
+
We welcome discussions and collaborations to improve EEGDM. Please feel free to open issues or pull requests on GitHub.
|
| 350 |
+
|
| 351 |
+
|