--- license: apache-2.0 library_name: pytorch pipeline_tag: other --- ### There is the pretrained weights of CBraMod. # CBraMod: A Criss-Cross Brain Foundation Model for EEG Decoding [![Paper](https://img.shields.io/badge/arXiv-2412.07236-red)](https://arxiv.org/abs/2412.07236) [![Paper](https://img.shields.io/badge/Paper-ICLR-008B8B)](https://openreview.net/forum?id=NPNUHgHF2w) [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-FFD21E)](https://huggingface.co/weighting666/CBraMod) ![GitHub Repo stars](https://img.shields.io/github/stars/wjq-learning/CBraMod)

πŸ” About | πŸ”¨ Setup | 🚒 How to Pretrain | β›΅ How to Finetune | πŸš€ Quick Start | πŸ”— Citation

πŸ”₯ NEWS: The paper "_CBraMod: A Criss-Cross Brain Foundation Model for EEG Decoding_" has been accepted by ICLR 2025! ## πŸ” About We propose **CBraMod**, a novel EEG foundation model, for EEG decoding on various clinical and BCI application. The preprint version of our paper is available at [arXiv](https://arxiv.org/abs/2412.07236). The camera-ready version of the paper will be available at [OpenReview](https://openreview.net/forum?id=NPNUHgHF2w).
## πŸ”¨ Setup Install [Python](https://www.python.org/downloads/). Install [PyTorch](https://pytorch.org/get-started/locally/). Install other requirements: ```commandline pip install -r requirements.txt ``` ## 🚒 How to Pretrain You can pretrain CBraMod on our pretraining dataset or your custom pretraining dataset using the following code: ```commandline python pretrain_main.py ``` We have released a pretrained checkpoint on [HugginfaceπŸ€—](https://huggingface.co/weighting666/CBraMod). ## β›΅ How to Finetune You can finetune CBraMod on our selected downstream datasets using the following code: ```commandline python finetune_main.py ``` ## πŸš€ Quick Start You can fine-tune the pretrained CBraMod on your custom downstream dataset using the following example code: ```python import torch import torch.nn as nn from models.cbramod import CBraMod from einops.layers.torch import Rearrange device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = CBraMod().to(device) model.load_state_dict(torch.load('pretrained_weights/pretrained_weights.pth', map_location=device)) model.proj_out = nn.Identity() classifier = nn.Sequential( Rearrange('b c s p -> b (c s p)'), nn.Linear(22*4*200, 4*200), nn.ELU(), nn.Dropout(0.1), nn.Linear(4 * 200, 200), nn.ELU(), nn.Dropout(0.1), nn.Linear(200, 4), ).to(device) # mock_eeg.shape = (batch_size, num_of_channels, time_segments, points_per_patch) mock_eeg = torch.randn((8, 22, 4, 200)).to(device) # logits.shape = (batch_size, num_of_classes) logits = classifier(model(mock_eeg)) ``` ## πŸ”— Citation If you're using this repository in your research or applications, please cite using the following BibTeX: ```bibtex @inproceedings{wang2025cbramod, title={{CB}raMod: A Criss-Cross Brain Foundation Model for {EEG} Decoding}, author={Jiquan Wang and Sha Zhao and Zhiling Luo and Yangxuan Zhou and Haiteng Jiang and Shijian Li and Tao Li and Gang Pan}, booktitle={The Thirteenth International Conference on Learning Representations}, year={2025}, url={https://openreview.net/forum?id=NPNUHgHF2w} } ``` ## ⭐ Star History
Code: https://github.com/wjq-learning/CBraMod