|
|
|
|
|
|
|
|
|
|
| import logging
|
| from enum import Enum
|
| from typing import Sequence, Union
|
|
|
| from dinov3.eval.dense.depth.models.embed import CenterPadding, StretchToMultiple
|
| from torch import Tensor, nn
|
|
|
| logger = logging.getLogger("fairvit")
|
|
|
|
|
| class BackboneLayersSet(Enum):
|
|
|
| LAST = "LAST"
|
| FOUR_LAST = "FOUR_LAST"
|
| FOUR_EVEN_INTERVALS = "FOUR_EVEN_INTERVALS"
|
|
|
|
|
| def _get_backbone_out_indices(
|
| model: nn.Module,
|
| backbone_out_layers: Union[list[int], BackboneLayersSet] = BackboneLayersSet.FOUR_EVEN_INTERVALS,
|
| ):
|
| """
|
| Get indices for output layers of the ViT backbone. For now there are 3 options available:
|
| BackboneLayersSet.LAST : only extract the last layer, used in segmentation tasks with a bn head.
|
| BackboneLayersSet.FOUR_LAST : extract the last 4 layers, used in segmentation (multiscale setting)
|
| BackboneLayersSet.FOUR_EVEN_INTERVALS : extract outputs every 1/4 of the total number of blocks
|
| Reference outputs in 'FOUR_EVEN_INTERVALS' mode :
|
| ViT/S (12 blocks): [2, 5, 8, 11]
|
| ViT/B (12 blocks): [2, 5, 8, 11]
|
| ViT/L (24 blocks): [5, 11, 17, 23] (correct), [4, 11, 17, 23] (incorrect)
|
| ViT/g (40 blocks): [9, 19, 29, 39]
|
| """
|
| n_blocks = getattr(model, "n_blocks", 1)
|
| if isinstance(backbone_out_layers, list):
|
| out_indices = backbone_out_layers
|
|
|
| if backbone_out_layers == BackboneLayersSet.LAST:
|
| out_indices = [n_blocks - 1]
|
| elif backbone_out_layers == BackboneLayersSet.FOUR_LAST:
|
| out_indices = [i for i in range(n_blocks - 4, n_blocks)]
|
| elif backbone_out_layers == BackboneLayersSet.FOUR_EVEN_INTERVALS:
|
|
|
| if n_blocks == 24:
|
| out_indices = [4, 11, 17, 23]
|
| else:
|
| out_indices = [i * (n_blocks // 4) - 1 for i in range(1, 5)]
|
| assert all([out_index < n_blocks for out_index in out_indices])
|
| return out_indices
|
|
|
|
|
| class PatchSizeAdaptationStrategy(Enum):
|
| CENTER_PADDING = "center_padding"
|
| STRETCH = "stretch"
|
| NO_ADAPTATION = "never"
|
|
|
|
|
| class DinoVisionTransformerWrapper(nn.Module):
|
| """Vision Transformer."""
|
|
|
| def __init__(
|
| self,
|
| backbone_model: nn.Module,
|
| backbone_out_layers: Union[str, list[int]],
|
| use_backbone_norm: bool = False,
|
| adapt_to_patch_size: PatchSizeAdaptationStrategy = PatchSizeAdaptationStrategy.CENTER_PADDING,
|
| ):
|
| super().__init__()
|
|
|
| self.final_norm = use_backbone_norm
|
| self.backbone = backbone_model
|
| self.backbone_out_indices = _get_backbone_out_indices(
|
| self.backbone,
|
| backbone_out_layers=(
|
| backbone_out_layers if isinstance(backbone_out_layers, list) else BackboneLayersSet(backbone_out_layers)
|
| ),
|
| )
|
|
|
|
|
| try:
|
| embed_dims = self.backbone.embed_dims
|
| except AttributeError:
|
| embed_dim = self.backbone.embed_dim
|
| n_blocks = self.backbone.n_blocks
|
| logger.warning(f"Backbone does not define embed_dims, using {[embed_dim] * n_blocks=} instead")
|
| embed_dims = [embed_dim] * n_blocks
|
| self.embed_dims: Sequence[int] = [embed_dims[idx] for idx in self.backbone_out_indices]
|
|
|
|
|
| try:
|
| input_pad_size = self.backbone.input_pad_size
|
| except AttributeError:
|
| patch_size = self.backbone.patch_size
|
| logger.warning(f"Backbone does not define input_pad_size, using {patch_size=} instead")
|
| input_pad_size = patch_size
|
| if adapt_to_patch_size is PatchSizeAdaptationStrategy.CENTER_PADDING:
|
| self.patch_size_adapter = CenterPadding(input_pad_size)
|
| elif adapt_to_patch_size is PatchSizeAdaptationStrategy.STRETCH:
|
| self.patch_size_adapter = StretchToMultiple(input_pad_size)
|
| elif adapt_to_patch_size is PatchSizeAdaptationStrategy.NO_ADAPTATION:
|
| self.patch_size_adapter = nn.Identity()
|
| else:
|
| raise ValueError(f"Unknown value {adapt_to_patch_size=}")
|
|
|
|
|
| self.backbone.requires_grad_(False)
|
|
|
| def forward(
|
| self,
|
| x: Tensor,
|
| ) -> list[tuple[Tensor, Tensor]]:
|
| x = self.patch_size_adapter(x)
|
| outputs = self.backbone.get_intermediate_layers(
|
| x,
|
| n=self.backbone_out_indices,
|
| reshape=True,
|
| return_class_token=True,
|
| norm=self.final_norm,
|
| )
|
| return outputs
|
|
|