| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import logging
|
| | from typing import Callable, List, Any, Tuple, Dict
|
| |
|
| | import torch
|
| | from torch import nn, Tensor
|
| |
|
| | from .attention import Attention, MemEffAttention
|
| | from .drop_path import DropPath
|
| | from .layer_scale import LayerScale
|
| | from .mlp import Mlp
|
| |
|
| |
|
| | logger = logging.getLogger("dinov2")
|
| |
|
| |
|
| | XFORMERS_AVAILABLE = False
|
| |
|
| |
|
| | class Block(nn.Module):
|
| | def __init__(
|
| | self,
|
| | dim: int,
|
| | num_heads: int,
|
| | mlp_ratio: float = 4.0,
|
| | qkv_bias: bool = False,
|
| | proj_bias: bool = True,
|
| | ffn_bias: bool = True,
|
| | drop: float = 0.0,
|
| | attn_drop: float = 0.0,
|
| | init_values=None,
|
| | drop_path: float = 0.0,
|
| | act_layer: Callable[..., nn.Module] = nn.GELU,
|
| | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| | attn_class: Callable[..., nn.Module] = Attention,
|
| | ffn_layer: Callable[..., nn.Module] = Mlp,
|
| | ) -> None:
|
| | super().__init__()
|
| |
|
| | self.norm1 = norm_layer(dim)
|
| | self.attn = attn_class(
|
| | dim,
|
| | num_heads=num_heads,
|
| | qkv_bias=qkv_bias,
|
| | proj_bias=proj_bias,
|
| | attn_drop=attn_drop,
|
| | proj_drop=drop,
|
| | )
|
| | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| |
|
| | self.norm2 = norm_layer(dim)
|
| | mlp_hidden_dim = int(dim * mlp_ratio)
|
| | self.mlp = ffn_layer(
|
| | in_features=dim,
|
| | hidden_features=mlp_hidden_dim,
|
| | act_layer=act_layer,
|
| | drop=drop,
|
| | bias=ffn_bias,
|
| | )
|
| | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| |
|
| | self.sample_drop_ratio = drop_path
|
| |
|
| | def forward(self, x: Tensor) -> Tensor:
|
| | def attn_residual_func(x: Tensor) -> Tensor:
|
| | return self.ls1(self.attn(self.norm1(x)))
|
| |
|
| | def ffn_residual_func(x: Tensor) -> Tensor:
|
| | return self.ls2(self.mlp(self.norm2(x)))
|
| |
|
| | if self.training and self.sample_drop_ratio > 0.1:
|
| |
|
| | x = drop_add_residual_stochastic_depth(
|
| | x,
|
| | residual_func=attn_residual_func,
|
| | sample_drop_ratio=self.sample_drop_ratio,
|
| | )
|
| | x = drop_add_residual_stochastic_depth(
|
| | x,
|
| | residual_func=ffn_residual_func,
|
| | sample_drop_ratio=self.sample_drop_ratio,
|
| | )
|
| | elif self.training and self.sample_drop_ratio > 0.0:
|
| | x = x + self.drop_path1(attn_residual_func(x))
|
| | x = x + self.drop_path1(ffn_residual_func(x))
|
| | else:
|
| | x = x + attn_residual_func(x)
|
| | x = x + ffn_residual_func(x)
|
| | return x
|
| |
|
| |
|
| | def drop_add_residual_stochastic_depth(
|
| | x: Tensor,
|
| | residual_func: Callable[[Tensor], Tensor],
|
| | sample_drop_ratio: float = 0.0,
|
| | ) -> Tensor:
|
| |
|
| | b, n, d = x.shape
|
| | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| | x_subset = x[brange]
|
| |
|
| |
|
| | residual = residual_func(x_subset)
|
| |
|
| | x_flat = x.flatten(1)
|
| | residual = residual.flatten(1)
|
| |
|
| | residual_scale_factor = b / sample_subset_size
|
| |
|
| |
|
| | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| | return x_plus_residual.view_as(x)
|
| |
|
| |
|
| | def get_branges_scales(x, sample_drop_ratio=0.0):
|
| | b, n, d = x.shape
|
| | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| | residual_scale_factor = b / sample_subset_size
|
| | return brange, residual_scale_factor
|
| |
|
| |
|
| | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| | if scaling_vector is None:
|
| | x_flat = x.flatten(1)
|
| | residual = residual.flatten(1)
|
| | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| | else:
|
| | x_plus_residual = scaled_index_add(
|
| | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| | )
|
| | return x_plus_residual
|
| |
|
| |
|
| | attn_bias_cache: Dict[Tuple, Any] = {}
|
| |
|
| |
|
| | def get_attn_bias_and_cat(x_list, branges=None):
|
| | """
|
| | this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| | """
|
| | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| | if all_shapes not in attn_bias_cache.keys():
|
| | seqlens = []
|
| | for b, x in zip(batch_sizes, x_list):
|
| | for _ in range(b):
|
| | seqlens.append(x.shape[1])
|
| | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| | attn_bias._batch_sizes = batch_sizes
|
| | attn_bias_cache[all_shapes] = attn_bias
|
| |
|
| | if branges is not None:
|
| | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| | else:
|
| | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| | cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| |
|
| | return attn_bias_cache[all_shapes], cat_tensors
|
| |
|
| |
|
| | def drop_add_residual_stochastic_depth_list(
|
| | x_list: List[Tensor],
|
| | residual_func: Callable[[Tensor, Any], Tensor],
|
| | sample_drop_ratio: float = 0.0,
|
| | scaling_vector=None,
|
| | ) -> Tensor:
|
| |
|
| | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| | branges = [s[0] for s in branges_scales]
|
| | residual_scale_factors = [s[1] for s in branges_scales]
|
| |
|
| |
|
| | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| |
|
| |
|
| | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))
|
| |
|
| | outputs = []
|
| | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| | return outputs
|
| |
|
| |
|
| | class NestedTensorBlock(Block):
|
| | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| | """
|
| | x_list contains a list of tensors to nest together and run
|
| | """
|
| | assert isinstance(self.attn, MemEffAttention)
|
| |
|
| | if self.training and self.sample_drop_ratio > 0.0:
|
| |
|
| | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| | return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| |
|
| | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| | return self.mlp(self.norm2(x))
|
| |
|
| | x_list = drop_add_residual_stochastic_depth_list(
|
| | x_list,
|
| | residual_func=attn_residual_func,
|
| | sample_drop_ratio=self.sample_drop_ratio,
|
| | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| | )
|
| | x_list = drop_add_residual_stochastic_depth_list(
|
| | x_list,
|
| | residual_func=ffn_residual_func,
|
| | sample_drop_ratio=self.sample_drop_ratio,
|
| | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| | )
|
| | return x_list
|
| | else:
|
| |
|
| | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| |
|
| | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| | return self.ls2(self.mlp(self.norm2(x)))
|
| |
|
| | attn_bias, x = get_attn_bias_and_cat(x_list)
|
| | x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| | x = x + ffn_residual_func(x)
|
| | return attn_bias.split(x)
|
| |
|
| | def forward(self, x_or_x_list):
|
| | if isinstance(x_or_x_list, Tensor):
|
| | return super().forward(x_or_x_list)
|
| | elif isinstance(x_or_x_list, list):
|
| | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
| | return self.forward_nested(x_or_x_list)
|
| | else:
|
| | raise AssertionError
|
| |
|