🧠Multimodal Brain Encoder
A real brain encoding model that predicts fMRI brain activity from multimodal inputs (images, text, audio).
Architecture
| Component |
Details |
| Feature Extractor |
CLIP ViT-L/14 (openai/clip-vit-large-patch14) |
| Feature Layers |
Layers 6, 12, 18, 24 CLS tokens concatenated (4096-dim) |
| Brain Encoder |
Deep network: 4096 → 2048 → 2048 → 1024 → N_voxels |
| ROI Heads |
5 functional network-specific attention heads |
| Ridge Baseline |
sklearn RidgeCV (Algonauts 2023 recipe) |
| Q&A System |
Grounded LLM interpreter (Qwen2.5-72B) |
Training Data
- Dataset: Natural Scenes Dataset (NSD)
- Subject: subj01 (7T fMRI)
- Training samples: 2000 images with paired fMRI responses
- Validation: 200 images
- Voxels: ~47,236 (nsdgeneral mask)
Brain Regions (24 ROIs)
| Network |
Regions |
Function |
| Early Visual |
V1v, V1d, V2v, V2d, V3v, V3d, hV4 |
Basic visual processing |
| Body Selective |
EBA, FBA-1, FBA-2, mTL-bodies |
Body/person perception |
| Face Selective |
OFA, FFA-1, FFA-2, mTL-faces, aTL-faces |
Face recognition |
| Place Selective |
OPA, PPA, RSC |
Scene/navigation |
| Word Selective |
OWFA, VWFA-1, VWFA-2, mfs-words, mTL-words |
Reading/text |
How It Works
- Input → CLIP ViT-L/14 multi-layer features (4096-dim)
- Brain Encoder → Predicted fMRI voxel activations (~47k voxels)
- ROI Analysis → Per-region activation summaries with uncertainty
- LLM Q&A → Grounded interpretation (only references model outputs)
References
- Allen et al. (2022). A massive 7T fMRI dataset. Nature Neuroscience
- Gifford et al. (2023). The Algonauts Project 2023 Challenge
- Radford et al. (2021). Learning Transferable Visual Models (CLIP)
- Adeli & Zelinsky (2025). Transformer Brain Encoders (arxiv:2505.17329)
Usage
from huggingface_hub import hf_hub_download
import torch, numpy as np
model_path = hf_hub_download(repo_id="ryu34/multimodal-brain-encoder", filename="best_model.pt")
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
model = BrainEncoder(**checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])
predictions = model(features)