|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import logging
|
| from typing import Callable, List, Any, Tuple, Dict
|
|
|
| import torch
|
| from torch import nn, Tensor
|
|
|
| from .attention import Attention, MemEffAttention
|
| from .drop_path import DropPath
|
| from .layer_scale import LayerScale
|
| from .mlp import Mlp
|
|
|
|
|
| logger = logging.getLogger("dinov2")
|
|
|
|
|
| XFORMERS_AVAILABLE = False
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(
|
| self,
|
| dim: int,
|
| num_heads: int,
|
| mlp_ratio: float = 4.0,
|
| qkv_bias: bool = False,
|
| proj_bias: bool = True,
|
| ffn_bias: bool = True,
|
| drop: float = 0.0,
|
| attn_drop: float = 0.0,
|
| init_values=None,
|
| drop_path: float = 0.0,
|
| act_layer: Callable[..., nn.Module] = nn.GELU,
|
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| attn_class: Callable[..., nn.Module] = Attention,
|
| ffn_layer: Callable[..., nn.Module] = Mlp,
|
| ) -> None:
|
| super().__init__()
|
|
|
| self.norm1 = norm_layer(dim)
|
| self.attn = attn_class(
|
| dim,
|
| num_heads=num_heads,
|
| qkv_bias=qkv_bias,
|
| proj_bias=proj_bias,
|
| attn_drop=attn_drop,
|
| proj_drop=drop,
|
| )
|
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
| self.norm2 = norm_layer(dim)
|
| mlp_hidden_dim = int(dim * mlp_ratio)
|
| self.mlp = ffn_layer(
|
| in_features=dim,
|
| hidden_features=mlp_hidden_dim,
|
| act_layer=act_layer,
|
| drop=drop,
|
| bias=ffn_bias,
|
| )
|
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
| self.sample_drop_ratio = drop_path
|
|
|
| def forward(self, x: Tensor) -> Tensor:
|
| def attn_residual_func(x: Tensor) -> Tensor:
|
| return self.ls1(self.attn(self.norm1(x)))
|
|
|
| def ffn_residual_func(x: Tensor) -> Tensor:
|
| return self.ls2(self.mlp(self.norm2(x)))
|
|
|
| if self.training and self.sample_drop_ratio > 0.1:
|
|
|
| x = drop_add_residual_stochastic_depth(
|
| x,
|
| residual_func=attn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| )
|
| x = drop_add_residual_stochastic_depth(
|
| x,
|
| residual_func=ffn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| )
|
| elif self.training and self.sample_drop_ratio > 0.0:
|
| x = x + self.drop_path1(attn_residual_func(x))
|
| x = x + self.drop_path1(ffn_residual_func(x))
|
| else:
|
| x = x + attn_residual_func(x)
|
| x = x + ffn_residual_func(x)
|
| return x
|
|
|
|
|
| def drop_add_residual_stochastic_depth(
|
| x: Tensor,
|
| residual_func: Callable[[Tensor], Tensor],
|
| sample_drop_ratio: float = 0.0,
|
| ) -> Tensor:
|
|
|
| b, n, d = x.shape
|
| sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| x_subset = x[brange]
|
|
|
|
|
| residual = residual_func(x_subset)
|
|
|
| x_flat = x.flatten(1)
|
| residual = residual.flatten(1)
|
|
|
| residual_scale_factor = b / sample_subset_size
|
|
|
|
|
| x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| return x_plus_residual.view_as(x)
|
|
|
|
|
| def get_branges_scales(x, sample_drop_ratio=0.0):
|
| b, n, d = x.shape
|
| sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| residual_scale_factor = b / sample_subset_size
|
| return brange, residual_scale_factor
|
|
|
|
|
| def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| if scaling_vector is None:
|
| x_flat = x.flatten(1)
|
| residual = residual.flatten(1)
|
| x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| else:
|
| x_plus_residual = scaled_index_add(
|
| x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| )
|
| return x_plus_residual
|
|
|
|
|
| attn_bias_cache: Dict[Tuple, Any] = {}
|
|
|
|
|
| def get_attn_bias_and_cat(x_list, branges=None):
|
| """
|
| this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| """
|
| batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| if all_shapes not in attn_bias_cache.keys():
|
| seqlens = []
|
| for b, x in zip(batch_sizes, x_list):
|
| for _ in range(b):
|
| seqlens.append(x.shape[1])
|
| attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| attn_bias._batch_sizes = batch_sizes
|
| attn_bias_cache[all_shapes] = attn_bias
|
|
|
| if branges is not None:
|
| cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| else:
|
| tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| cat_tensors = torch.cat(tensors_bs1, dim=1)
|
|
|
| return attn_bias_cache[all_shapes], cat_tensors
|
|
|
|
|
| def drop_add_residual_stochastic_depth_list(
|
| x_list: List[Tensor],
|
| residual_func: Callable[[Tensor, Any], Tensor],
|
| sample_drop_ratio: float = 0.0,
|
| scaling_vector=None,
|
| ) -> Tensor:
|
|
|
| branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| branges = [s[0] for s in branges_scales]
|
| residual_scale_factors = [s[1] for s in branges_scales]
|
|
|
|
|
| attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
|
|
|
|
| residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))
|
|
|
| outputs = []
|
| for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| return outputs
|
|
|
|
|
| class NestedTensorBlock(Block):
|
| def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| """
|
| x_list contains a list of tensors to nest together and run
|
| """
|
| assert isinstance(self.attn, MemEffAttention)
|
|
|
| if self.training and self.sample_drop_ratio > 0.0:
|
|
|
| def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.attn(self.norm1(x), attn_bias=attn_bias)
|
|
|
| def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.mlp(self.norm2(x))
|
|
|
| x_list = drop_add_residual_stochastic_depth_list(
|
| x_list,
|
| residual_func=attn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| )
|
| x_list = drop_add_residual_stochastic_depth_list(
|
| x_list,
|
| residual_func=ffn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| )
|
| return x_list
|
| else:
|
|
|
| def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
|
|
| def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.ls2(self.mlp(self.norm2(x)))
|
|
|
| attn_bias, x = get_attn_bias_and_cat(x_list)
|
| x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| x = x + ffn_residual_func(x)
|
| return attn_bias.split(x)
|
|
|
| def forward(self, x_or_x_list):
|
| if isinstance(x_or_x_list, Tensor):
|
| return super().forward(x_or_x_list)
|
| elif isinstance(x_or_x_list, list):
|
| assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
| return self.forward_nested(x_or_x_list)
|
| else:
|
| raise AssertionError
|
|
|