HanzhouLiu
Track all files under examples/ with Git LFS
a6e928c
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import torch
from torch import nn, Tensor
from .attention import Attention, CrossAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None) -> Tensor:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
pos=pos,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
class CrossBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
cross_attn_class: Callable[..., nn.Module] = CrossAttention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True,
rope=None,
) -> None:
super().__init__()
self.qnorm1 = norm_layer(dim)
self.kvnorm1 = norm_layer(dim)
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.cross_attn = cross_attn_class(
dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
drop=drop, bias=ffn_bias,
)
self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, q: Tensor, kv: Tensor, q_pos=None, kv_pos=None) -> Tensor:
"""
Forward pass.
"""
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos))
def cross_attn_residual_func(q: Tensor, kv: Tensor, q_pos=None, kv_pos=None) -> Tensor:
return self.ls2(self.cross_attn(q=self.qnorm1(q), kv=self.kvnorm1(kv), q_pos=q_pos, kv_pos=kv_pos))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls3(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
q = drop_add_residual_stochastic_depth(
q, residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
pos=q_pos,
)
x = drop_add_residual_stochastic_depth_cross(
q=q, kv=kv, q_pos=q_pos, kv_pos=kv_pos,
residual_func=lambda _q, _kv, _q_pos=None, _kv_pos=None: cross_attn_residual_func(_q, _kv, _q_pos, _kv_pos),
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x, residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
q = q + self.drop_path1(attn_residual_func(q, pos=q_pos))
x = q + self.drop_path2(cross_attn_residual_func(q=q, kv=kv, q_pos=q_pos, kv_pos=kv_pos))
x = x + self.drop_path3(ffn_residual_func(x))
else:
q = q + attn_residual_func(q, pos=q_pos)
x = q + cross_attn_residual_func(q=q, kv=kv, q_pos=q_pos, kv_pos=kv_pos)
x = x + ffn_residual_func(x)
return x, kv
class CrossBlock2(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
cross_attn_class: Callable[..., nn.Module] = CrossAttention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True,
rope=None,
) -> None:
super().__init__()
self.qnorm1 = norm_layer(dim)
self.kvnorm1 = norm_layer(dim)
self.norm1 = norm_layer(dim)
self.norm3 = norm_layer(dim)
self.attn = attn_class(
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.attn2 = attn_class(
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.cross_attn = cross_attn_class(
dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=drop, qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
drop=drop, bias=ffn_bias,
)
self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ls4 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path4 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, q: Tensor, kv: Tensor, q_pos=None, kv_pos=None) -> Tensor:
"""
Forward pass.
"""
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos))
def attn_residual_func2(x: Tensor, pos=None) -> Tensor:
return self.ls4(self.attn2(self.norm3(x), pos=pos))
def cross_attn_residual_func(q: Tensor, kv: Tensor, q_pos=None, kv_pos=None) -> Tensor:
return self.ls2(self.cross_attn(q=self.qnorm1(q), kv=self.kvnorm1(kv), q_pos=q_pos, kv_pos=kv_pos))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls3(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
q = drop_add_residual_stochastic_depth(
q, residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
pos=q_pos,
)
kv = drop_add_residual_stochastic_depth(
kv, residual_func=attn_residual_func2,
sample_drop_ratio=self.sample_drop_ratio,
pos=kv_pos,
)
x = drop_add_residual_stochastic_depth_cross(
q=q, kv=kv, q_pos=q_pos, kv_pos=kv_pos,
residual_func=lambda _q, _kv, _q_pos=None, _kv_pos=None: cross_attn_residual_func(_q, _kv, _q_pos, _kv_pos),
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x, residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
q = q + self.drop_path1(attn_residual_func(q, pos=q_pos))
kv = kv + self.drop_path4(attn_residual_func2(kv, pos=kv_pos))
x = q + self.drop_path2(cross_attn_residual_func(q=q, kv=kv, q_pos=q_pos, kv_pos=kv_pos))
x = x + self.drop_path3(ffn_residual_func(x))
else:
q = q + attn_residual_func(q, pos=q_pos)
kv = kv + attn_residual_func2(kv, pos=kv_pos)
x = q + cross_attn_residual_func(q=q, kv=kv, q_pos=q_pos, kv_pos=kv_pos)
x = x + ffn_residual_func(x)
return x, kv
def drop_add_residual_stochastic_depth_cross(
q: Tensor,
kv: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
q_pos=None,
kv_pos=None,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = q.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=q.device))[:sample_subset_size]
q_subset = q[brange]
kv_subset = kv[brange]
# 2) apply residual_func to get residual
if q_pos is not None:
# if necessary, apply rope to the subset
q_pos = q_pos[brange] if q_pos is not None else None
kv_pos = kv_pos[brange] if kv_pos is not None else None
residual = residual_func(q_subset, kv=kv_subset, q_pos=q_pos, kv_pos=kv_pos)
else:
residual = residual_func(q_subset, kv=kv_subset)
q_flat = q_subset.flatten(1)
kv_flat = kv_subset.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
q_plus_residual = torch.index_add(q_flat, 0, brange, residual.to(dtype=q.dtype), alpha=residual_scale_factor)
return q_plus_residual.view_as(q)
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
pos=None,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
if pos is not None:
# if necessary, apply rope to the subset
pos = pos[brange]
residual = residual_func(x_subset, pos=pos)
else:
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
if not XFORMERS_AVAILABLE:
raise AssertionError("xFormers is required for using nested tensors")
return self.forward_nested(x_or_x_list)
else:
raise AssertionError