dawidtang commited on
Commit
fe3cbbb
·
verified ·
1 Parent(s): e856a53

Upload folder using huggingface_hub

Browse files
.keep ADDED
File without changes
README.md CHANGED
@@ -1,3 +1,123 @@
1
  ---
2
- license: cc-by-nc-4.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ tags:
3
+ - neuroscience
4
+ - fmri
5
+ - video
6
+ - v-jepa
7
+ - pytorch
8
+ library_name: pytorch
9
  ---
10
+
11
+ # V-JEPA2 Offline Encoder for Video-Evoked BOLD Responses
12
+
13
+ This repository contains a PyTorch checkpoint for a basic V-JEPA2-based offline encoder trained to predict video-evoked BOLD responses. The encoder is intended for research workflows involving neural response prediction and neural response-guided visual synthesis.
14
+
15
+ The checkpoint stores decoder weights and metadata for an offline encoder. This repository includes a custom `transformers.AutoModel` wrapper and does not require the original training codebase.
16
+
17
+ ## Files
18
+
19
+ - `vjepa2_offline_encoder.pth`: PyTorch checkpoint containing decoder weights, decoding-unit selection metadata, feature-extractor configuration, and registered attributes.
20
+ - `config.json`, `configuration_vjepa2_fmri_encoder.py`, `modeling_vjepa2_fmri_encoder.py`: custom Transformers files for `AutoModel` loading.
21
+ - `requirements.txt`: minimal Python dependencies.
22
+
23
+ ## Data
24
+
25
+ This checkpoint was trained using data from:
26
+
27
+ - **BOLD Moments Dataset (BMD)**: whole-brain fMRI responses to short naturalistic videos.
28
+ - **Social interaction video fMRI dataset from Emalie McMahon and collaborators**: fMRI responses to naturalistic two-person social action videos.
29
+
30
+ This repository does not include the underlying fMRI datasets or stimulus videos.
31
+
32
+ ## Input/Output Contract
33
+
34
+ The intended input is a short video clip corresponding to the training stimulus duration:
35
+
36
+ - **Input**: one 3-second RGB video clip, represented as a float tensor shaped `[B, T, C, H, W]` with values in `[0, 1]`.
37
+ - **Output**: one vector of predicted z-scored fMRI beta responses per video, shaped `[B, 20484]`.
38
+ - **Temporal dimension**: the output has no time dimension. Each 3-second video maps to a single predicted response vector.
39
+
40
+ This makes the encoder suitable for scoring or optimizing short generated videos against static target neural-response patterns.
41
+
42
+ The video-input path resizes frames to `224 x 224` and applies the ImageNet normalization used by the V-JEPA2 training pipeline. If you pass already-normalized V-JEPA2 inputs, call `model.predict_fmri(video, normalize=False)`.
43
+
44
+ ## Loading
45
+
46
+ This checkpoint can be loaded with `transformers.AutoModel` and `trust_remote_code=True`.
47
+
48
+ Example:
49
+
50
+ ```python
51
+ import torch
52
+ from transformers import AutoModel
53
+
54
+ model = AutoModel.from_pretrained(
55
+ "epfl-neuroai/vjepa2-enoder-basic",
56
+ trust_remote_code=True,
57
+ )
58
+ model.eval()
59
+
60
+ # Replace this with a preprocessed 3-second video tensor.
61
+ # Shape: [batch, frames, channels, height, width].
62
+ video = torch.zeros(1, 16, 3, 224, 224)
63
+
64
+ with torch.no_grad():
65
+ prediction = model.predict_fmri(video)
66
+
67
+ print(prediction.shape) # [1, 20484]
68
+ ```
69
+
70
+ For decoder-only debugging, the model can also run from precomputed V-JEPA2 layer features:
71
+
72
+ ```python
73
+ model = AutoModel.from_pretrained(
74
+ "epfl-neuroai/vjepa2-enoder-basic",
75
+ trust_remote_code=True,
76
+ load_vjepa=False,
77
+ )
78
+
79
+ features = [
80
+ torch.zeros(1, decoder.mean.shape[1])
81
+ for decoder in model.decoders
82
+ ]
83
+
84
+ with torch.no_grad():
85
+ prediction = model.forward_features(features)
86
+ ```
87
+
88
+ ## Citations
89
+
90
+ If you use this checkpoint, please cite the source datasets:
91
+
92
+ ```bibtex
93
+ @article{tang2025diverse,
94
+ title={Diverse perceptual representations across visual pathways emerge from a single objective},
95
+ author={Tang, Yingtian and Gokce, Abdulkadir and Al-Karkari, Khaled Jedoui and Yamins, Daniel and Schrimpf, Martin},
96
+ journal={bioRxiv},
97
+ pages={2025--07},
98
+ year={2025},
99
+ publisher={Cold Spring Harbor Laboratory}
100
+ }
101
+
102
+ @article{lahner2024modeling,
103
+ title={Modeling short visual events through the BOLD moments video fMRI dataset and metadata},
104
+ author={Lahner, Benjamin and Dwivedi, Kshitij and Iamshchinina, Polina and Graumann, Monika and Lascelles, Alex and Roig, Gemma and Gifford, Alessandro Thomas and Pan, Bowen and Jin, SouYoung and Ratan Murty, N Apurva and others},
105
+ journal={Nature communications},
106
+ volume={15},
107
+ number={1},
108
+ pages={6241},
109
+ year={2024},
110
+ publisher={Nature Publishing Group UK London}
111
+ }
112
+
113
+ @article{mcmahon2023hierarchical,
114
+ title={Hierarchical organization of social action features along the lateral visual pathway},
115
+ author={McMahon, Emalie and Bonner, Michael F and Isik, Leyla},
116
+ journal={Current Biology},
117
+ volume={33},
118
+ number={23},
119
+ pages={5035--5047},
120
+ year={2023},
121
+ publisher={Elsevier}
122
+ }
123
+ ```
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "vjepa2_fmri_encoder",
3
+ "architectures": [
4
+ "VJEPA2FMRIEncoderModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_vjepa2_fmri_encoder.VJEPA2FMRIEncoderConfig",
8
+ "AutoModel": "modeling_vjepa2_fmri_encoder.VJEPA2FMRIEncoderModel"
9
+ },
10
+ "checkpoint_filename": "vjepa2_offline_encoder.pth",
11
+ "output_dim": 20484,
12
+ "input_duration_seconds": 3.0,
13
+ "input_format": "video_tensor_b_t_c_h_w",
14
+ "output_description": "z_scored_fmri_betas_no_time_dimension",
15
+ "vjepa_size": "large",
16
+ "load_vjepa": true,
17
+ "image_size": 224,
18
+ "normalize_input": true
19
+ }
configuration_vjepa2_fmri_encoder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformers config for the V-JEPA2 fMRI encoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class VJEPA2FMRIEncoderConfig(PretrainedConfig):
9
+ model_type = "vjepa2_fmri_encoder"
10
+
11
+ def __init__(
12
+ self,
13
+ checkpoint_filename: str = "vjepa2_offline_encoder.pth",
14
+ output_dim: int = 20484,
15
+ input_duration_seconds: float = 3.0,
16
+ input_format: str = "video_tensor_b_t_c_h_w",
17
+ output_description: str = "z_scored_fmri_betas_no_time_dimension",
18
+ vjepa_size: str = "large",
19
+ load_vjepa: bool = True,
20
+ image_size: int = 224,
21
+ normalize_input: bool = True,
22
+ **kwargs,
23
+ ) -> None:
24
+ super().__init__(**kwargs)
25
+ self.checkpoint_filename = checkpoint_filename
26
+ self.output_dim = int(output_dim)
27
+ self.input_duration_seconds = float(input_duration_seconds)
28
+ self.input_format = input_format
29
+ self.output_description = output_description
30
+ self.vjepa_size = vjepa_size
31
+ self.load_vjepa = bool(load_vjepa)
32
+ self.image_size = int(image_size)
33
+ self.normalize_input = bool(normalize_input)
modeling_vjepa2_fmri_encoder.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom AutoModel implementation for a basic V-JEPA2 fMRI encoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Iterable
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from transformers import PreTrainedModel
13
+
14
+ try:
15
+ from .configuration_vjepa2_fmri_encoder import VJEPA2FMRIEncoderConfig
16
+ except ImportError:
17
+ from configuration_vjepa2_fmri_encoder import VJEPA2FMRIEncoderConfig
18
+
19
+
20
+ class RidgeDecoder(nn.Module):
21
+ def __init__(self, state_dict: dict[str, torch.Tensor]) -> None:
22
+ super().__init__()
23
+ self.register_buffer("mean", state_dict["steps.1.mean"])
24
+ self.register_buffer("std", state_dict["steps.1.std"])
25
+ self.register_buffer("coef", state_dict["steps.2.regressor._coef"])
26
+ self.register_buffer("intercept", state_dict["steps.2.regressor._intercept"])
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ x = x.reshape(x.shape[0], -1)
30
+ x = (x - self.mean.to(device=x.device)) / self.std.to(device=x.device)
31
+ coef = self.coef.to(device=x.device)
32
+ x = x.to(dtype=coef.dtype)
33
+ return x @ coef.T + self.intercept.to(device=x.device)
34
+
35
+
36
+ class HookedFeatureExtractor:
37
+ def __init__(self, layer_names: Iterable[str], ret_type: str = "chw", spatial_pool: int = 14) -> None:
38
+ self.layer_names = list(layer_names)
39
+ self.ret_type = ret_type
40
+ self.spatial_pool = int(spatial_pool)
41
+ self.outputs: dict[str, torch.Tensor] = {}
42
+ self.hooks = []
43
+
44
+ @staticmethod
45
+ def _get_layer(model: nn.Module, layer_name: str) -> nn.Module:
46
+ layer: object = model
47
+ for part in layer_name.split("."):
48
+ layer = layer[int(part)] if part.isdigit() else getattr(layer, part)
49
+ if not isinstance(layer, nn.Module):
50
+ raise TypeError(f"{layer_name} did not resolve to a torch module")
51
+ return layer
52
+
53
+ def __call__(self, model: nn.Module, videos: torch.Tensor, **model_kwargs) -> list[torch.Tensor]:
54
+ self.outputs = {}
55
+ self.hooks = [
56
+ self._get_layer(model, name).register_forward_hook(
57
+ lambda _module, _inputs, output, name=name: self.outputs.__setitem__(name, output)
58
+ )
59
+ for name in self.layer_names
60
+ ]
61
+ try:
62
+ model(videos, **model_kwargs)
63
+ finally:
64
+ for hook in self.hooks:
65
+ hook.remove()
66
+ self.hooks = []
67
+ return [self._process_feature(self.outputs[name]) for name in self.layer_names]
68
+
69
+ def _process_feature(self, feature: torch.Tensor) -> torch.Tensor:
70
+ batch, _thw, channels = feature.shape
71
+ feature = feature.reshape(batch, -1, 14, 14, channels).permute(0, 1, 4, 2, 3)
72
+ if self.spatial_pool > 1:
73
+ batch, frames, channels, height, width = feature.shape
74
+ new_height = height // self.spatial_pool
75
+ new_width = width // self.spatial_pool
76
+ feature = feature.reshape(
77
+ batch,
78
+ frames,
79
+ channels,
80
+ new_height,
81
+ self.spatial_pool,
82
+ new_width,
83
+ self.spatial_pool,
84
+ )
85
+ feature = feature.permute(0, 1, 2, 3, 5, 4, 6).mean(dim=(-2, -1))
86
+ if self.ret_type == "chw":
87
+ return feature.mean(dim=1)
88
+ if self.ret_type == "tchw":
89
+ return feature
90
+ raise ValueError(f"Unsupported ret_type: {self.ret_type}")
91
+
92
+
93
+ class VJEPA2Backbone(nn.Module):
94
+ def __init__(self, size: str, image_size: int, normalize_input: bool) -> None:
95
+ super().__init__()
96
+ self.image_size = int(image_size)
97
+ self.normalize_input = bool(normalize_input)
98
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1))
99
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1))
100
+ hub_name = {
101
+ "large": "vjepa2_vit_large",
102
+ "huge": "vjepa2_vit_huge",
103
+ "giant": "vjepa2_vit_giant",
104
+ }[size]
105
+ backbone = torch.hub.load("facebookresearch/vjepa2", hub_name, pretrained=True)
106
+ self.backbone = backbone[0] if isinstance(backbone, (list, tuple)) else backbone
107
+
108
+ def forward(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor:
109
+ if videos.ndim != 5:
110
+ raise ValueError(f"Expected video tensor shaped [B, T, C, H, W], got {tuple(videos.shape)}")
111
+ if videos.shape[2] != 3:
112
+ raise ValueError(f"Expected RGB video with 3 channels at dim 2, got {videos.shape[2]}")
113
+
114
+ videos = videos.float()
115
+ batch, frames, channels, height, width = videos.shape
116
+ if height != self.image_size or width != self.image_size:
117
+ videos = videos.reshape(batch * frames, channels, height, width)
118
+ videos = F.interpolate(
119
+ videos,
120
+ size=(self.image_size, self.image_size),
121
+ mode="bilinear",
122
+ align_corners=False,
123
+ )
124
+ videos = videos.reshape(batch, frames, channels, self.image_size, self.image_size)
125
+
126
+ normalize = self.normalize_input if normalize is None else bool(normalize)
127
+ if normalize:
128
+ videos = (videos - self.image_mean.to(device=videos.device, dtype=videos.dtype)) / self.image_std.to(
129
+ device=videos.device,
130
+ dtype=videos.dtype,
131
+ )
132
+ return self.backbone(videos.permute(0, 2, 1, 3, 4))
133
+
134
+
135
+ class VJEPA2FMRIEncoderModel(PreTrainedModel):
136
+ config_class = VJEPA2FMRIEncoderConfig
137
+ base_model_prefix = "vjepa2_fmri_encoder"
138
+ main_input_name = "videos"
139
+
140
+ def __init__(self, config: VJEPA2FMRIEncoderConfig) -> None:
141
+ super().__init__(config)
142
+ self.decoders = nn.ModuleList()
143
+ self.register_buffer("decoding_units", torch.empty(0, dtype=torch.long))
144
+ self.extractor: HookedFeatureExtractor | None = None
145
+ self.vjepa: VJEPA2Backbone | None = None
146
+
147
+ @classmethod
148
+ def from_pretrained(
149
+ cls,
150
+ pretrained_model_name_or_path: str | os.PathLike[str],
151
+ *model_args: Any,
152
+ config: VJEPA2FMRIEncoderConfig | None = None,
153
+ load_vjepa: bool | None = None,
154
+ vjepa_size: str | None = None,
155
+ normalize_input: bool | None = None,
156
+ **kwargs: Any,
157
+ ) -> "VJEPA2FMRIEncoderModel":
158
+ if model_args:
159
+ raise TypeError("Unexpected positional arguments for VJEPA2FMRIEncoderModel.from_pretrained")
160
+
161
+ revision = kwargs.pop("revision", None)
162
+ token = kwargs.pop("token", None)
163
+ cache_dir = kwargs.pop("cache_dir", None)
164
+ local_files_only = kwargs.pop("local_files_only", False)
165
+ for ignored in ("trust_remote_code", "state_dict", "ignore_mismatched_sizes", "adapter_kwargs", "weights_only"):
166
+ kwargs.pop(ignored, None)
167
+ if kwargs:
168
+ raise TypeError(f"Unsupported keyword argument(s): {', '.join(sorted(kwargs))}")
169
+
170
+ if config is None:
171
+ config = VJEPA2FMRIEncoderConfig.from_pretrained(
172
+ pretrained_model_name_or_path,
173
+ revision=revision,
174
+ token=token,
175
+ cache_dir=cache_dir,
176
+ local_files_only=local_files_only,
177
+ )
178
+
179
+ checkpoint_path = cls._resolve_checkpoint_path(
180
+ pretrained_model_name_or_path,
181
+ filename=config.checkpoint_filename,
182
+ revision=revision,
183
+ token=token,
184
+ cache_dir=cache_dir,
185
+ local_files_only=local_files_only,
186
+ )
187
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
188
+
189
+ model = cls(config)
190
+ model.decoders = nn.ModuleList([RidgeDecoder(state_dict) for state_dict in checkpoint["decoders_state_dict"]])
191
+ model.register_buffer("decoding_units", checkpoint["decoding_units"].long())
192
+ for name, value in checkpoint.get("registered_attrs", {}).items():
193
+ if torch.is_tensor(value):
194
+ model.register_buffer(name, value)
195
+
196
+ load_vjepa = config.load_vjepa if load_vjepa is None else bool(load_vjepa)
197
+ vjepa_size = config.vjepa_size if vjepa_size is None else vjepa_size
198
+ normalize_input = config.normalize_input if normalize_input is None else bool(normalize_input)
199
+ if load_vjepa:
200
+ extractor_config = checkpoint["extractor_config"]
201
+ model.extractor = HookedFeatureExtractor(
202
+ layer_names=extractor_config["layer_names"],
203
+ ret_type=extractor_config.get("ret_type", "chw"),
204
+ spatial_pool=extractor_config.get("spatial_pool", 14),
205
+ )
206
+ model.vjepa = VJEPA2Backbone(
207
+ size=vjepa_size,
208
+ image_size=config.image_size,
209
+ normalize_input=normalize_input,
210
+ )
211
+ model.eval()
212
+ return model
213
+
214
+ @staticmethod
215
+ def _resolve_checkpoint_path(
216
+ pretrained_model_name_or_path: str | os.PathLike[str],
217
+ *,
218
+ filename: str,
219
+ revision: str | None,
220
+ token: str | bool | None,
221
+ cache_dir: str | os.PathLike[str] | None,
222
+ local_files_only: bool,
223
+ ) -> str:
224
+ path = Path(pretrained_model_name_or_path)
225
+ if path.exists():
226
+ checkpoint_path = path / filename if path.is_dir() else path
227
+ if not checkpoint_path.exists():
228
+ raise FileNotFoundError(f"Missing checkpoint file: {checkpoint_path}")
229
+ return str(checkpoint_path)
230
+
231
+ from huggingface_hub import hf_hub_download
232
+
233
+ return hf_hub_download(
234
+ repo_id=str(pretrained_model_name_or_path),
235
+ filename=filename,
236
+ repo_type="model",
237
+ revision=revision,
238
+ token=token,
239
+ cache_dir=cache_dir,
240
+ local_files_only=local_files_only,
241
+ )
242
+
243
+ def forward_features(self, features: list[torch.Tensor]) -> torch.Tensor:
244
+ if len(features) != len(self.decoders):
245
+ raise ValueError(f"Expected {len(self.decoders)} feature tensors, got {len(features)}")
246
+ outputs = [decoder(feature) for decoder, feature in zip(self.decoders, features)]
247
+ output = torch.stack(outputs, dim=-1)
248
+ index = self.decoding_units.to(output.device).unsqueeze(0).unsqueeze(-1)
249
+ index = index.expand(output.shape[0], -1, -1)
250
+ return output.gather(dim=2, index=index).squeeze(-1)
251
+
252
+ def forward(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor:
253
+ if self.vjepa is None or self.extractor is None:
254
+ raise RuntimeError("This model was loaded with load_vjepa=False.")
255
+ features = self.extractor(self.vjepa, videos, normalize=normalize)
256
+ return self.forward_features(features)
257
+
258
+ def predict_fmri(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor:
259
+ """Predict z-scored fMRI beta responses for videos shaped [B, T, C, H, W]."""
260
+
261
+ return self(videos, normalize=normalize)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ huggingface_hub
3
+ transformers
vjepa2_offline_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b2ec1499735098a1c97e67c84837241c349809f278d334f8c0e0c7b5ef1fe3b
3
+ size 2125320301