himipo's picture
first
11aa70b
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
from typing import Callable, List, Optional
import torch
from torch import Tensor, nn
from ..utils import cat_keep_shapes, uncat_with_shapes
from .attention import CausalSelfAttention, SelfAttention
from .ffn_layers import Mlp
from .layer_scale import LayerScale # , DropPath
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.accumulated_cache_size_limit = 1024
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
ffn_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = SelfAttention,
ffn_layer: Callable[..., nn.Module] = Mlp,
mask_k_bias: bool = False,
device=None,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
mask_k_bias=mask_k_bias,
device=device,
)
self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * ffn_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
device=device,
)
self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
self.sample_drop_ratio = drop_path
@staticmethod
def _maybe_index_rope(rope: tuple[Tensor, Tensor] | None, indices: Tensor) -> tuple[Tensor, Tensor] | None:
if rope is None:
return None
sin, cos = rope
assert sin.ndim == cos.ndim
if sin.ndim == 4:
# If the rope embedding has a batch dimension (is different for each batch element), index into it
return sin[indices], cos[indices] # [batch, heads, patches, embed_dim]
else:
# No batch dimension, do not index
return sin, cos # [heads, patches, embed_dim] or [patches, embed_dim]
def _forward(self, x: Tensor, rope=None) -> Tensor:
"""
This is the reference implementation for a single tensor, matching what is done below for a list.
We call the list op on [x] instead of this function.
"""
b, _, _ = x.shape
sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1)
residual_scale_factor = b / sample_subset_size
if self.training and self.sample_drop_ratio > 0.0:
indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset_1 = x[indices_1]
rope_subset = self._maybe_index_rope(rope, indices_1)
residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset)
x_attn = torch.index_add(
x,
dim=0,
source=self.ls1(residual_1),
index=indices_1,
alpha=residual_scale_factor,
)
indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset_2 = x_attn[indices_2]
residual_2 = self.mlp(self.norm2(x_subset_2))
x_ffn = torch.index_add(
x_attn,
dim=0,
source=self.ls2(residual_2),
index=indices_2,
alpha=residual_scale_factor,
)
else:
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
return x_ffn
def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
"""
This list operator concatenates the tokens from the list of inputs together to save
on the elementwise operations. Torch-compile memory-planning allows hiding the overhead
related to concat ops.
"""
b_list = [x.shape[0] for x in x_list]
sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list]
residual_scale_factors = [b / sample_subset_size for b, sample_subset_size in zip(b_list, sample_subset_sizes)]
if self.training and self.sample_drop_ratio > 0.0:
indices_1_list = [
(torch.randperm(b, device=x.device))[:sample_subset_size]
for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
]
x_subset_1_list = [x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)]
if rope_list is not None:
rope_subset_list = [
self._maybe_index_rope(rope, indices_1) for rope, indices_1 in zip(rope_list, indices_1_list)
]
else:
rope_subset_list = rope_list
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
x_attn_list = [
torch.index_add(
x,
dim=0,
source=self.ls1(residual_1),
index=indices_1,
alpha=residual_scale_factor,
)
for x, residual_1, indices_1, residual_scale_factor in zip(
x_list, residual_1_list, indices_1_list, residual_scale_factors
)
]
indices_2_list = [
(torch.randperm(b, device=x.device))[:sample_subset_size]
for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
]
x_subset_2_list = [x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)]
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
norm2_flat = self.norm2(flattened)
norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens)
residual_2_list = self.mlp.forward_list(norm2_list)
x_ffn = [
torch.index_add(
x_attn,
dim=0,
source=self.ls2(residual_2),
index=indices_2,
alpha=residual_scale_factor,
)
for x_attn, residual_2, indices_2, residual_scale_factor in zip(
x_attn_list, residual_2_list, indices_2_list, residual_scale_factors
)
]
else:
x_out = []
for x, rope in zip(x_list, rope_list):
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
x_out.append(x_ffn)
x_ffn = x_out
return x_ffn
def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
if isinstance(x_or_x_list, Tensor):
# for reference:
# return self._forward(x_or_x_list, rope=rope_or_rope_list)
# in order to match implementations we call the list op:
return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
elif isinstance(x_or_x_list, list):
if rope_or_rope_list is None:
rope_or_rope_list = [None for x in x_or_x_list]
# return [self._forward(x, rope=rope) for x, rope in zip(x_or_x_list, rope_or_rope_list)]
return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
else:
raise AssertionError
class CausalSelfAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
ffn_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
is_causal: bool = True,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
dropout_prob: float = 0.0,
):
super().__init__()
self.dim = dim
self.is_causal = is_causal
self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
self.attention_norm = norm_layer(dim)
self.attention = CausalSelfAttention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob)
self.ffn_norm = norm_layer(dim)
ffn_hidden_dim = int(dim * ffn_ratio)
self.feed_forward = Mlp(
in_features=dim,
hidden_features=ffn_hidden_dim,
drop=dropout_prob,
act_layer=act_layer,
)
self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
def init_weights(
self,
init_attn_std: float | None = None,
init_proj_std: float | None = None,
init_fc_std: float | None = None,
factor: float = 1.0,
) -> None:
init_attn_std = init_attn_std or (self.dim**-0.5)
init_proj_std = init_proj_std or init_attn_std * factor
init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
self.attention.init_weights(init_attn_std, init_proj_std)
self.attention_norm.reset_parameters()
nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
self.ffn_norm.reset_parameters()
def forward(
self,
x: torch.Tensor,
):
x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
return x_ffn