amael-apple's picture
Initial commit
c20d7cc
"""Contains factory functions to build and load ViT.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import logging
import timm
import torch
from sharp.models.presets.vit import VIT_CONFIG_DICT, ViTConfig, ViTPreset
LOGGER = logging.getLogger(__name__)
class TimmViT(timm.models.VisionTransformer):
"""Contains TIMM implementation for Vanilla ViT."""
def __init__(self, config: ViTConfig):
"""Initialize ViT from TIMM implementation."""
# Handle mlp layers.
mlp_layer = timm.layers.GluMlp if config.mlp_mode == "glu" else timm.layers.Mlp
super().__init__(
in_chans=config.in_chans,
embed_dim=config.embed_dim,
depth=config.depth,
num_heads=config.num_heads,
init_values=config.init_values,
img_size=config.img_size,
patch_size=config.patch_size,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
global_pool=config.global_pool,
mlp_layer=mlp_layer,
)
# Required for extracting intermediate features.
self.dim_in = config.in_chans
self.intermediate_features_ids = config.intermediate_features_ids
def reshape_feature(self, embeddings: torch.Tensor):
"""Discard class token and reshape 1D feature map to a 2D grid."""
batch_size, seq_len, channel = embeddings.shape
height, width = self.patch_embed.grid_size
# Remove class token.
if self.num_prefix_tokens:
embeddings = embeddings[:, self.num_prefix_tokens :, :]
# Shape: (batch, height, width, dim) -> (batch, dim, height, width)
embeddings = embeddings.reshape(batch_size, height, width, channel).permute(0, 3, 1, 2)
return embeddings
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[int, torch.Tensor]]:
"""Override forwarding with intermediate features.
Adapted from timm ViT.
Returns:
Output features and list of features from intermediate layers (patch encoder only).
"""
intermediate_features = {}
x = self.patch_embed(input_tensor)
batch_size, seq_len, _ = x.shape
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for idx, block in enumerate(self.blocks):
x = block(x)
if self.intermediate_features_ids is not None and idx in self.intermediate_features_ids:
intermediate_features[idx] = x
x = self.norm(x)
x = self.reshape_feature(x)
return x, intermediate_features
def internal_resolution(self) -> int:
"""Return the internal image size of the network."""
if isinstance(self.patch_embed.img_size, tuple):
return self.patch_embed.img_size[0]
else:
return self.patch_embed.img_size
def create_vit(
config: ViTConfig | None = None,
preset: ViTPreset | None = "dinov2l16_384",
intermediate_features_ids: list[int] | None = None,
) -> TimmViT:
"""Factory function for creating a ViT model."""
if config is not None:
LOGGER.info("Using user-defined config.")
else:
if preset is None:
raise ValueError("User-defined config and preset cannot be both None.")
LOGGER.info("Using preset ViT %s.", preset)
config = VIT_CONFIG_DICT[preset]
config.intermediate_features_ids = intermediate_features_ids
model = TimmViT(config)
LOGGER.debug(model)
return model