# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import logging from enum import Enum from typing import Sequence, Union from dinov3.eval.dense.depth.models.embed import CenterPadding, StretchToMultiple from torch import Tensor, nn logger = logging.getLogger("fairvit") class BackboneLayersSet(Enum): # Set of intermediate layers to take from the backbone LAST = "LAST" # extracting only the last layer FOUR_LAST = "FOUR_LAST" # extracting the last 4 layers FOUR_EVEN_INTERVALS = "FOUR_EVEN_INTERVALS" # extracting outputs every 1/4 of the total number of blocks def _get_backbone_out_indices( model: nn.Module, backbone_out_layers: Union[list[int], BackboneLayersSet] = BackboneLayersSet.FOUR_EVEN_INTERVALS, ): """ Get indices for output layers of the ViT backbone. For now there are 3 options available: BackboneLayersSet.LAST : only extract the last layer, used in segmentation tasks with a bn head. BackboneLayersSet.FOUR_LAST : extract the last 4 layers, used in segmentation (multiscale setting) BackboneLayersSet.FOUR_EVEN_INTERVALS : extract outputs every 1/4 of the total number of blocks Reference outputs in 'FOUR_EVEN_INTERVALS' mode : ViT/S (12 blocks): [2, 5, 8, 11] ViT/B (12 blocks): [2, 5, 8, 11] ViT/L (24 blocks): [5, 11, 17, 23] (correct), [4, 11, 17, 23] (incorrect) ViT/g (40 blocks): [9, 19, 29, 39] """ n_blocks = getattr(model, "n_blocks", 1) if isinstance(backbone_out_layers, list): out_indices = backbone_out_layers if backbone_out_layers == BackboneLayersSet.LAST: out_indices = [n_blocks - 1] elif backbone_out_layers == BackboneLayersSet.FOUR_LAST: out_indices = [i for i in range(n_blocks - 4, n_blocks)] elif backbone_out_layers == BackboneLayersSet.FOUR_EVEN_INTERVALS: # XXX: Force (incorrect) out indices for backward-compatibility (ViT/L only) if n_blocks == 24: out_indices = [4, 11, 17, 23] else: out_indices = [i * (n_blocks // 4) - 1 for i in range(1, 5)] assert all([out_index < n_blocks for out_index in out_indices]) return out_indices class PatchSizeAdaptationStrategy(Enum): CENTER_PADDING = "center_padding" STRETCH = "stretch" NO_ADAPTATION = "never" class DinoVisionTransformerWrapper(nn.Module): """Vision Transformer.""" def __init__( self, backbone_model: nn.Module, backbone_out_layers: Union[str, list[int]], use_backbone_norm: bool = False, adapt_to_patch_size: PatchSizeAdaptationStrategy = PatchSizeAdaptationStrategy.CENTER_PADDING, ): super().__init__() self.final_norm = use_backbone_norm self.backbone = backbone_model self.backbone_out_indices = _get_backbone_out_indices( self.backbone, backbone_out_layers=( backbone_out_layers if isinstance(backbone_out_layers, list) else BackboneLayersSet(backbone_out_layers) ), ) # If the backbone does not define embed_dims, use [embed_dim] * n_blocks try: embed_dims = self.backbone.embed_dims except AttributeError: embed_dim = self.backbone.embed_dim n_blocks = self.backbone.n_blocks logger.warning(f"Backbone does not define embed_dims, using {[embed_dim] * n_blocks=} instead") embed_dims = [embed_dim] * n_blocks self.embed_dims: Sequence[int] = [embed_dims[idx] for idx in self.backbone_out_indices] # How to adapt input images to the patch size of the model? try: input_pad_size = self.backbone.input_pad_size except AttributeError: patch_size = self.backbone.patch_size logger.warning(f"Backbone does not define input_pad_size, using {patch_size=} instead") input_pad_size = patch_size if adapt_to_patch_size is PatchSizeAdaptationStrategy.CENTER_PADDING: self.patch_size_adapter = CenterPadding(input_pad_size) elif adapt_to_patch_size is PatchSizeAdaptationStrategy.STRETCH: self.patch_size_adapter = StretchToMultiple(input_pad_size) elif adapt_to_patch_size is PatchSizeAdaptationStrategy.NO_ADAPTATION: self.patch_size_adapter = nn.Identity() else: raise ValueError(f"Unknown value {adapt_to_patch_size=}") # Freeze backbone self.backbone.requires_grad_(False) def forward( self, x: Tensor, # [B, rgb, H, W] ) -> list[tuple[Tensor, Tensor]]: x = self.patch_size_adapter(x) outputs = self.backbone.get_intermediate_layers( x, n=self.backbone_out_indices, reshape=True, return_class_token=True, norm=self.final_norm, ) # List of (patch feats [B, C, h, w], class token [B, C]) return outputs