# ------------------------------------------------------------------------ # RF-DETR # Copyright (c) 2025 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from HuggingFace Dinov2 (https://github.com/huggingface/transformers) # Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. # ------------------------------------------------------------------------ import collections.abc import math from typing import Dict, List, Optional, Set, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, torch_int, ) from transformers.utils.backbone_utils import BackboneMixin from transformers.configuration_utils import PretrainedConfig from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) # Base docstring _CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" # General docstring _CONFIG_FOR_DOC = "WindowedDinov2WithRegistersConfig" class WindowedDinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DINOv2 with Registers [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. mlp_ratio (`int`, *optional*, defaults to 4): Ratio of the hidden size of the MLPs relative to the `hidden_size`. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. hidden_dropout_prob (`float`, *optional*, defaults to 0.0): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the layer normalization layers. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. num_channels (`int`, *optional*, defaults to 3): The number of input channels. qkv_bias (`bool`, *optional*, defaults to `True`): Whether to add a bias to the queries, keys and values. layerscale_value (`float`, *optional*, defaults to 1.0): Initial value to use for layer scale. drop_path_rate (`float`, *optional*, defaults to 0.0): Stochastic depth rate per sample (when applied in the main path of residual layers). use_swiglu_ffn (`bool`, *optional*, defaults to `False`): Whether to use the SwiGLU feedforward neural network. num_register_tokens (`int`, *optional*, defaults to 4): Number of register tokens to use. out_features (`List[str]`, *optional*): If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the same order as defined in the `stage_names` attribute. out_indices (`List[int]`, *optional*): If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. If unset and `out_features` is unset, will default to the last stage. Must be in the same order as defined in the `stage_names` attribute. apply_layernorm (`bool`, *optional*, defaults to `True`): Whether to apply layer normalization to the feature maps in case the model is used as backbone. reshape_hidden_states (`bool`, *optional*, defaults to `True`): Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, seq_len, hidden_size)`. Example: ```python >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel >>> # Initializing a Dinov2WithRegisters base style configuration >>> configuration = Dinov2WithRegistersConfig() >>> # Initializing a model (with random weights) from the base style configuration >>> model = Dinov2WithRegistersModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "dinov2_with_registers" def __init__( self, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, mlp_ratio=4, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, layer_norm_eps=1e-6, image_size=224, patch_size=16, num_channels=3, qkv_bias=True, layerscale_value=1.0, drop_path_rate=0.0, use_swiglu_ffn=False, num_register_tokens=4, out_features=None, out_indices=None, apply_layernorm=True, reshape_hidden_states=True, num_windows=1, window_block_indexes=None, gradient_checkpointing=False, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.mlp_ratio = mlp_ratio self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.qkv_bias = qkv_bias self.layerscale_value = layerscale_value self.drop_path_rate = drop_path_rate self.use_swiglu_ffn = use_swiglu_ffn self.num_register_tokens = num_register_tokens self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] self._out_features, self._out_indices = get_aligned_output_features_output_indices( out_features=out_features, out_indices=out_indices, stage_names=self.stage_names ) self.apply_layernorm = apply_layernorm self.reshape_hidden_states = reshape_hidden_states self.num_windows = num_windows self.window_block_indexes = list(range(num_hidden_layers)) if window_block_indexes is None else window_block_indexes self.gradient_checkpointing = gradient_checkpointing class Dinov2WithRegistersPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: num_channels = pixel_values.shape[1] if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." f" Expected {self.num_channels} but got {num_channels}." ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings class WindowedDinov2WithRegistersEmbeddings(nn.Module): """ Construct the CLS token, mask token, register tokens, position and patch embeddings. """ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) if config.num_register_tokens > 0 else None self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.patch_size = config.patch_size self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility with the original implementation. Adapted from: - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 # Skip interpolation for matching dimensions (unless tracing) if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings # Handle class token and patch embeddings separately class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] # Calculate new dimensions height = height // self.config.patch_size width = width // self.config.patch_size # Reshape for interpolation sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) # Store original dtype for restoration after interpolation target_dtype = patch_pos_embed.dtype # Interpolate at float32 precision patch_pos_embed = nn.functional.interpolate( patch_pos_embed.to(dtype=torch.float32), size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor mode="bicubic", align_corners=False, antialias=True, ).to(dtype=target_dtype) # Validate output dimensions if not tracing if not torch.jit.is_tracing(): if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: raise ValueError("Width or height does not match with the interpolated position embeddings") # Reshape back to original format patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) # Combine class and patch embeddings return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) if bool_masked_pos is not None: embeddings = torch.where( bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings ) # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) if self.config.num_windows > 1: # reshape for windows num_h_patches = height // self.config.patch_size num_w_patches = width // self.config.patch_size cls_token_with_pos_embed = embeddings[:, :1] pixel_tokens_with_pos_embed = embeddings[:, 1:] pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(batch_size, num_h_patches, num_w_patches, -1) num_w_patches_per_window = num_w_patches // self.config.num_windows num_h_patches_per_window = num_h_patches // self.config.num_windows num_windows = self.config.num_windows windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1) windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5) windowed_pixel_tokens = windowed_pixel_tokens.reshape(batch_size * num_windows ** 2, num_h_patches_per_window * num_w_patches_per_window, -1) windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows ** 2, 1, 1) embeddings = torch.cat((windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1) # add register tokens embeddings = torch.cat( (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1 ) if self.config.num_register_tokens > 0 else embeddings embeddings = self.dropout(embeddings) return embeddings class Dinov2WithRegistersSelfAttention(nn.Module): def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention): def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__(config) self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions ) mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, head_mask, self.attention_probs_dropout_prob if self.training else 0.0, is_causal=False, scale=None, ) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) return context_layer, None class Dinov2WithRegistersSelfOutput(nn.Module): """ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the layernorm applied before each block. """ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class Dinov2WithRegistersAttention(nn.Module): def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__() self.attention = Dinov2WithRegistersSelfAttention(config) self.output = Dinov2WithRegistersSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads: Set[int]) -> None: if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads ) # Prune linear layers self.attention.query = prune_linear_layer(self.attention.query, index) self.attention.key = prune_linear_layer(self.attention.key, index) self.attention.value = prune_linear_layer(self.attention.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention): def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__(config) self.attention = Dinov2WithRegistersSdpaSelfAttention(config) class Dinov2WithRegistersLayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return hidden_state * self.lambda1 def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize output = input.div(keep_prob) * random_tensor return output class Dinov2WithRegistersDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return drop_path(hidden_states, self.drop_prob, self.training) def extra_repr(self) -> str: return "p={}".format(self.drop_prob) class Dinov2WithRegistersMLP(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) self.fc1 = nn.Linear(in_features, hidden_features, bias=True) if isinstance(config.hidden_act, str): self.activation = ACT2FN[config.hidden_act] else: self.activation = config.hidden_act self.fc2 = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.fc1(hidden_state) hidden_state = self.activation(hidden_state) hidden_state = self.fc2(hidden_state) return hidden_state class Dinov2WithRegistersSwiGLUFFN(nn.Module): def __init__(self, config) -> None: super().__init__() in_features = out_features = config.hidden_size hidden_features = int(config.hidden_size * config.mlp_ratio) hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) self.weights_out = nn.Linear(hidden_features, out_features, bias=True) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.weights_in(hidden_state) x1, x2 = hidden_state.chunk(2, dim=-1) hidden = nn.functional.silu(x1) * x2 return self.weights_out(hidden) DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = { "eager": Dinov2WithRegistersAttention, "sdpa": Dinov2WithRegistersSdpaAttention, } class WindowedDinov2WithRegistersLayer(nn.Module): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__() self.num_windows = config.num_windows self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config) self.layer_scale1 = Dinov2WithRegistersLayerScale(config) self.drop_path = ( Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() ) self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_swiglu_ffn: self.mlp = Dinov2WithRegistersSwiGLUFFN(config) else: self.mlp = Dinov2WithRegistersMLP(config) self.layer_scale2 = Dinov2WithRegistersLayerScale(config) def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, run_full_attention: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: assert head_mask is None, "head_mask is not supported for windowed attention" assert not output_attentions, "output_attentions is not supported for windowed attention" shortcut = hidden_states if run_full_attention: # reshape x to remove windows B, HW, C = hidden_states.shape num_windows_squared = self.num_windows ** 2 hidden_states = hidden_states.view(B // num_windows_squared, num_windows_squared * HW, C) self_attention_outputs = self.attention( self.norm1(hidden_states), # in Dinov2WithRegisters, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] if run_full_attention: # reshape x to add windows back B, HW, C = hidden_states.shape num_windows_squared = self.num_windows ** 2 # hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C) attention_output = attention_output.view(B * num_windows_squared, HW // num_windows_squared, C) attention_output = self.layer_scale1(attention_output) outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = self.drop_path(attention_output) + shortcut # in Dinov2WithRegisters, layernorm is also applied after self-attention layer_output = self.norm2(hidden_states) layer_output = self.mlp(layer_output) layer_output = self.layer_scale2(layer_output) # second residual connection layer_output = self.drop_path(layer_output) + hidden_states outputs = (layer_output,) + outputs return outputs class WindowedDinov2WithRegistersEncoder(nn.Module): def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([WindowedDinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = config.gradient_checkpointing def forward( self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if i > int(self.config.out_features[-1][5:]): # early stop if we have reached the last output feature break run_full_attention = i not in self.config.window_block_indexes layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, output_attentions, run_full_attention, ) else: layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, run_full_attention) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class WindowedDinov2WithRegistersPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = WindowedDinov2WithRegistersConfig base_model_prefix = "dinov2_with_registers" main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] _supports_sdpa = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues module.weight.data = nn.init.trunc_normal_( module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, WindowedDinov2WithRegistersEmbeddings): module.position_embeddings.data = nn.init.trunc_normal_( module.position_embeddings.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.position_embeddings.dtype) module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.cls_token.dtype) _EXPECTED_OUTPUT_SHAPE = [1, 257, 768] DINOV2_WITH_REGISTERS_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.preprocess`] for details. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for pre-training. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.", DINOV2_WITH_REGISTERS_START_DOCSTRING, ) class WindowedDinov2WithRegistersModel(WindowedDinov2WithRegistersPreTrainedModel): def __init__(self, config: WindowedDinov2WithRegistersConfig): super().__init__(config) self.config = config self.embeddings = WindowedDinov2WithRegistersEmbeddings(config) self.encoder = WindowedDinov2WithRegistersEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC, modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, ) def forward( self, pixel_values: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs = self.encoder( embedding_output, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = sequence_output[:, 0, :] if not return_dict: head_outputs = (sequence_output, pooled_output) return head_outputs + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.preprocess`] for details. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( """ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet. """, DINOV2_WITH_REGISTERS_START_DOCSTRING, ) class WindowedDinov2WithRegistersForImageClassification(WindowedDinov2WithRegistersPreTrainedModel): def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None: super().__init__(config) self.num_labels = config.num_labels self.dinov2_with_registers = WindowedDinov2WithRegistersModel(config) # Classifier head self.classifier = ( nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_IMAGE_CLASS_CHECKPOINT, output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) def forward( self, pixel_values: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, ImageClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.dinov2_with_registers( pixel_values, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] # batch_size, sequence_length, hidden_size cls_token = sequence_output[:, 0] patch_tokens = sequence_output[:, 1:] linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) logits = self.classifier(linear_input) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer. """, DINOV2_WITH_REGISTERS_START_DOCSTRING, ) class WindowedDinov2WithRegistersBackbone(WindowedDinov2WithRegistersPreTrainedModel, BackboneMixin): def __init__(self, config: WindowedDinov2WithRegistersConfig): super().__init__(config) super()._init_backbone(config) self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] self.embeddings = WindowedDinov2WithRegistersEmbeddings(config) self.encoder = WindowedDinov2WithRegistersEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.num_register_tokens = config.num_register_tokens # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: return self.embeddings.patch_embeddings @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> BackboneOutput: """ Returns: Examples: Returns: Examples: ```python >>> from transformers import AutoImageProcessor, AutoBackbone >>> import torch >>> from PIL import Image >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") >>> model = AutoBackbone.from_pretrained( ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] ... ) >>> inputs = processor(image, return_tensors="pt") >>> outputs = model(**inputs) >>> feature_maps = outputs.feature_maps >>> list(feature_maps[-1].shape) [1, 768, 16, 16] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions embedding_output = self.embeddings(pixel_values) outputs = self.encoder( embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict ) hidden_states = outputs.hidden_states if return_dict else outputs[1] feature_maps = () for stage, hidden_state in zip(self.stage_names, hidden_states): if stage in self.out_features: if self.config.apply_layernorm: hidden_state = self.layernorm(hidden_state) if self.config.reshape_hidden_states: hidden_state = hidden_state[:, self.num_register_tokens + 1 :] # this was actually a bug in the original implementation that we copied here, # cause normally the order is height, width batch_size, _, height, width = pixel_values.shape patch_size = self.config.patch_size num_h_patches = height // patch_size num_w_patches = width // patch_size if self.config.num_windows > 1: # undo windowing num_windows_squared = self.config.num_windows ** 2 B, HW, C = hidden_state.shape num_h_patches_per_window = num_h_patches // self.config.num_windows num_w_patches_per_window = num_w_patches // self.config.num_windows hidden_state = hidden_state.reshape(B // num_windows_squared, num_windows_squared * HW, C) hidden_state = hidden_state.view(B // num_windows_squared, self.config.num_windows, self.config.num_windows, num_h_patches_per_window, num_w_patches_per_window, C) hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5) hidden_state = hidden_state.reshape(batch_size, num_h_patches, num_w_patches, -1) hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() feature_maps += (hidden_state,) if not return_dict: if output_hidden_states: output = (feature_maps,) + outputs[1:] else: output = (feature_maps,) + outputs[2:] return output return BackboneOutput( feature_maps=feature_maps, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions if output_attentions else None, ) __all__ = [ "WindowedDinov2WithRegistersPreTrainedModel", "WindowedDinov2WithRegistersModel", "WindowedDinov2WithRegistersForImageClassification", "WindowedDinov2WithRegistersBackbone", ]