amael-apple's picture
Initial commit
c20d7cc
raw
history blame
3.7 kB
"""Contains factory functions to build and load ViT.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import logging
import timm
import torch
from sharp.models.presets.vit import VIT_CONFIG_DICT, ViTConfig, ViTPreset
LOGGER = logging.getLogger(__name__)
class TimmViT(timm.models.VisionTransformer):
"""Contains TIMM implementation for Vanilla ViT."""
def __init__(self, config: ViTConfig):
"""Initialize ViT from TIMM implementation."""
# Handle mlp layers.
mlp_layer = timm.layers.GluMlp if config.mlp_mode == "glu" else timm.layers.Mlp
super().__init__(
in_chans=config.in_chans,
embed_dim=config.embed_dim,
depth=config.depth,
num_heads=config.num_heads,
init_values=config.init_values,
img_size=config.img_size,
patch_size=config.patch_size,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
global_pool=config.global_pool,
mlp_layer=mlp_layer,
)
# Required for extracting intermediate features.
self.dim_in = config.in_chans
self.intermediate_features_ids = config.intermediate_features_ids
def reshape_feature(self, embeddings: torch.Tensor):
"""Discard class token and reshape 1D feature map to a 2D grid."""
batch_size, seq_len, channel = embeddings.shape
height, width = self.patch_embed.grid_size
# Remove class token.
if self.num_prefix_tokens:
embeddings = embeddings[:, self.num_prefix_tokens :, :]
# Shape: (batch, height, width, dim) -> (batch, dim, height, width)
embeddings = embeddings.reshape(batch_size, height, width, channel).permute(0, 3, 1, 2)
return embeddings
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[int, torch.Tensor]]:
"""Override forwarding with intermediate features.
Adapted from timm ViT.
Returns:
Output features and list of features from intermediate layers (patch encoder only).
"""
intermediate_features = {}
x = self.patch_embed(input_tensor)
batch_size, seq_len, _ = x.shape
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for idx, block in enumerate(self.blocks):
x = block(x)
if self.intermediate_features_ids is not None and idx in self.intermediate_features_ids:
intermediate_features[idx] = x
x = self.norm(x)
x = self.reshape_feature(x)
return x, intermediate_features
def internal_resolution(self) -> int:
"""Return the internal image size of the network."""
if isinstance(self.patch_embed.img_size, tuple):
return self.patch_embed.img_size[0]
else:
return self.patch_embed.img_size
def create_vit(
config: ViTConfig | None = None,
preset: ViTPreset | None = "dinov2l16_384",
intermediate_features_ids: list[int] | None = None,
) -> TimmViT:
"""Factory function for creating a ViT model."""
if config is not None:
LOGGER.info("Using user-defined config.")
else:
if preset is None:
raise ValueError("User-defined config and preset cannot be both None.")
LOGGER.info("Using preset ViT %s.", preset)
config = VIT_CONFIG_DICT[preset]
config.intermediate_features_ids = intermediate_features_ids
model = TimmViT(config)
LOGGER.debug(model)
return model