eugenehp commited on
Commit
fd53db3
Β·
verified Β·
1 Parent(s): ce89add

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +292 -0
README.md ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - eeg
5
+ - bci
6
+ - brain-computer-interface
7
+ - foundation-model
8
+ - vit
9
+ - masked-autoencoder
10
+ - mae
11
+ - neuroscience
12
+ - safetensors
13
+ - burn
14
+ - rust
15
+ language:
16
+ - en
17
+ library_name: steegformer-rs
18
+ pipeline_tag: feature-extraction
19
+ ---
20
+
21
+ # ST-EEGFormer β€” Safetensors Weights
22
+
23
+ Pre-converted [safetensors](https://github.com/huggingface/safetensors) weights for the [ST-EEGFormer](https://github.com/LiuyinYang1101/STEEGFormer) EEG foundation model, ready for use with **[steegformer-rs](https://github.com/eugenehp/steegformer-rs)** (pure-Rust inference on [Burn 0.20](https://burn.dev)) or any framework that supports safetensors.
24
+
25
+ Weights are converted from the official PyTorch `.pth` checkpoints published at [LiuyinYang1101/STEEGFormer](https://github.com/LiuyinYang1101/STEEGFormer/releases).
26
+
27
+ ST-EEGFormer won **1st Place** in the NeurIPS 2025 EEG Foundation Challenge and was accepted at **ICLR 2026**.
28
+
29
+ ## Model Files
30
+
31
+ ### Encoder Only (for inference / embedding extraction)
32
+
33
+ | File | Variant | Params | Size | Layers | Heads | embed_dim |
34
+ |------|---------|--------|------|--------|-------|-----------|
35
+ | [`ST-EEGFormer_small_encoder.safetensors`](ST-EEGFormer_small_encoder.safetensors) | **Small** | 25.6 M | 102 MB | 8 | 8 | 512 |
36
+ | [`ST-EEGFormer_base_encoder.safetensors`](ST-EEGFormer_base_encoder.safetensors) | **Base** | 85.6 M | 342 MB | 12 | 12 | 768 |
37
+ | [`ST-EEGFormer_large_encoder.safetensors`](ST-EEGFormer_large_encoder.safetensors) | **Large** | 303.0 M | 1,212 MB | 24 | 16 | 1024 |
38
+ | [`ST-EEGFormer_largeV2_encoder.safetensors`](ST-EEGFormer_largeV2_encoder.safetensors) | **Large V2** | 303.1 M | 1,212 MB | 24 | 16 | 1024 |
39
+
40
+ ### Full MAE (encoder + decoder, for reconstruction / fine-tuning)
41
+
42
+ | File | Variant | Params | Size | Decoder dim | Decoder depth |
43
+ |------|---------|--------|------|-------------|---------------|
44
+ | [`ST-EEGFormer_small_mae.safetensors`](ST-EEGFormer_small_mae.safetensors) | **Small** | 33.1 M | 132 MB | 384 | 4 |
45
+ | [`ST-EEGFormer_base_mae.safetensors`](ST-EEGFormer_base_mae.safetensors) | **Base** | 111.5 M | 446 MB | 512 | 8 |
46
+ | [`ST-EEGFormer_large_mae.safetensors`](ST-EEGFormer_large_mae.safetensors) | **Large** | 329.1 M | 1,316 MB | 512 | 8 |
47
+ | [`ST-EEGFormer_largeV2_mae.safetensors`](ST-EEGFormer_largeV2_mae.safetensors) | **Large V2** | 329.3 M | 1,317 MB | 512 | 8 |
48
+
49
+ ### Config
50
+
51
+ | File | Description |
52
+ |------|-------------|
53
+ | [`config.json`](config.json) | Model hyperparameters for all variants |
54
+
55
+ > **Large V2** has undergone further pre-training on the HBN dataset for the NeurIPS 2025 EEG Foundation Challenge.
56
+
57
+ ## Quick Start β€” Rust
58
+
59
+ ```bash
60
+ # Install
61
+ cargo add steegformer-rs
62
+
63
+ # Download weights
64
+ huggingface-cli download eugenehp/ST-EEGFormer \
65
+ ST-EEGFormer_small_encoder.safetensors \
66
+ config.json \
67
+ --local-dir weights/
68
+
69
+ # Run inference
70
+ cargo run --release --bin infer -- \
71
+ --config weights/config.json \
72
+ --weights weights/ST-EEGFormer_small_encoder.safetensors
73
+ ```
74
+
75
+ ### Library API
76
+
77
+ ```rust
78
+ use steegformer_rs::{STEEGFormerEncoder, ModelConfig, data};
79
+ use std::path::Path;
80
+
81
+ // Load model
82
+ let cfg = ModelConfig::small();
83
+ let (encoder, _ms) = STEEGFormerEncoder::<B>::load_from_config(
84
+ cfg,
85
+ Path::new("ST-EEGFormer_small_encoder.safetensors"),
86
+ device,
87
+ )?;
88
+
89
+ // Build input: 4 channels Γ— 6 seconds @ 128 Hz
90
+ let channels = &["Fz", "C3", "C4", "Pz"];
91
+ let signal = vec![0.0f32; channels.len() * 768];
92
+ let batch = data::build_batch_named::<B>(signal, channels, 768, &device);
93
+
94
+ // Extract embeddings
95
+ let result = encoder.run_batch(&batch)?;
96
+ println!("Embedding shape: {:?}", result.shape); // [512]
97
+ ```
98
+
99
+ ## Quick Start β€” Python
100
+
101
+ ```python
102
+ from safetensors.torch import load_file
103
+
104
+ # Load encoder weights
105
+ state_dict = load_file("ST-EEGFormer_small_encoder.safetensors")
106
+
107
+ # Build model and load
108
+ from models_mae_eeg import mae_vit_small_patch16
109
+ model = mae_vit_small_patch16()
110
+ model.load_state_dict(state_dict, strict=False)
111
+ model.eval()
112
+ ```
113
+
114
+ ## Architecture
115
+
116
+ ```
117
+ EEG signal (B, C, T) β€” up to 142 channels, 128 Hz, ≀ 6s
118
+ β”‚
119
+ β–Ό
120
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
121
+ β”‚ PatchEmbedEEG β”‚
122
+ β”‚ Unfold β†’ 16-sample patches β”‚
123
+ β”‚ Linear(16, embed_dim) β”‚
124
+ β”‚ β†’ (B, num_patches Γ— C, D) β”‚
125
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
126
+ β”‚
127
+ + Sinusoidal Temporal PE (fixed)
128
+ + Learned Channel Embedding (nn.Embedding(145, D))
129
+ β”‚
130
+ β–Ό
131
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
132
+ β”‚ [CLS] token prepend β”‚
133
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
134
+ β”‚
135
+ β–Ό
136
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€οΏ½οΏ½β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
137
+ β”‚ N Γ— Transformer Encoder Block β”‚
138
+ β”‚ Pre-norm: LN β†’ MHSA β†’ residual β”‚
139
+ β”‚ LN β†’ FFN β†’ residual β”‚
140
+ β”‚ (qkv_bias=True, GELU activation) β”‚
141
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
142
+ β”‚
143
+ β–Ό
144
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
145
+ β”‚ LayerNorm β†’ CLS token β”‚
146
+ β”‚ β†’ (B, embed_dim) embedding β”‚
147
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
148
+ ```
149
+
150
+ ### MAE Pre-training (decoder, included in `*_mae.safetensors`)
151
+
152
+ ```
153
+ Encoder output (25% of tokens)
154
+ β”‚
155
+ β–Ό
156
+ Linear(embed_dim β†’ decoder_dim)
157
+ + Insert mask tokens at masked positions
158
+ + Decoder temporal/channel PE
159
+ β”‚
160
+ β–Ό
161
+ M Γ— Decoder Transformer Blocks
162
+ β”‚
163
+ β–Ό
164
+ Linear(decoder_dim β†’ patch_size)
165
+ β†’ Reconstructed EEG patches
166
+ ```
167
+
168
+ ## Numerical Parity (Rust vs Python)
169
+
170
+ Verified at every stage against the official PyTorch implementation:
171
+
172
+ | Stage | RMSE | Pearson r |
173
+ |---|---|---|
174
+ | Patch embedding | 0.000000 | 1.000000 |
175
+ | Channel embedding | 0.000000 | 1.000000 |
176
+ | Temporal encoding | 0.000000 | 1.000000 |
177
+ | After positional encoding | 0.000000 | 1.000000 |
178
+ | After transformer block 0 | 0.000004 | 1.000000 |
179
+ | **Full encoder (8 blocks)** | **0.000001** | **1.000000** |
180
+
181
+ ## Benchmarks
182
+
183
+ **Platform:** Apple M4 Pro, 64 GB RAM, macOS (arm64)
184
+
185
+ ### Inference Latency β€” ST-EEGFormer-Small (22ch Γ— 768 samples)
186
+
187
+ | Backend | Mean | Min |
188
+ |---|---|---|
189
+ | Rust CPU (NdArray + Accelerate) | 608.4 ms | 601.4 ms |
190
+ | Python CPU (PyTorch 2.6) | 78.1 ms | 77.2 ms |
191
+ | **Rust GPU (Burn wgpu + Metal)** | **38.1 ms** | **7.9 ms** |
192
+ | Python MPS (PyTorch + Metal) | 19.2 ms | 19.0 ms |
193
+
194
+ ### Channel Scaling (T=768)
195
+
196
+ | Channels | Rust CPU | Python CPU | Rust GPU | Python MPS |
197
+ |---|---|---|---|---|
198
+ | 4 | 75.5 ms | 21.8 ms | 11.5 ms | 4.0 ms |
199
+ | 22 | 596.0 ms | 77.9 ms | 32.7 ms | 19.3 ms |
200
+ | 64 | 3853.2 ms | 301.9 ms | 119.4 ms | 90.1 ms |
201
+
202
+ ## Weight Key Format
203
+
204
+ ### Encoder keys
205
+
206
+ ```
207
+ patch_embed.proj.weight [embed_dim, 16]
208
+ patch_embed.proj.bias [embed_dim]
209
+ cls_token [1, 1, embed_dim]
210
+ enc_channel_emd.channel_transformation.weight [145, embed_dim]
211
+ enc_temporal_emd.pe [1, 512, embed_dim]
212
+ blocks.{i}.norm1.weight [embed_dim]
213
+ blocks.{i}.norm1.bias [embed_dim]
214
+ blocks.{i}.attn.qkv.weight [3*embed_dim, embed_dim]
215
+ blocks.{i}.attn.qkv.bias [3*embed_dim]
216
+ blocks.{i}.attn.proj.weight [embed_dim, embed_dim]
217
+ blocks.{i}.attn.proj.bias [embed_dim]
218
+ blocks.{i}.norm2.weight [embed_dim]
219
+ blocks.{i}.norm2.bias [embed_dim]
220
+ blocks.{i}.mlp.fc1.weight [4*embed_dim, embed_dim]
221
+ blocks.{i}.mlp.fc1.bias [4*embed_dim]
222
+ blocks.{i}.mlp.fc2.weight [embed_dim, 4*embed_dim]
223
+ blocks.{i}.mlp.fc2.bias [embed_dim]
224
+ norm.weight [embed_dim]
225
+ norm.bias [embed_dim]
226
+ ```
227
+
228
+ ### Decoder keys (MAE only)
229
+
230
+ ```
231
+ decoder_embed.weight [dec_dim, embed_dim]
232
+ decoder_embed.bias [dec_dim]
233
+ mask_token [1, 1, dec_dim]
234
+ dec_channel_emd.channel_transformation.weight [145, dec_dim]
235
+ dec_temporal_emd.pe [1, 512, dec_dim]
236
+ decoder_blocks.{i}.* (same structure as encoder)
237
+ decoder_norm.weight [dec_dim]
238
+ decoder_norm.bias [dec_dim]
239
+ decoder_pred.weight [16, dec_dim]
240
+ decoder_pred.bias [16]
241
+ ```
242
+
243
+ ## Conversion
244
+
245
+ These weights were converted from the official `.pth` files:
246
+
247
+ ```python
248
+ import torch
249
+ from safetensors.torch import save_file
250
+
251
+ ckpt = torch.load("checkpoint.pth", map_location="cpu", weights_only=False)
252
+ state_dict = ckpt["model"]
253
+
254
+ # Encoder only
255
+ encoder = {k: v.float().contiguous() for k, v in state_dict.items()
256
+ if any(k.startswith(p) for p in
257
+ ["patch_embed.", "cls_token", "enc_", "blocks.", "norm."])}
258
+ save_file(encoder, "encoder.safetensors")
259
+ ```
260
+
261
+ Or use the included conversion script:
262
+
263
+ ```bash
264
+ python scripts/convert_to_safetensors.py --all
265
+ ```
266
+
267
+ ## Citation
268
+
269
+ ```bibtex
270
+ @inproceedings{yang2026_steegformer,
271
+ title={Are {EEG} Foundation Models Worth It? Comparative Evaluation
272
+ with Traditional Decoders in Diverse {BCI} Tasks},
273
+ author={Liuyin Yang and Qiang Sun and Ang Li and Marc M. Van Hulle},
274
+ booktitle={The Fourteenth International Conference on Learning Representations},
275
+ year={2026},
276
+ url={https://openreview.net/forum?id=5Xwm8e6vbh}
277
+ }
278
+ ```
279
+
280
+ ## License
281
+
282
+ MIT β€” same as the original ST-EEGFormer release.
283
+
284
+ ## Links
285
+
286
+ | | |
287
+ |---|---|
288
+ | **Rust crate** | [github.com/eugenehp/steegformer-rs](https://github.com/eugenehp/steegformer-rs) |
289
+ | **Original code** | [github.com/LiuyinYang1101/STEEGFormer](https://github.com/LiuyinYang1101/STEEGFormer) |
290
+ | **Original weights** | [GitHub Releases](https://github.com/LiuyinYang1101/STEEGFormer/releases) |
291
+ | **Paper** | [OpenReview (ICLR 2026)](https://openreview.net/forum?id=5Xwm8e6vbh) |
292
+ | **Burn framework** | [burn.dev](https://burn.dev) |