|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable |
|
|
|
|
|
from torch import Tensor, nn |
|
|
|
|
|
from .attention import Attention, LayerScale, Mlp |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
qkv_bias: bool = True, |
|
|
proj_bias: bool = True, |
|
|
ffn_bias: bool = True, |
|
|
drop: float = 0.0, |
|
|
attn_drop: float = 0.0, |
|
|
init_values=None, |
|
|
drop_path: float = 0.0, |
|
|
act_layer: Callable[..., nn.Module] = nn.GELU, |
|
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
|
|
attn_class: Callable[..., nn.Module] = Attention, |
|
|
ffn_layer: Callable[..., nn.Module] = Mlp, |
|
|
qk_norm: bool = False, |
|
|
rope=None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = norm_layer(dim) |
|
|
|
|
|
self.attn = attn_class( |
|
|
dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
proj_bias=proj_bias, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
qk_norm=qk_norm, |
|
|
rope=rope, |
|
|
) |
|
|
|
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = ffn_layer( |
|
|
in_features=dim, |
|
|
hidden_features=mlp_hidden_dim, |
|
|
act_layer=act_layer, |
|
|
drop=drop, |
|
|
bias=ffn_bias, |
|
|
) |
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
|
|
|
self.sample_drop_ratio = 0.0 |
|
|
|
|
|
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: |
|
|
def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: |
|
|
return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) |
|
|
|
|
|
def ffn_residual_func(x: Tensor) -> Tensor: |
|
|
return self.ls2(self.mlp(self.norm2(x))) |
|
|
|
|
|
|
|
|
x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) |
|
|
x = x + ffn_residual_func(x) |
|
|
return x |
|
|
|