|
|
|
|
|
|
|
|
| import os
|
|
|
| from typing import Callable, List, Any, Tuple, Dict
|
|
|
| import torch
|
|
|
| from torch import nn, Tensor
|
|
|
| from .attention import Attention, MemEffAttention
|
| from .drop_path import DropPath
|
| from .layer_scale import LayerScale
|
| from .mlp import Mlp
|
|
|
| XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| try:
|
| if XFORMERS_ENABLED:
|
| from xformers.ops import fmha
|
|
|
| XFORMERS_AVAILABLE = True
|
| else:
|
| raise ImportError
|
| except ImportError:
|
| XFORMERS_AVAILABLE = False
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(
|
| self,
|
| dim: int,
|
| num_heads: int,
|
| mlp_ratio: float = 4.0,
|
| qkv_bias: bool = False,
|
| proj_bias: bool = True,
|
| ffn_bias: bool = True,
|
| drop: float = 0.0,
|
| attn_drop: float = 0.0,
|
| init_values=None,
|
| drop_path: float = 0.0,
|
| act_layer: Callable[..., nn.Module] = nn.GELU,
|
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| attn_class: Callable[..., nn.Module] = Attention,
|
| ffn_layer: Callable[..., nn.Module] = Mlp,
|
| ) -> None:
|
| super().__init__()
|
| self.norm1 = norm_layer(dim)
|
| self.attn = attn_class(
|
| dim,
|
| num_heads=num_heads,
|
| qkv_bias=qkv_bias,
|
| proj_bias=proj_bias,
|
| attn_drop=attn_drop,
|
| proj_drop=drop,
|
| )
|
| self.ls1 = (
|
| LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| )
|
| self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
| self.norm2 = norm_layer(dim)
|
| mlp_hidden_dim = int(dim * mlp_ratio)
|
| self.mlp = ffn_layer(
|
| in_features=dim,
|
| hidden_features=mlp_hidden_dim,
|
| act_layer=act_layer,
|
| drop=drop,
|
| bias=ffn_bias,
|
| )
|
| self.ls2 = (
|
| LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| )
|
| self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
| self.sample_drop_ratio = drop_path
|
|
|
| def forward(self, x: Tensor, return_attention=False) -> Tensor:
|
| """
|
| Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| """
|
|
|
| def attn_residual_func(x: Tensor) -> Tensor:
|
| return self.ls1(self.attn(self.norm1(x)))
|
|
|
| def ffn_residual_func(x: Tensor) -> Tensor:
|
| return self.ls2(self.mlp(self.norm2(x)))
|
|
|
|
|
| if return_attention:
|
| attn = self.attn(self.norm1(x), return_attn=True)
|
|
|
| if self.training and self.sample_drop_ratio > 0.1:
|
|
|
| x = drop_add_residual_stochastic_depth(
|
| x,
|
| residual_func=attn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| )
|
| x = drop_add_residual_stochastic_depth(
|
| x,
|
| residual_func=ffn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| )
|
| elif self.training and self.sample_drop_ratio > 0.0:
|
| x = x + self.drop_path1(attn_residual_func(x))
|
| x = x + self.drop_path1(ffn_residual_func(x))
|
| else:
|
| x = x + attn_residual_func(x)
|
| x = x + ffn_residual_func(x)
|
|
|
|
|
| if return_attention:
|
| return x, attn
|
|
|
| return x
|
|
|
|
|
| def drop_add_residual_stochastic_depth(
|
| x: Tensor,
|
| residual_func: Callable[[Tensor], Tensor],
|
| sample_drop_ratio: float = 0.0,
|
| ) -> Tensor:
|
|
|
| b, n, d = x.shape
|
| sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| x_subset = x[brange]
|
|
|
|
|
| residual = residual_func(x_subset)
|
|
|
| x_flat = x.flatten(1)
|
| residual = residual.flatten(1)
|
|
|
| residual_scale_factor = b / sample_subset_size
|
|
|
|
|
| x_plus_residual = torch.index_add(
|
| x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| )
|
| return x_plus_residual.view_as(x)
|
|
|
|
|
| def get_branges_scales(x, sample_drop_ratio=0.0):
|
| b, n, d = x.shape
|
| sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| residual_scale_factor = b / sample_subset_size
|
| return brange, residual_scale_factor
|
|
|
|
|
| def add_residual(x, brange, residual, residual_scale_factor, ls=None):
|
| if ls is None:
|
| x_flat = x.flatten(1)
|
| residual = residual.flatten(1)
|
| x_plus_residual = x_flat.index_add_(
|
| dim=0,
|
| index=brange,
|
| source=residual.to(dtype=x.dtype),
|
| alpha=residual_scale_factor,
|
| )
|
| else:
|
| x_plus_residual = x.index_add_(
|
| dim=0,
|
| source=ls(residual.to(dtype=x.dtype)),
|
| index=brange,
|
| alpha=residual_scale_factor,
|
| )
|
| return x_plus_residual
|
|
|
|
|
| attn_bias_cache: Dict[Tuple, Any] = {}
|
|
|
|
|
| def get_attn_bias_and_cat(x_list, branges=None):
|
| """
|
| this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| """
|
| batch_sizes = (
|
| [b.shape[0] for b in branges]
|
| if branges is not None
|
| else [x.shape[0] for x in x_list]
|
| )
|
| all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| if all_shapes not in attn_bias_cache.keys():
|
| seqlens = []
|
| for b, x in zip(batch_sizes, x_list):
|
| for _ in range(b):
|
| seqlens.append(x.shape[1])
|
| attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| attn_bias._batch_sizes = batch_sizes
|
| attn_bias_cache[all_shapes] = attn_bias
|
|
|
| if branges is not None:
|
| cat_tensors = torch.cat(
|
| [
|
| _s.index_select(0, _i).reshape(-1)
|
| for _s, _i in zip([_x.flatten(1) for _x in x_list], branges)
|
| ],
|
| dim=0,
|
| ).view(1, -1, x_list[0].shape[-1])
|
|
|
|
|
|
|
| else:
|
| tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| cat_tensors = torch.cat(tensors_bs1, dim=1)
|
|
|
| return attn_bias_cache[all_shapes], cat_tensors
|
|
|
|
|
| def drop_add_residual_stochastic_depth_list(
|
| x_list: List[Tensor],
|
| residual_func: Callable[[Tensor, Any], Tensor],
|
| sample_drop_ratio: float = 0.0,
|
| scaling_vector=None,
|
| ) -> Tensor:
|
|
|
| branges_scales = [
|
| get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| ]
|
| branges = [s[0] for s in branges_scales]
|
| residual_scale_factors = [s[1] for s in branges_scales]
|
|
|
|
|
| attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
|
|
|
|
| residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))
|
|
|
| outputs = []
|
| for x, brange, residual, residual_scale_factor in zip(
|
| x_list, branges, residual_list, residual_scale_factors
|
| ):
|
| outputs.append(
|
| add_residual(
|
| x, brange, residual, residual_scale_factor, scaling_vector
|
| ).view_as(x)
|
| )
|
| return outputs
|
|
|
|
|
| class NestedTensorBlock(Block):
|
| def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| """
|
| x_list contains a list of tensors to nest together and run
|
| """
|
| assert isinstance(self.attn, MemEffAttention)
|
|
|
| if self.training and self.sample_drop_ratio > 0.0:
|
|
|
| def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.attn(self.norm1(x), attn_bias=attn_bias)
|
|
|
| def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.mlp(self.norm2(x))
|
|
|
| x_list = drop_add_residual_stochastic_depth_list(
|
| x_list,
|
| residual_func=attn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| scaling_vector=self.ls1 if isinstance(self.ls1, LayerScale) else None,
|
| )
|
| x_list = drop_add_residual_stochastic_depth_list(
|
| x_list,
|
| residual_func=ffn_residual_func,
|
| sample_drop_ratio=self.sample_drop_ratio,
|
| scaling_vector=self.ls2 if isinstance(self.ls1, LayerScale) else None,
|
| )
|
| return x_list
|
| else:
|
|
|
| def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
|
|
| def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| return self.ls2(self.mlp(self.norm2(x)))
|
|
|
| attn_bias, x = get_attn_bias_and_cat(x_list)
|
| x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| x = x + ffn_residual_func(x)
|
| return attn_bias.split(x)
|
|
|
| def forward(self, x_or_x_list, return_attention=False):
|
| """
|
| Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
|
| """
|
| if isinstance(x_or_x_list, Tensor):
|
|
|
|
|
| return super().forward(x_or_x_list, return_attention)
|
| elif isinstance(x_or_x_list, list):
|
| if return_attention:
|
| raise NotImplementedError(
|
| "return_attention not supported for nested tensors"
|
| )
|
| assert (
|
| XFORMERS_AVAILABLE
|
| ), "Please install xFormers for nested tensors usage"
|
| return self.forward_nested(x_or_x_list)
|
| else:
|
| raise AssertionError
|
|
|
|
|
| if __name__ == "__main__":
|
| _device = (
|
| "cuda"
|
| if torch.cuda.is_available()
|
| else "mps" if torch.backends.mps.is_available() else "cpu"
|
| )
|
|
|
| block = Block(dim=64, num_heads=8, drop_path=0.3).to(_device)
|
| x = torch.randn(
|
| 10, 16, 64, device=_device
|
| )
|
| output = block(x)
|
| print(output.shape)
|
|
|
| nested_block = NestedTensorBlock(
|
| dim=64, num_heads=8, attn_class=MemEffAttention
|
| ).to(_device)
|
| nested_x = [
|
| torch.randn(10, 16, 64, device=_device),
|
| torch.randn(10, 16, 64, device=_device),
|
| ]
|
| nested_output = nested_block(nested_x)
|
| print(
|
| [o.shape for o in nested_output]
|
| )
|
|
|