File size: 2,885 Bytes
e101805 2a1ba39 e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from typing import Callable, List, Optional
import torch
from torch import Tensor, nn
from .attention import SelfAttention
from .ffn_layers import SwiGLUFFN
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
ffn_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
ffn_drop: float = 0.0,
attn_drop: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = SelfAttention,
ffn_layer: Callable[..., nn.Module] = SwiGLUFFN,
dual_attention: bool = False,
device=None,
) -> None:
super().__init__()
self._dual_attention = bool(dual_attention)
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
proj_drop=attn_drop,
device=device,
)
self.nope_attn = (
attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
proj_drop=attn_drop,
device=device,
)
if self._dual_attention
else None
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * ffn_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
drop=ffn_drop,
bias=ffn_bias,
device=device,
)
def compute_attention_output(
self,
x: Tensor,
rope: tuple[Tensor, Tensor] | None = None,
attn_mask: Tensor | None = None,
use_nope_attn: bool = False,
) -> Tensor:
"""Compute self-attention output without adding the residual."""
if use_nope_attn:
if self.nope_attn is None:
raise RuntimeError("nope_attn is not initialized; instantiate with dual_attention=True")
attn_module = self.nope_attn
else:
attn_module = self.attn
return attn_module(self.norm1(x), attn_mask=attn_mask, rope=rope)
def forward_with_attention_output(self, x: Tensor, attn_output: Tensor) -> Tensor:
"""Apply residual connection + FFN given a precomputed attention output."""
x_attn = x + attn_output
x_ffn = x_attn + self.mlp(self.norm2(x_attn))
return x_ffn
def forward(self, x: Tensor, rope: tuple[Tensor, Tensor] | None = None, attn_mask: Tensor | None = None) -> Tensor:
"""Forward for batched tensor inputs only."""
attn_output = self.compute_attention_output(x, rope=rope, attn_mask=attn_mask)
return self.forward_with_attention_output(x, attn_output) |