"""Custom AutoModel implementation for a basic V-JEPA2 fMRI encoder.""" from __future__ import annotations import os from pathlib import Path from typing import Any, Iterable import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel try: from .configuration_vjepa2_fmri_encoder import VJEPA2FMRIEncoderConfig except ImportError: from configuration_vjepa2_fmri_encoder import VJEPA2FMRIEncoderConfig class RidgeDecoder(nn.Module): def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: super().__init__() self.register_buffer("mean", state_dict["steps.1.mean"]) self.register_buffer("std", state_dict["steps.1.std"]) self.register_buffer("coef", state_dict["steps.2.regressor._coef"]) self.register_buffer("intercept", state_dict["steps.2.regressor._intercept"]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(x.shape[0], -1) x = (x - self.mean.to(device=x.device)) / self.std.to(device=x.device) coef = self.coef.to(device=x.device) x = x.to(dtype=coef.dtype) return x @ coef.T + self.intercept.to(device=x.device) class HookedFeatureExtractor: def __init__(self, layer_names: Iterable[str], ret_type: str = "chw", spatial_pool: int = 14) -> None: self.layer_names = list(layer_names) self.ret_type = ret_type self.spatial_pool = int(spatial_pool) @staticmethod def _get_layer(model: nn.Module, layer_name: str) -> nn.Module: layer: object = model for part in layer_name.split("."): layer = layer[int(part)] if part.isdigit() else getattr(layer, part) if not isinstance(layer, nn.Module): raise TypeError(f"{layer_name} did not resolve to a torch module") return layer @staticmethod def _unwrap_output(output: Any) -> torch.Tensor: if isinstance(output, (list, tuple)): if len(output) == 0: raise ValueError("Received an empty feature tuple.") output = output[0] if not torch.is_tensor(output): raise TypeError(f"Expected tensor feature output, got {type(output)!r}") return output def __call__(self, model: nn.Module, videos: torch.Tensor, **model_kwargs) -> list[torch.Tensor]: outputs: dict[str, torch.Tensor] = {} hooks = [ self._get_layer(model, name).register_forward_hook( lambda _module, _inputs, output, name=name: outputs.__setitem__(name, self._unwrap_output(output)) ) for name in self.layer_names ] try: model(videos, **model_kwargs) finally: for hook in hooks: hook.remove() return [self._process_feature(outputs[name]) for name in self.layer_names] def _process_feature(self, feature: torch.Tensor) -> torch.Tensor: batch, tokens, channels = feature.shape feature = feature.reshape(batch, -1, 14, 14, channels).permute(0, 1, 4, 2, 3) if self.spatial_pool > 1: batch, frames, channels, height, width = feature.shape new_height = height // self.spatial_pool new_width = width // self.spatial_pool feature = feature.reshape( batch, frames, channels, new_height, self.spatial_pool, new_width, self.spatial_pool, ) feature = feature.permute(0, 1, 2, 3, 5, 4, 6).mean(dim=(-2, -1)) if self.ret_type == "chw": return feature.mean(dim=1) if self.ret_type == "tchw": return feature raise ValueError(f"Unsupported ret_type: {self.ret_type}") class LocalVJEPA2Backbone(nn.Module): def __init__(self, size: str, image_size: int, normalize_input: bool, checkpoint_path: str) -> None: super().__init__() self.image_size = int(image_size) self.normalize_input = bool(normalize_input) self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1)) self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1)) hub_name = { "large": "vjepa2_vit_large", "huge": "vjepa2_vit_huge", "giant": "vjepa2_vit_giant", }[size] backbone = torch.hub.load("facebookresearch/vjepa2", hub_name, pretrained=False) backbone, predictor = backbone if isinstance(backbone, (list, tuple)) else (backbone, None) state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) backbone.load_state_dict(_clean_backbone_key(state_dict["target_encoder"]), strict=False) if predictor is not None and "predictor" in state_dict: predictor.load_state_dict(_clean_backbone_key(state_dict["predictor"]), strict=False) self.backbone = backbone def forward(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor: if videos.ndim != 5: raise ValueError(f"Expected video tensor shaped [B, T, C, H, W], got {tuple(videos.shape)}") if videos.shape[2] != 3: raise ValueError(f"Expected RGB video with 3 channels at dim 2, got {videos.shape[2]}") videos = videos.float() batch, frames, channels, height, width = videos.shape if height != self.image_size or width != self.image_size: videos = videos.reshape(batch * frames, channels, height, width) videos = F.interpolate( videos, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) videos = videos.reshape(batch, frames, channels, self.image_size, self.image_size) normalize = self.normalize_input if normalize is None else bool(normalize) if normalize: videos = (videos - self.image_mean.to(device=videos.device, dtype=videos.dtype)) / self.image_std.to( device=videos.device, dtype=videos.dtype, ) return self.backbone(videos.permute(0, 2, 1, 3, 4)) def _clean_backbone_key(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: cleaned = {} for key, value in state_dict.items(): key = key.replace("module.", "").replace("backbone.", "") cleaned[key] = value return cleaned class VJEPA2FMRIEncoderModel(PreTrainedModel): config_class = VJEPA2FMRIEncoderConfig base_model_prefix = "vjepa2_fmri_encoder" main_input_name = "videos" def __init__(self, config: VJEPA2FMRIEncoderConfig) -> None: super().__init__(config) self.decoders = nn.ModuleList() self.register_buffer("decoding_units", torch.empty(0, dtype=torch.long)) self.extractor: HookedFeatureExtractor | None = None self.vjepa: LocalVJEPA2Backbone | None = None @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str | os.PathLike[str], *model_args: Any, config: VJEPA2FMRIEncoderConfig | None = None, load_vjepa: bool | None = None, vjepa_size: str | None = None, normalize_input: bool | None = None, **kwargs: Any, ) -> "VJEPA2FMRIEncoderModel": if model_args: raise TypeError("Unexpected positional arguments for VJEPA2FMRIEncoderModel.from_pretrained") revision = kwargs.pop("revision", None) token = kwargs.pop("token", None) cache_dir = kwargs.pop("cache_dir", None) local_files_only = kwargs.pop("local_files_only", False) for ignored in ("trust_remote_code", "state_dict", "ignore_mismatched_sizes", "adapter_kwargs", "weights_only"): kwargs.pop(ignored, None) if kwargs: raise TypeError(f"Unsupported keyword argument(s): {', '.join(sorted(kwargs))}") if config is None: config = VJEPA2FMRIEncoderConfig.from_pretrained( pretrained_model_name_or_path, revision=revision, token=token, cache_dir=cache_dir, local_files_only=local_files_only, ) checkpoint_path = cls._resolve_file_path( pretrained_model_name_or_path, filename=config.checkpoint_filename, revision=revision, token=token, cache_dir=cache_dir, local_files_only=local_files_only, ) checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model = cls(config) model.decoders = nn.ModuleList([RidgeDecoder(state_dict) for state_dict in checkpoint["decoders_state_dict"]]) model.register_buffer("decoding_units", checkpoint["decoding_units"].long()) for name, value in checkpoint.get("registered_attrs", {}).items(): if torch.is_tensor(value): model.register_buffer(name, value) load_vjepa = config.load_vjepa if load_vjepa is None else bool(load_vjepa) vjepa_size = config.vjepa_size if vjepa_size is None else vjepa_size normalize_input = config.normalize_input if normalize_input is None else bool(normalize_input) if load_vjepa: backbone_path = cls._resolve_file_path( pretrained_model_name_or_path, filename=config.backbone_filename, revision=revision, token=token, cache_dir=cache_dir, local_files_only=local_files_only, ) extractor_config = checkpoint["extractor_config"] model.extractor = HookedFeatureExtractor( layer_names=cls._resolve_layer_names(extractor_config), ret_type=extractor_config.get("ret_type", "chw"), spatial_pool=extractor_config.get("spatial_pool", 14), ) model.vjepa = LocalVJEPA2Backbone( size=vjepa_size, image_size=config.image_size, normalize_input=normalize_input, checkpoint_path=backbone_path, ) model.eval() return model @staticmethod def _resolve_file_path( pretrained_model_name_or_path: str | os.PathLike[str], *, filename: str, revision: str | None, token: str | bool | None, cache_dir: str | os.PathLike[str] | None, local_files_only: bool, ) -> str: path = Path(pretrained_model_name_or_path) if path.exists(): file_path = path / filename if path.is_dir() else path if not file_path.exists(): raise FileNotFoundError(f"Missing file: {file_path}") return str(file_path) from huggingface_hub import hf_hub_download return hf_hub_download( repo_id=str(pretrained_model_name_or_path), filename=filename, repo_type="model", revision=revision, token=token, cache_dir=cache_dir, local_files_only=local_files_only, ) @staticmethod def _resolve_layer_names(extractor_config: dict[str, Any]) -> list[str]: layer_names = extractor_config.get("layer_names") if layer_names is None: layer_names = extractor_config.get("loi") if layer_names is None: raise KeyError("extractor_config must contain `layer_names` or `loi`.") return list(layer_names) def forward_features(self, features: list[torch.Tensor]) -> torch.Tensor: if len(features) != len(self.decoders): raise ValueError(f"Expected {len(self.decoders)} feature tensors, got {len(features)}") outputs = [decoder(feature) for decoder, feature in zip(self.decoders, features)] output = torch.stack(outputs, dim=-1) index = self.decoding_units.to(output.device).unsqueeze(0).unsqueeze(-1) index = index.expand(output.shape[0], -1, -1) return output.gather(dim=2, index=index).squeeze(-1) def forward(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor: if self.vjepa is None or self.extractor is None: raise RuntimeError("This model was loaded with load_vjepa=False.") features = self.extractor(self.vjepa, videos, normalize=normalize) return self.forward_features(features) def predict_fmri(self, videos: torch.Tensor, normalize: bool | None = None) -> torch.Tensor: """Predict z-scored fMRI beta responses for videos shaped [B, T, C, H, W].""" return self(videos, normalize=normalize)