st24hour's picture
Upload folder using huggingface_hub
2a1ba39 verified
from typing import Callable, List, Optional
import torch
from torch import Tensor, nn
from .attention import SelfAttention
from .ffn_layers import SwiGLUFFN
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
ffn_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
ffn_drop: float = 0.0,
attn_drop: float = 0.0,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = SelfAttention,
ffn_layer: Callable[..., nn.Module] = SwiGLUFFN,
dual_attention: bool = False,
device=None,
) -> None:
super().__init__()
self._dual_attention = bool(dual_attention)
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
proj_drop=attn_drop,
device=device,
)
self.nope_attn = (
attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
proj_drop=attn_drop,
device=device,
)
if self._dual_attention
else None
)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * ffn_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
drop=ffn_drop,
bias=ffn_bias,
device=device,
)
def compute_attention_output(
self,
x: Tensor,
rope: tuple[Tensor, Tensor] | None = None,
attn_mask: Tensor | None = None,
use_nope_attn: bool = False,
) -> Tensor:
"""Compute self-attention output without adding the residual."""
if use_nope_attn:
if self.nope_attn is None:
raise RuntimeError("nope_attn is not initialized; instantiate with dual_attention=True")
attn_module = self.nope_attn
else:
attn_module = self.attn
return attn_module(self.norm1(x), attn_mask=attn_mask, rope=rope)
def forward_with_attention_output(self, x: Tensor, attn_output: Tensor) -> Tensor:
"""Apply residual connection + FFN given a precomputed attention output."""
x_attn = x + attn_output
x_ffn = x_attn + self.mlp(self.norm2(x_attn))
return x_ffn
def forward(self, x: Tensor, rope: tuple[Tensor, Tensor] | None = None, attn_mask: Tensor | None = None) -> Tensor:
"""Forward for batched tensor inputs only."""
attn_output = self.compute_attention_output(x, rope=rope, attn_mask=attn_mask)
return self.forward_with_attention_output(x, attn_output)