| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import Callable |
| from torch import Tensor, nn |
|
|
| from .attention import Attention, LayerScale, Mlp |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| proj_bias: bool = True, |
| ffn_bias: bool = True, |
| drop: float = 0.0, |
| attn_drop: float = 0.0, |
| init_values=None, |
| drop_path: float = 0.0, |
| act_layer: Callable[..., nn.Module] = nn.GELU, |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
| attn_class: Callable[..., nn.Module] = Attention, |
| ffn_layer: Callable[..., nn.Module] = Mlp, |
| qk_norm: bool = False, |
| rope=None, |
| ) -> None: |
| super().__init__() |
|
|
| self.norm1 = norm_layer(dim) |
|
|
| self.attn = attn_class( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| attn_drop=attn_drop, |
| proj_drop=drop, |
| qk_norm=qk_norm, |
| rope=rope, |
| ) |
|
|
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = ffn_layer( |
| in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop, |
| bias=ffn_bias, |
| ) |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
| self.sample_drop_ratio = 0.0 |
|
|
| def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: |
| def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: |
| return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) |
|
|
| def ffn_residual_func(x: Tensor) -> Tensor: |
| return self.ls2(self.mlp(self.norm2(x))) |
|
|
| |
| x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) |
| x = x + ffn_residual_func(x) |
| return x |
|
|