| """Custom AutoModel implementation for a basic V-JEPA2 fMRI encoder.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any, Iterable |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from transformers import PreTrainedModel |
|
|
| try: |
| from .configuration_vjepa2_fmri_encoder import VJEPA2FMRIEncoderConfig |
| except ImportError: |
| from configuration_vjepa2_fmri_encoder import VJEPA2FMRIEncoderConfig |
|
|
|
|
| class RidgeDecoder(nn.Module): |
| def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: |
| super().__init__() |
| self.register_buffer("mean", state_dict["steps.1.mean"]) |
| self.register_buffer("std", state_dict["steps.1.std"]) |
| self.register_buffer("coef", state_dict["steps.2.regressor._coef"]) |
| self.register_buffer("intercept", state_dict["steps.2.regressor._intercept"]) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x.reshape(x.shape[0], -1) |
| x = (x - self.mean.to(device=x.device)) / self.std.to(device=x.device) |
| coef = self.coef.to(device=x.device) |
| x = x.to(dtype=coef.dtype) |
| return x @ coef.T + self.intercept.to(device=x.device) |
|
|
|
|
| class HookedFeatureExtractor: |
| def __init__(self, layer_names: Iterable[str], ret_type: str = "chw", spatial_pool: int = 14) -> None: |
| self.layer_names = list(layer_names) |
| self.ret_type = ret_type |
| self.spatial_pool = int(spatial_pool) |
|
|
| @staticmethod |
| def _get_layer(model: nn.Module, layer_name: str) -> nn.Module: |
| layer: object = model |
| for part in layer_name.split("."): |
| layer = layer[int(part)] if part.isdigit() else getattr(layer, part) |
| if not isinstance(layer, nn.Module): |
| raise TypeError(f"{layer_name} did not resolve to a torch module") |
| return layer |
|
|
| @staticmethod |
| def _unwrap_output(output: Any) -> torch.Tensor: |
| if isinstance(output, (list, tuple)): |
| if len(output) == 0: |
| raise ValueError("Received an empty feature tuple.") |
| output = output[0] |
| if not torch.is_tensor(output): |
| raise TypeError(f"Expected tensor feature output, got {type(output)!r}") |
| return output |
|
|
| def __call__(self, model: nn.Module, videos: torch.Tensor, **model_kwargs) -> list[torch.Tensor]: |
| outputs: dict[str, torch.Tensor] = {} |
| hooks = [ |
| self._get_layer(model, name).register_forward_hook( |
| lambda _module, _inputs, output, name=name: outputs.__setitem__(name, self._unwrap_output(output)) |
| ) |
| for name in self.layer_names |
| ] |
| try: |
| model(videos, **model_kwargs) |
| finally: |
| for hook in hooks: |
| hook.remove() |
| return [self._process_feature(outputs[name]) for name in self.layer_names] |
|
|
| def _process_feature(self, feature: torch.Tensor) -> torch.Tensor: |
| batch, tokens, channels = feature.shape |
| feature = feature.reshape(batch, -1, 14, 14, channels).permute(0, 1, 4, 2, 3) |
| if self.spatial_pool > 1: |
| batch, frames, channels, height, width = feature.shape |
| new_height = height // self.spatial_pool |
| new_width = width // self.spatial_pool |
| feature = feature.reshape( |
| batch, |
| frames, |
| channels, |
| new_height, |
| self.spatial_pool, |
| new_width, |
| self.spatial_pool, |
| ) |
| feature = feature.permute(0, 1, 2, 3, 5, 4, 6).mean(dim=(-2, -1)) |
| if self.ret_type == "chw": |
| return feature.mean(dim=1) |
| if self.ret_type == "tchw": |
| return feature |
| raise ValueError(f"Unsupported ret_type: {self.ret_type}") |
|
|
|
|
| class LocalVJEPA2Backbone(nn.Module): |
| def __init__(self, size: str, image_size: int, normalize_input: bool, checkpoint_path: str) -> None: |
| super().__init__() |
| self.image_size = int(image_size) |
| self.normalize_input = bool(normalize_input) |
| self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1)) |
| self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1)) |
| hub_name = { |
| "large": "vjepa2_vit_large", |
| "huge": "vjepa2_vit_huge", |
| "giant": "vjepa2_vit_giant", |
| }[size] |
| backbone = torch.hub.load("facebookresearch/vjepa2", hub_name, pretrained=False) |
| backbone, predictor = backbone if isinstance(backbone, (list, tuple)) else (backbone, None) |
| state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| backbone.load_state_dict(_clean_backbone_key(state_dict["target_encoder"]), strict=False) |
| if predictor is not None and "predictor" in state_dict: |
| predictor.load_state_dict(_clean_backbone_key(state_dict["predictor"]), strict=False) |
| self.backbone = backbone |
|
|
| def forward(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor: |
| if videos.ndim != 5: |
| raise ValueError(f"Expected video tensor shaped [B, T, C, H, W], got {tuple(videos.shape)}") |
| if videos.shape[2] != 3: |
| raise ValueError(f"Expected RGB video with 3 channels at dim 2, got {videos.shape[2]}") |
|
|
| videos = videos.float() |
| batch, frames, channels, height, width = videos.shape |
| if height != self.image_size or width != self.image_size: |
| videos = videos.reshape(batch * frames, channels, height, width) |
| videos = F.interpolate( |
| videos, |
| size=(self.image_size, self.image_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| videos = videos.reshape(batch, frames, channels, self.image_size, self.image_size) |
|
|
| normalize = self.normalize_input if normalize is None else bool(normalize) |
| if normalize: |
| videos = (videos - self.image_mean.to(device=videos.device, dtype=videos.dtype)) / self.image_std.to( |
| device=videos.device, |
| dtype=videos.dtype, |
| ) |
| return self.backbone(videos.permute(0, 2, 1, 3, 4)) |
|
|
|
|
| def _clean_backbone_key(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| cleaned = {} |
| for key, value in state_dict.items(): |
| key = key.replace("module.", "").replace("backbone.", "") |
| cleaned[key] = value |
| return cleaned |
|
|
|
|
| class VJEPA2FMRIEncoderModel(PreTrainedModel): |
| config_class = VJEPA2FMRIEncoderConfig |
| base_model_prefix = "vjepa2_fmri_encoder" |
| main_input_name = "videos" |
|
|
| def __init__(self, config: VJEPA2FMRIEncoderConfig) -> None: |
| super().__init__(config) |
| self.decoders = nn.ModuleList() |
| self.register_buffer("decoding_units", torch.empty(0, dtype=torch.long)) |
| self.extractor: HookedFeatureExtractor | None = None |
| self.vjepa: LocalVJEPA2Backbone | None = None |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str | os.PathLike[str], |
| *model_args: Any, |
| config: VJEPA2FMRIEncoderConfig | None = None, |
| load_vjepa: bool | None = None, |
| vjepa_size: str | None = None, |
| normalize_input: bool | None = None, |
| **kwargs: Any, |
| ) -> "VJEPA2FMRIEncoderModel": |
| if model_args: |
| raise TypeError("Unexpected positional arguments for VJEPA2FMRIEncoderModel.from_pretrained") |
|
|
| revision = kwargs.pop("revision", None) |
| token = kwargs.pop("token", None) |
| cache_dir = kwargs.pop("cache_dir", None) |
| local_files_only = kwargs.pop("local_files_only", False) |
| for ignored in ("trust_remote_code", "state_dict", "ignore_mismatched_sizes", "adapter_kwargs", "weights_only"): |
| kwargs.pop(ignored, None) |
| if kwargs: |
| raise TypeError(f"Unsupported keyword argument(s): {', '.join(sorted(kwargs))}") |
|
|
| if config is None: |
| config = VJEPA2FMRIEncoderConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| revision=revision, |
| token=token, |
| cache_dir=cache_dir, |
| local_files_only=local_files_only, |
| ) |
|
|
| checkpoint_path = cls._resolve_file_path( |
| pretrained_model_name_or_path, |
| filename=config.checkpoint_filename, |
| revision=revision, |
| token=token, |
| cache_dir=cache_dir, |
| local_files_only=local_files_only, |
| ) |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
| model = cls(config) |
| model.decoders = nn.ModuleList([RidgeDecoder(state_dict) for state_dict in checkpoint["decoders_state_dict"]]) |
| model.register_buffer("decoding_units", checkpoint["decoding_units"].long()) |
| for name, value in checkpoint.get("registered_attrs", {}).items(): |
| if torch.is_tensor(value): |
| model.register_buffer(name, value) |
|
|
| load_vjepa = config.load_vjepa if load_vjepa is None else bool(load_vjepa) |
| vjepa_size = config.vjepa_size if vjepa_size is None else vjepa_size |
| normalize_input = config.normalize_input if normalize_input is None else bool(normalize_input) |
| if load_vjepa: |
| backbone_path = cls._resolve_file_path( |
| pretrained_model_name_or_path, |
| filename=config.backbone_filename, |
| revision=revision, |
| token=token, |
| cache_dir=cache_dir, |
| local_files_only=local_files_only, |
| ) |
| extractor_config = checkpoint["extractor_config"] |
| model.extractor = HookedFeatureExtractor( |
| layer_names=cls._resolve_layer_names(extractor_config), |
| ret_type=extractor_config.get("ret_type", "chw"), |
| spatial_pool=extractor_config.get("spatial_pool", 14), |
| ) |
| model.vjepa = LocalVJEPA2Backbone( |
| size=vjepa_size, |
| image_size=config.image_size, |
| normalize_input=normalize_input, |
| checkpoint_path=backbone_path, |
| ) |
| model.eval() |
| return model |
|
|
| @staticmethod |
| def _resolve_file_path( |
| pretrained_model_name_or_path: str | os.PathLike[str], |
| *, |
| filename: str, |
| revision: str | None, |
| token: str | bool | None, |
| cache_dir: str | os.PathLike[str] | None, |
| local_files_only: bool, |
| ) -> str: |
| path = Path(pretrained_model_name_or_path) |
| if path.exists(): |
| file_path = path / filename if path.is_dir() else path |
| if not file_path.exists(): |
| raise FileNotFoundError(f"Missing file: {file_path}") |
| return str(file_path) |
|
|
| from huggingface_hub import hf_hub_download |
|
|
| return hf_hub_download( |
| repo_id=str(pretrained_model_name_or_path), |
| filename=filename, |
| repo_type="model", |
| revision=revision, |
| token=token, |
| cache_dir=cache_dir, |
| local_files_only=local_files_only, |
| ) |
|
|
| @staticmethod |
| def _resolve_layer_names(extractor_config: dict[str, Any]) -> list[str]: |
| layer_names = extractor_config.get("layer_names") |
| if layer_names is None: |
| layer_names = extractor_config.get("loi") |
| if layer_names is None: |
| raise KeyError("extractor_config must contain `layer_names` or `loi`.") |
| return list(layer_names) |
|
|
| def forward_features(self, features: list[torch.Tensor]) -> torch.Tensor: |
| if len(features) != len(self.decoders): |
| raise ValueError(f"Expected {len(self.decoders)} feature tensors, got {len(features)}") |
| outputs = [decoder(feature) for decoder, feature in zip(self.decoders, features)] |
| output = torch.stack(outputs, dim=-1) |
| index = self.decoding_units.to(output.device).unsqueeze(0).unsqueeze(-1) |
| index = index.expand(output.shape[0], -1, -1) |
| return output.gather(dim=2, index=index).squeeze(-1) |
|
|
| def forward(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor: |
| if self.vjepa is None or self.extractor is None: |
| raise RuntimeError("This model was loaded with load_vjepa=False.") |
| features = self.extractor(self.vjepa, videos, normalize=normalize) |
| return self.forward_features(features) |
|
|
| def predict_fmri(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor: |
| """Predict z-scored fMRI beta responses for videos shaped [B, T, C, H, W].""" |
|
|
| return self(videos, normalize=normalize) |
|
|