| from typing import Callable, List, Optional | |
| import torch | |
| from torch import Tensor, nn | |
| from .attention import SelfAttention | |
| from .ffn_layers import SwiGLUFFN | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| ffn_ratio: float = 4.0, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| ffn_bias: bool = True, | |
| ffn_drop: float = 0.0, | |
| attn_drop: float = 0.0, | |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, | |
| attn_class: Callable[..., nn.Module] = SelfAttention, | |
| ffn_layer: Callable[..., nn.Module] = SwiGLUFFN, | |
| dual_attention: bool = False, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| self._dual_attention = bool(dual_attention) | |
| self.norm1 = norm_layer(dim) | |
| self.attn = attn_class( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| proj_drop=attn_drop, | |
| device=device, | |
| ) | |
| self.nope_attn = ( | |
| attn_class( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| proj_drop=attn_drop, | |
| device=device, | |
| ) | |
| if self._dual_attention | |
| else None | |
| ) | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * ffn_ratio) | |
| self.mlp = ffn_layer( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| drop=ffn_drop, | |
| bias=ffn_bias, | |
| device=device, | |
| ) | |
| def compute_attention_output( | |
| self, | |
| x: Tensor, | |
| rope: tuple[Tensor, Tensor] | None = None, | |
| attn_mask: Tensor | None = None, | |
| use_nope_attn: bool = False, | |
| ) -> Tensor: | |
| """Compute self-attention output without adding the residual.""" | |
| if use_nope_attn: | |
| if self.nope_attn is None: | |
| raise RuntimeError("nope_attn is not initialized; instantiate with dual_attention=True") | |
| attn_module = self.nope_attn | |
| else: | |
| attn_module = self.attn | |
| return attn_module(self.norm1(x), attn_mask=attn_mask, rope=rope) | |
| def forward_with_attention_output(self, x: Tensor, attn_output: Tensor) -> Tensor: | |
| """Apply residual connection + FFN given a precomputed attention output.""" | |
| x_attn = x + attn_output | |
| x_ffn = x_attn + self.mlp(self.norm2(x_attn)) | |
| return x_ffn | |
| def forward(self, x: Tensor, rope: tuple[Tensor, Tensor] | None = None, attn_mask: Tensor | None = None) -> Tensor: | |
| """Forward for batched tensor inputs only.""" | |
| attn_output = self.compute_attention_output(x, rope=rope, attn_mask=attn_mask) | |
| return self.forward_with_attention_output(x, attn_output) |