brainjepa / README.md
eugenehp's picture
Upload Brain-JEPA model card, weights, and gradient mapping
c664669 verified
|
Raw
History Blame Contribute Delete
5.39 kB
---
license: mit
language:
- en
tags:
- fmri
- neuroscience
- brain
- foundation-model
- vision-transformer
- jepa
- burn
- rust
datasets:
- ukbiobank
pipeline_tag: feature-extraction
library_name: brainjepa-rs
---
# Brain-JEPA (safetensors)
Pretrained weights for **Brain-JEPA** (NeurIPS 2024, Spotlight) converted to safetensors format for use with [brainjepa-rs](https://github.com/eugenehp/brainjepa-rs).
## Model description
Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with:
- **Brain gradient positioning** for spatial (ROI) embeddings
- **Temporal patch embedding** via 1D convolution along time
- **JEPA architecture** (Joint Embedding Predictive Architecture)
The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs.
## Files
| File | Description | Shape info |
|---|---|---|
| `brainjepa.safetensors` | All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB |
| `gradient_mapping_450.csv` | Brain gradient coordinates for positional embeddings | 450 rows x 30 columns |
### Weight key structure
Keys are prefixed by component (`encoder.`, `predictor.`, `target_encoder.`):
```
encoder.patch_embed.proj.weight [768, 1, 1, 16]
encoder.blocks.{i}.norm1.weight [768]
encoder.blocks.{i}.attn.qkv.weight [2304, 768]
encoder.blocks.{i}.attn.proj.weight [768, 768]
encoder.blocks.{i}.mlp.fc1.weight [3072, 768]
encoder.blocks.{i}.mlp.fc2.weight [768, 3072]
encoder.norm.weight [768]
...
```
For inference, use `target_encoder.*` keys (EMA-smoothed weights from pretraining).
## Usage with brainjepa-rs (Rust)
```sh
# Install
git clone https://github.com/eugenehp/brainjepa-rs
cd brainjepa-rs
# Download weights from this repo
# Place brainjepa.safetensors and gradient_mapping_450.csv in data/
# Run inference (CPU)
cargo run --release --bin infer -- \
--weights data/brainjepa.safetensors \
--gradient data/gradient_mapping_450.csv \
--input data/fmri_sample.safetensors
# Run inference (GPU, Metal/Vulkan)
cargo run --release --no-default-features --features wgpu --bin infer -- \
--weights data/brainjepa.safetensors \
--gradient data/gradient_mapping_450.csv \
--input data/fmri_sample.safetensors
```
### Rust library
```rust
use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig};
let (encoder, _) = BrainJepaEncoder::<B>::from_weights(
"data/brainjepa.safetensors",
"data/gradient_mapping_450.csv",
&ModelConfig::default(),
&DataConfig::default(),
&device,
)?;
let result = encoder.encode_safetensors("data/fmri.safetensors")?;
// result.embeddings: [4500, 768] float32
```
## Usage with original Python code
These weights were converted from the original PyTorch checkpoint. To use with the original code:
```python
import torch
from safetensors.torch import load_file
tensors = load_file("brainjepa.safetensors")
# Filter for target_encoder weights and strip prefix:
state_dict = {
k.removeprefix("target_encoder."): v
for k, v in tensors.items()
if k.startswith("target_encoder.")
}
model.load_state_dict(state_dict)
```
## Conversion
Weights were converted from the original PyTorch checkpoint using:
```sh
python scripts/convert_weights.py \
--input jepa-ep300.pth.tar \
--output brainjepa.safetensors
```
The conversion script strips the `module.` prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format.
## Benchmark
Tested on Mac Mini M4 Pro (14 cores, 64 GB).
Input: `[1, 1, 450, 160]` (single sample, ViT-Base 86M params). Best-of-3 encode time.
| Backend | Encode | vs PyTorch CPU |
|---|---|---|
| Rust — NdArray + Rayon (CPU) | 28,778 ms | 0.06x |
| Rust — NdArray + Accelerate (CPU) | 21,092 ms | 0.08x |
| Python — PyTorch (CPU) | 1,782 ms | 1.0x |
| Python — PyTorch MPS (GPU) | 581 ms | 3.1x |
| **Rust — wgpu f32 / Metal (GPU)** | **83 ms** | **21.5x** |
| **Rust — wgpu f16 / Metal (GPU)** | **85 ms** | **21.0x** |
The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster
than PyTorch CPU.
![benchmark](benchmark.png)
## Architecture details
| Parameter | Value |
|---|---|
| Model | ViT-Base |
| Embedding dim | 768 |
| Encoder depth | 12 layers |
| Predictor depth | 6 layers |
| Attention heads | 12 |
| Head dim | 64 |
| MLP ratio | 4x (hidden=3072) |
| Patch size | 16 (temporal) |
| Input size | 450 ROIs x 160 time points |
| Output | 4500 patches x 768 dims |
| Normalization | LayerNorm (eps=1e-6) |
| Activation | GELU |
| Pretraining | 300 epochs on UK Biobank |
| Loss | Smooth L1 (JEPA representation matching) |
| Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) |
## Source
Original paper and code:
> Zijian Dong, Ruilin Li, Yilei Wu, et al.
> **Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking.**
> NeurIPS 2024 (Spotlight). [arXiv:2409.19407](https://arxiv.org/abs/2409.19407)
- Paper: [arxiv.org/abs/2409.19407](https://arxiv.org/abs/2409.19407)
- Original code: [github.com/hzlab/Brain-JEPA](https://github.com/hzlab/Brain-JEPA)
- Rust inference: [github.com/eugenehp/brainjepa-rs](https://github.com/eugenehp/brainjepa-rs)