A foundation model for tokenized spatial-temporal representation learning of electroencephalography signal data
Overview
EEGFoundation is a novel foundation model that treats neural dynamics as a discrete semantic language, overcoming the limitations of vision-based EEG analysis paradigms. By implementing amplitude-aware tokenization and channel-independent pretraining on a 27,000+ hour EEG corpus, the model learns universal neural oscillation patterns that generalize across diverse EEG analysis tasks.
Fig.1 The EEGFoundation framework for spatiotemporal sequence modeling
Model Architecture
EEGFoundation follows a three-stage hierarchical approach:
Amplitude-Aware Tokenization: Continuous EEG signals are normalized and quantized into discrete symbolic tokens that preserve micro-voltage fluctuations while filtering high-frequency noise.
Temporal Pretraining: Using a RoFormer encoder with Rotary Position Embeddings, the model learns universal temporal dynamics from channel-independent EEG streams.
Spatiotemporal Fusion: Cross-channel attention dynamically aggregates local representations into a coherent global context for robust downstream task performance.
Quick Start
Environment Setup
# Clone the repository
git clone https://github.com/yourusername/EEGFoundation_github.git
cd EEGFoundation_github
# Install dependencies (Python 3.10+ required)
pip install torch>=2.0.0 transformers>=4.30.0 numpy>=1.24.0 scipy>=1.10.0
pip install mne>=1.4.0 einops>=0.6.0 matplotlib>=3.7.0
Basic Usage
from src.models.downstream_EEGFoundation import load_downstream_model
import torch
import numpy as np
# Load pre-trained model for motor imagery classification
model = load_downstream_model(
model_path="models/BCIC-2a_model.pth",
config_path="configs/BCIC_IV_2a_config.json"
)
# Prepare input data (example)
batch_size = 2
num_channels = 20
seq_length = 2000
eeg_signal = torch.randn(batch_size, num_channels, seq_length).float()
embedding = torch.randn(batch_size, 512).float()
# Forward pass
with torch.no_grad():
outputs = model(input_ids=eeg_signal, embedding_data=embedding)
predictions = torch.softmax(outputs['logits'], dim=-1)
print(f"Predictions shape: {predictions.shape}")
Demo Data
import numpy as np
# Load example data
demo_data = np.load("demo_data/eeg_data.npy")
print(f"Demo data shape: {demo_data.shape}")
# The demo_data directory contains:
# - eeg_data.npy: Sample EEG recordings
# - downstream_eeg_data.npz: Processed data for downstream tasks
License
This project is licensed under the Apache License 2.0. See the LICENSE file for details.
