| --- |
| license: mit |
| language: |
| - en |
| tags: |
| - fmri |
| - neuroscience |
| - brain |
| - foundation-model |
| - vision-transformer |
| - jepa |
| - burn |
| - rust |
| datasets: |
| - ukbiobank |
| pipeline_tag: feature-extraction |
| library_name: brainjepa-rs |
| --- |
| |
| # Brain-JEPA (safetensors) |
|
|
| Pretrained weights for **Brain-JEPA** (NeurIPS 2024, Spotlight) converted to safetensors format for use with [brainjepa-rs](https://github.com/eugenehp/brainjepa-rs). |
|
|
| ## Model description |
|
|
| Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with: |
|
|
| - **Brain gradient positioning** for spatial (ROI) embeddings |
| - **Temporal patch embedding** via 1D convolution along time |
| - **JEPA architecture** (Joint Embedding Predictive Architecture) |
|
|
| The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs. |
|
|
| ## Files |
|
|
| | File | Description | Shape info | |
| |---|---|---| |
| | `brainjepa.safetensors` | All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB | |
| | `gradient_mapping_450.csv` | Brain gradient coordinates for positional embeddings | 450 rows x 30 columns | |
| |
| ### Weight key structure |
| |
| Keys are prefixed by component (`encoder.`, `predictor.`, `target_encoder.`): |
|
|
| ``` |
| encoder.patch_embed.proj.weight [768, 1, 1, 16] |
| encoder.blocks.{i}.norm1.weight [768] |
| encoder.blocks.{i}.attn.qkv.weight [2304, 768] |
| encoder.blocks.{i}.attn.proj.weight [768, 768] |
| encoder.blocks.{i}.mlp.fc1.weight [3072, 768] |
| encoder.blocks.{i}.mlp.fc2.weight [768, 3072] |
| encoder.norm.weight [768] |
| ... |
| ``` |
|
|
| For inference, use `target_encoder.*` keys (EMA-smoothed weights from pretraining). |
|
|
| ## Usage with brainjepa-rs (Rust) |
|
|
| ```sh |
| # Install |
| git clone https://github.com/eugenehp/brainjepa-rs |
| cd brainjepa-rs |
| |
| # Download weights from this repo |
| # Place brainjepa.safetensors and gradient_mapping_450.csv in data/ |
| |
| # Run inference (CPU) |
| cargo run --release --bin infer -- \ |
| --weights data/brainjepa.safetensors \ |
| --gradient data/gradient_mapping_450.csv \ |
| --input data/fmri_sample.safetensors |
| |
| # Run inference (GPU, Metal/Vulkan) |
| cargo run --release --no-default-features --features wgpu --bin infer -- \ |
| --weights data/brainjepa.safetensors \ |
| --gradient data/gradient_mapping_450.csv \ |
| --input data/fmri_sample.safetensors |
| ``` |
|
|
| ### Rust library |
|
|
| ```rust |
| use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig}; |
| |
| let (encoder, _) = BrainJepaEncoder::<B>::from_weights( |
| "data/brainjepa.safetensors", |
| "data/gradient_mapping_450.csv", |
| &ModelConfig::default(), |
| &DataConfig::default(), |
| &device, |
| )?; |
| let result = encoder.encode_safetensors("data/fmri.safetensors")?; |
| // result.embeddings: [4500, 768] float32 |
| ``` |
|
|
| ## Usage with original Python code |
|
|
| These weights were converted from the original PyTorch checkpoint. To use with the original code: |
|
|
| ```python |
| import torch |
| from safetensors.torch import load_file |
| |
| tensors = load_file("brainjepa.safetensors") |
| # Filter for target_encoder weights and strip prefix: |
| state_dict = { |
| k.removeprefix("target_encoder."): v |
| for k, v in tensors.items() |
| if k.startswith("target_encoder.") |
| } |
| model.load_state_dict(state_dict) |
| ``` |
|
|
| ## Conversion |
|
|
| Weights were converted from the original PyTorch checkpoint using: |
|
|
| ```sh |
| python scripts/convert_weights.py \ |
| --input jepa-ep300.pth.tar \ |
| --output brainjepa.safetensors |
| ``` |
|
|
| The conversion script strips the `module.` prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format. |
|
|
| ## Benchmark |
|
|
| Tested on Mac Mini M4 Pro (14 cores, 64 GB). |
| Input: `[1, 1, 450, 160]` (single sample, ViT-Base 86M params). Best-of-3 encode time. |
|
|
| | Backend | Encode | vs PyTorch CPU | |
| |---|---|---| |
| | Rust — NdArray + Rayon (CPU) | 28,778 ms | 0.06x | |
| | Rust — NdArray + Accelerate (CPU) | 21,092 ms | 0.08x | |
| | Python — PyTorch (CPU) | 1,782 ms | 1.0x | |
| | Python — PyTorch MPS (GPU) | 581 ms | 3.1x | |
| | **Rust — wgpu f32 / Metal (GPU)** | **83 ms** | **21.5x** | |
| | **Rust — wgpu f16 / Metal (GPU)** | **85 ms** | **21.0x** | |
|
|
| The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster |
| than PyTorch CPU. |
|
|
|  |
|
|
| ## Architecture details |
|
|
| | Parameter | Value | |
| |---|---| |
| | Model | ViT-Base | |
| | Embedding dim | 768 | |
| | Encoder depth | 12 layers | |
| | Predictor depth | 6 layers | |
| | Attention heads | 12 | |
| | Head dim | 64 | |
| | MLP ratio | 4x (hidden=3072) | |
| | Patch size | 16 (temporal) | |
| | Input size | 450 ROIs x 160 time points | |
| | Output | 4500 patches x 768 dims | |
| | Normalization | LayerNorm (eps=1e-6) | |
| | Activation | GELU | |
| | Pretraining | 300 epochs on UK Biobank | |
| | Loss | Smooth L1 (JEPA representation matching) | |
| | Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) | |
|
|
| ## Source |
|
|
| Original paper and code: |
|
|
| > Zijian Dong, Ruilin Li, Yilei Wu, et al. |
| > **Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking.** |
| > NeurIPS 2024 (Spotlight). [arXiv:2409.19407](https://arxiv.org/abs/2409.19407) |
|
|
| - Paper: [arxiv.org/abs/2409.19407](https://arxiv.org/abs/2409.19407) |
| - Original code: [github.com/hzlab/Brain-JEPA](https://github.com/hzlab/Brain-JEPA) |
| - Rust inference: [github.com/eugenehp/brainjepa-rs](https://github.com/eugenehp/brainjepa-rs) |
|
|