|
|
"""Contains factory functions to build and load ViT. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
|
|
|
import timm |
|
|
import torch |
|
|
|
|
|
from sharp.models.presets.vit import VIT_CONFIG_DICT, ViTConfig, ViTPreset |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TimmViT(timm.models.VisionTransformer): |
|
|
"""Contains TIMM implementation for Vanilla ViT.""" |
|
|
|
|
|
def __init__(self, config: ViTConfig): |
|
|
"""Initialize ViT from TIMM implementation.""" |
|
|
|
|
|
mlp_layer = timm.layers.GluMlp if config.mlp_mode == "glu" else timm.layers.Mlp |
|
|
|
|
|
super().__init__( |
|
|
in_chans=config.in_chans, |
|
|
embed_dim=config.embed_dim, |
|
|
depth=config.depth, |
|
|
num_heads=config.num_heads, |
|
|
init_values=config.init_values, |
|
|
img_size=config.img_size, |
|
|
patch_size=config.patch_size, |
|
|
num_classes=config.num_classes, |
|
|
mlp_ratio=config.mlp_ratio, |
|
|
qkv_bias=config.qkv_bias, |
|
|
global_pool=config.global_pool, |
|
|
mlp_layer=mlp_layer, |
|
|
) |
|
|
|
|
|
|
|
|
self.dim_in = config.in_chans |
|
|
self.intermediate_features_ids = config.intermediate_features_ids |
|
|
|
|
|
def reshape_feature(self, embeddings: torch.Tensor): |
|
|
"""Discard class token and reshape 1D feature map to a 2D grid.""" |
|
|
batch_size, seq_len, channel = embeddings.shape |
|
|
|
|
|
height, width = self.patch_embed.grid_size |
|
|
|
|
|
|
|
|
if self.num_prefix_tokens: |
|
|
embeddings = embeddings[:, self.num_prefix_tokens :, :] |
|
|
|
|
|
|
|
|
embeddings = embeddings.reshape(batch_size, height, width, channel).permute(0, 3, 1, 2) |
|
|
return embeddings |
|
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[int, torch.Tensor]]: |
|
|
"""Override forwarding with intermediate features. |
|
|
|
|
|
Adapted from timm ViT. |
|
|
|
|
|
Returns: |
|
|
Output features and list of features from intermediate layers (patch encoder only). |
|
|
""" |
|
|
intermediate_features = {} |
|
|
|
|
|
x = self.patch_embed(input_tensor) |
|
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
x = self._pos_embed(x) |
|
|
x = self.patch_drop(x) |
|
|
x = self.norm_pre(x) |
|
|
|
|
|
for idx, block in enumerate(self.blocks): |
|
|
x = block(x) |
|
|
if self.intermediate_features_ids is not None and idx in self.intermediate_features_ids: |
|
|
intermediate_features[idx] = x |
|
|
x = self.norm(x) |
|
|
|
|
|
x = self.reshape_feature(x) |
|
|
return x, intermediate_features |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Return the internal image size of the network.""" |
|
|
if isinstance(self.patch_embed.img_size, tuple): |
|
|
return self.patch_embed.img_size[0] |
|
|
else: |
|
|
return self.patch_embed.img_size |
|
|
|
|
|
|
|
|
def create_vit( |
|
|
config: ViTConfig | None = None, |
|
|
preset: ViTPreset | None = "dinov2l16_384", |
|
|
intermediate_features_ids: list[int] | None = None, |
|
|
) -> TimmViT: |
|
|
"""Factory function for creating a ViT model.""" |
|
|
if config is not None: |
|
|
LOGGER.info("Using user-defined config.") |
|
|
else: |
|
|
if preset is None: |
|
|
raise ValueError("User-defined config and preset cannot be both None.") |
|
|
LOGGER.info("Using preset ViT %s.", preset) |
|
|
config = VIT_CONFIG_DICT[preset] |
|
|
|
|
|
config.intermediate_features_ids = intermediate_features_ids |
|
|
model = TimmViT(config) |
|
|
LOGGER.debug(model) |
|
|
return model |
|
|
|