--- license: mit tags: - eeg - bci - brain-computer-interface - foundation-model - vit - masked-autoencoder - mae - neuroscience - safetensors - burn - rust language: - en library_name: steegformer-rs pipeline_tag: feature-extraction --- # ST-EEGFormer — Safetensors Weights Pre-converted [safetensors](https://github.com/huggingface/safetensors) weights for the [ST-EEGFormer](https://github.com/LiuyinYang1101/STEEGFormer) EEG foundation model, ready for use with **[steegformer-rs](https://github.com/eugenehp/steegformer-rs)** (pure-Rust inference on [Burn 0.20](https://burn.dev)) or any framework that supports safetensors. Weights are converted from the official PyTorch `.pth` checkpoints published at [LiuyinYang1101/STEEGFormer](https://github.com/LiuyinYang1101/STEEGFormer/releases). ST-EEGFormer won **1st Place** in the NeurIPS 2025 EEG Foundation Challenge and was accepted at **ICLR 2026**. ## Model Files ### Encoder Only (for inference / embedding extraction) | File | Variant | Params | Size | Layers | Heads | embed_dim | |------|---------|--------|------|--------|-------|-----------| | [`ST-EEGFormer_small_encoder.safetensors`](ST-EEGFormer_small_encoder.safetensors) | **Small** | 25.6 M | 102 MB | 8 | 8 | 512 | | [`ST-EEGFormer_base_encoder.safetensors`](ST-EEGFormer_base_encoder.safetensors) | **Base** | 85.6 M | 342 MB | 12 | 12 | 768 | | [`ST-EEGFormer_large_encoder.safetensors`](ST-EEGFormer_large_encoder.safetensors) | **Large** | 303.0 M | 1,212 MB | 24 | 16 | 1024 | | [`ST-EEGFormer_largeV2_encoder.safetensors`](ST-EEGFormer_largeV2_encoder.safetensors) | **Large V2** | 303.1 M | 1,212 MB | 24 | 16 | 1024 | ### Full MAE (encoder + decoder, for reconstruction / fine-tuning) | File | Variant | Params | Size | Decoder dim | Decoder depth | |------|---------|--------|------|-------------|---------------| | [`ST-EEGFormer_small_mae.safetensors`](ST-EEGFormer_small_mae.safetensors) | **Small** | 33.1 M | 132 MB | 384 | 4 | | [`ST-EEGFormer_base_mae.safetensors`](ST-EEGFormer_base_mae.safetensors) | **Base** | 111.5 M | 446 MB | 512 | 8 | | [`ST-EEGFormer_large_mae.safetensors`](ST-EEGFormer_large_mae.safetensors) | **Large** | 329.1 M | 1,316 MB | 512 | 8 | | [`ST-EEGFormer_largeV2_mae.safetensors`](ST-EEGFormer_largeV2_mae.safetensors) | **Large V2** | 329.3 M | 1,317 MB | 512 | 8 | ### Config | File | Description | |------|-------------| | [`config.json`](config.json) | Model hyperparameters for all variants | > **Large V2** has undergone further pre-training on the HBN dataset for the NeurIPS 2025 EEG Foundation Challenge. ## Quick Start — Rust ```bash # Install cargo add steegformer-rs # Download weights huggingface-cli download eugenehp/ST-EEGFormer \ ST-EEGFormer_small_encoder.safetensors \ config.json \ --local-dir weights/ # Run inference cargo run --release --bin infer -- \ --config weights/config.json \ --weights weights/ST-EEGFormer_small_encoder.safetensors ``` ### Library API ```rust use steegformer_rs::{STEEGFormerEncoder, ModelConfig, data}; use std::path::Path; // Load model let cfg = ModelConfig::small(); let (encoder, _ms) = STEEGFormerEncoder::::load_from_config( cfg, Path::new("ST-EEGFormer_small_encoder.safetensors"), device, )?; // Build input: 4 channels × 6 seconds @ 128 Hz let channels = &["Fz", "C3", "C4", "Pz"]; let signal = vec![0.0f32; channels.len() * 768]; let batch = data::build_batch_named::(signal, channels, 768, &device); // Extract embeddings let result = encoder.run_batch(&batch)?; println!("Embedding shape: {:?}", result.shape); // [512] ``` ## Quick Start — Python ```python from safetensors.torch import load_file # Load encoder weights state_dict = load_file("ST-EEGFormer_small_encoder.safetensors") # Build model and load from models_mae_eeg import mae_vit_small_patch16 model = mae_vit_small_patch16() model.load_state_dict(state_dict, strict=False) model.eval() ``` ## Architecture ``` EEG signal (B, C, T) — up to 142 channels, 128 Hz, ≤ 6s │ ▼ ┌──────────────────────────────────────┐ │ PatchEmbedEEG │ │ Unfold → 16-sample patches │ │ Linear(16, embed_dim) │ │ → (B, num_patches × C, D) │ └──────────────────────────────────────┘ │ + Sinusoidal Temporal PE (fixed) + Learned Channel Embedding (nn.Embedding(145, D)) │ ▼ ┌──────────────────────────────────────┐ │ [CLS] token prepend │ └──────────────────────────────────────┘ │ ▼ ┌──────────────────────────────────────┐ │ N × Transformer Encoder Block │ │ Pre-norm: LN → MHSA → residual │ │ LN → FFN → residual │ │ (qkv_bias=True, GELU activation) │ └──────────────────────────────────────┘ │ ▼ ┌──────────────────────────────────────┐ │ LayerNorm → CLS token │ │ → (B, embed_dim) embedding │ └──────────────────────────────────────┘ ``` ### MAE Pre-training (decoder, included in `*_mae.safetensors`) ``` Encoder output (25% of tokens) │ ▼ Linear(embed_dim → decoder_dim) + Insert mask tokens at masked positions + Decoder temporal/channel PE │ ▼ M × Decoder Transformer Blocks │ ▼ Linear(decoder_dim → patch_size) → Reconstructed EEG patches ``` ## Numerical Parity (Rust vs Python) Verified at every stage against the official PyTorch implementation: | Stage | RMSE | Pearson r | |---|---|---| | Patch embedding | 0.000000 | 1.000000 | | Channel embedding | 0.000000 | 1.000000 | | Temporal encoding | 0.000000 | 1.000000 | | After positional encoding | 0.000000 | 1.000000 | | After transformer block 0 | 0.000004 | 1.000000 | | **Full encoder (8 blocks)** | **0.000001** | **1.000000** | ## Benchmarks **Platform:** Apple M4 Pro, 64 GB RAM, macOS (arm64) ### Inference Latency — ST-EEGFormer-Small (22ch × 768 samples) | Backend | Mean | Min | |---|---|---| | Rust CPU (NdArray + Accelerate) | 608.4 ms | 601.4 ms | | Python CPU (PyTorch 2.6) | 78.1 ms | 77.2 ms | | **Rust GPU (Burn wgpu + Metal)** | **38.1 ms** | **7.9 ms** | | Python MPS (PyTorch + Metal) | 19.2 ms | 19.0 ms | ### Channel Scaling (T=768) | Channels | Rust CPU | Python CPU | Rust GPU | Python MPS | |---|---|---|---|---| | 4 | 75.5 ms | 21.8 ms | 11.5 ms | 4.0 ms | | 22 | 596.0 ms | 77.9 ms | 32.7 ms | 19.3 ms | | 64 | 3853.2 ms | 301.9 ms | 119.4 ms | 90.1 ms | ## Weight Key Format ### Encoder keys ``` patch_embed.proj.weight [embed_dim, 16] patch_embed.proj.bias [embed_dim] cls_token [1, 1, embed_dim] enc_channel_emd.channel_transformation.weight [145, embed_dim] enc_temporal_emd.pe [1, 512, embed_dim] blocks.{i}.norm1.weight [embed_dim] blocks.{i}.norm1.bias [embed_dim] blocks.{i}.attn.qkv.weight [3*embed_dim, embed_dim] blocks.{i}.attn.qkv.bias [3*embed_dim] blocks.{i}.attn.proj.weight [embed_dim, embed_dim] blocks.{i}.attn.proj.bias [embed_dim] blocks.{i}.norm2.weight [embed_dim] blocks.{i}.norm2.bias [embed_dim] blocks.{i}.mlp.fc1.weight [4*embed_dim, embed_dim] blocks.{i}.mlp.fc1.bias [4*embed_dim] blocks.{i}.mlp.fc2.weight [embed_dim, 4*embed_dim] blocks.{i}.mlp.fc2.bias [embed_dim] norm.weight [embed_dim] norm.bias [embed_dim] ``` ### Decoder keys (MAE only) ``` decoder_embed.weight [dec_dim, embed_dim] decoder_embed.bias [dec_dim] mask_token [1, 1, dec_dim] dec_channel_emd.channel_transformation.weight [145, dec_dim] dec_temporal_emd.pe [1, 512, dec_dim] decoder_blocks.{i}.* (same structure as encoder) decoder_norm.weight [dec_dim] decoder_norm.bias [dec_dim] decoder_pred.weight [16, dec_dim] decoder_pred.bias [16] ``` ## Conversion These weights were converted from the official `.pth` files: ```python import torch from safetensors.torch import save_file ckpt = torch.load("checkpoint.pth", map_location="cpu", weights_only=False) state_dict = ckpt["model"] # Encoder only encoder = {k: v.float().contiguous() for k, v in state_dict.items() if any(k.startswith(p) for p in ["patch_embed.", "cls_token", "enc_", "blocks.", "norm."])} save_file(encoder, "encoder.safetensors") ``` Or use the included conversion script: ```bash python scripts/convert_to_safetensors.py --all ``` ## Citation ```bibtex @inproceedings{yang2026_steegformer, title={Are {EEG} Foundation Models Worth It? Comparative Evaluation with Traditional Decoders in Diverse {BCI} Tasks}, author={Liuyin Yang and Qiang Sun and Ang Li and Marc M. Van Hulle}, booktitle={The Fourteenth International Conference on Learning Representations}, year={2026}, url={https://openreview.net/forum?id=5Xwm8e6vbh} } ``` ## License MIT — same as the original ST-EEGFormer release. ## Links | | | |---|---| | **Rust crate** | [github.com/eugenehp/steegformer-rs](https://github.com/eugenehp/steegformer-rs) | | **Original code** | [github.com/LiuyinYang1101/STEEGFormer](https://github.com/LiuyinYang1101/STEEGFormer) | | **Original weights** | [GitHub Releases](https://github.com/LiuyinYang1101/STEEGFormer/releases) | | **Paper** | [OpenReview (ICLR 2026)](https://openreview.net/forum?id=5Xwm8e6vbh) | | **Burn framework** | [burn.dev](https://burn.dev) |