Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This software may be used and distributed in accordance with | |
| # the terms of the DINOv3 License Agreement. | |
| from typing import Callable, List, Optional | |
| import torch | |
| from torch import Tensor, nn | |
| from ..utils import cat_keep_shapes, uncat_with_shapes | |
| from .attention import CausalSelfAttention, SelfAttention | |
| from .ffn_layers import Mlp | |
| from .layer_scale import LayerScale # , DropPath | |
| torch._dynamo.config.automatic_dynamic_shapes = False | |
| torch._dynamo.config.accumulated_cache_size_limit = 1024 | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| ffn_ratio: float = 4.0, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| ffn_bias: bool = True, | |
| drop: float = 0.0, | |
| attn_drop: float = 0.0, | |
| init_values=None, | |
| drop_path: float = 0.0, | |
| act_layer: Callable[..., nn.Module] = nn.GELU, | |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, | |
| attn_class: Callable[..., nn.Module] = SelfAttention, | |
| ffn_layer: Callable[..., nn.Module] = Mlp, | |
| mask_k_bias: bool = False, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") | |
| self.norm1 = norm_layer(dim) | |
| self.attn = attn_class( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| mask_k_bias=mask_k_bias, | |
| device=device, | |
| ) | |
| self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * ffn_ratio) | |
| self.mlp = ffn_layer( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| bias=ffn_bias, | |
| device=device, | |
| ) | |
| self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() | |
| self.sample_drop_ratio = drop_path | |
| def _maybe_index_rope(rope: tuple[Tensor, Tensor] | None, indices: Tensor) -> tuple[Tensor, Tensor] | None: | |
| if rope is None: | |
| return None | |
| sin, cos = rope | |
| assert sin.ndim == cos.ndim | |
| if sin.ndim == 4: | |
| # If the rope embedding has a batch dimension (is different for each batch element), index into it | |
| return sin[indices], cos[indices] # [batch, heads, patches, embed_dim] | |
| else: | |
| # No batch dimension, do not index | |
| return sin, cos # [heads, patches, embed_dim] or [patches, embed_dim] | |
| def _forward(self, x: Tensor, rope=None) -> Tensor: | |
| """ | |
| This is the reference implementation for a single tensor, matching what is done below for a list. | |
| We call the list op on [x] instead of this function. | |
| """ | |
| b, _, _ = x.shape | |
| sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1) | |
| residual_scale_factor = b / sample_subset_size | |
| if self.training and self.sample_drop_ratio > 0.0: | |
| indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size] | |
| x_subset_1 = x[indices_1] | |
| rope_subset = self._maybe_index_rope(rope, indices_1) | |
| residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset) | |
| x_attn = torch.index_add( | |
| x, | |
| dim=0, | |
| source=self.ls1(residual_1), | |
| index=indices_1, | |
| alpha=residual_scale_factor, | |
| ) | |
| indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size] | |
| x_subset_2 = x_attn[indices_2] | |
| residual_2 = self.mlp(self.norm2(x_subset_2)) | |
| x_ffn = torch.index_add( | |
| x_attn, | |
| dim=0, | |
| source=self.ls2(residual_2), | |
| index=indices_2, | |
| alpha=residual_scale_factor, | |
| ) | |
| else: | |
| x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) | |
| x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) | |
| return x_ffn | |
| def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]: | |
| """ | |
| This list operator concatenates the tokens from the list of inputs together to save | |
| on the elementwise operations. Torch-compile memory-planning allows hiding the overhead | |
| related to concat ops. | |
| """ | |
| b_list = [x.shape[0] for x in x_list] | |
| sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] | |
| residual_scale_factors = [b / sample_subset_size for b, sample_subset_size in zip(b_list, sample_subset_sizes)] | |
| if self.training and self.sample_drop_ratio > 0.0: | |
| indices_1_list = [ | |
| (torch.randperm(b, device=x.device))[:sample_subset_size] | |
| for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes) | |
| ] | |
| x_subset_1_list = [x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)] | |
| if rope_list is not None: | |
| rope_subset_list = [ | |
| self._maybe_index_rope(rope, indices_1) for rope, indices_1 in zip(rope_list, indices_1_list) | |
| ] | |
| else: | |
| rope_subset_list = rope_list | |
| flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) | |
| norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) | |
| residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) | |
| x_attn_list = [ | |
| torch.index_add( | |
| x, | |
| dim=0, | |
| source=self.ls1(residual_1), | |
| index=indices_1, | |
| alpha=residual_scale_factor, | |
| ) | |
| for x, residual_1, indices_1, residual_scale_factor in zip( | |
| x_list, residual_1_list, indices_1_list, residual_scale_factors | |
| ) | |
| ] | |
| indices_2_list = [ | |
| (torch.randperm(b, device=x.device))[:sample_subset_size] | |
| for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes) | |
| ] | |
| x_subset_2_list = [x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)] | |
| flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) | |
| norm2_flat = self.norm2(flattened) | |
| norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens) | |
| residual_2_list = self.mlp.forward_list(norm2_list) | |
| x_ffn = [ | |
| torch.index_add( | |
| x_attn, | |
| dim=0, | |
| source=self.ls2(residual_2), | |
| index=indices_2, | |
| alpha=residual_scale_factor, | |
| ) | |
| for x_attn, residual_2, indices_2, residual_scale_factor in zip( | |
| x_attn_list, residual_2_list, indices_2_list, residual_scale_factors | |
| ) | |
| ] | |
| else: | |
| x_out = [] | |
| for x, rope in zip(x_list, rope_list): | |
| x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) | |
| x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) | |
| x_out.append(x_ffn) | |
| x_ffn = x_out | |
| return x_ffn | |
| def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]: | |
| if isinstance(x_or_x_list, Tensor): | |
| # for reference: | |
| # return self._forward(x_or_x_list, rope=rope_or_rope_list) | |
| # in order to match implementations we call the list op: | |
| return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0] | |
| elif isinstance(x_or_x_list, list): | |
| if rope_or_rope_list is None: | |
| rope_or_rope_list = [None for x in x_or_x_list] | |
| # return [self._forward(x, rope=rope) for x, rope in zip(x_or_x_list, rope_or_rope_list)] | |
| return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) | |
| else: | |
| raise AssertionError | |
| class CausalSelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| ffn_ratio: float = 4.0, | |
| ls_init_value: Optional[float] = None, | |
| is_causal: bool = True, | |
| act_layer: Callable = nn.GELU, | |
| norm_layer: Callable = nn.LayerNorm, | |
| dropout_prob: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.is_causal = is_causal | |
| self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() | |
| self.attention_norm = norm_layer(dim) | |
| self.attention = CausalSelfAttention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob) | |
| self.ffn_norm = norm_layer(dim) | |
| ffn_hidden_dim = int(dim * ffn_ratio) | |
| self.feed_forward = Mlp( | |
| in_features=dim, | |
| hidden_features=ffn_hidden_dim, | |
| drop=dropout_prob, | |
| act_layer=act_layer, | |
| ) | |
| self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() | |
| def init_weights( | |
| self, | |
| init_attn_std: float | None = None, | |
| init_proj_std: float | None = None, | |
| init_fc_std: float | None = None, | |
| factor: float = 1.0, | |
| ) -> None: | |
| init_attn_std = init_attn_std or (self.dim**-0.5) | |
| init_proj_std = init_proj_std or init_attn_std * factor | |
| init_fc_std = init_fc_std or (2 * self.dim) ** -0.5 | |
| self.attention.init_weights(init_attn_std, init_proj_std) | |
| self.attention_norm.reset_parameters() | |
| nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std) | |
| nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std) | |
| self.ffn_norm.reset_parameters() | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| ): | |
| x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal)) | |
| x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn))) | |
| return x_ffn | |