"""Transformer block for the dense SAGE model.""" from __future__ import annotations from typing import Optional import torch from torch import nn from model.attention import GQAAttention from model.config import ModelConfig from model.mlp import SwiGLUMLP from model.rmsnorm import RMSNorm class TransformerBlock(nn.Module): """Pre-norm transformer block with attention and SwiGLU.""" def __init__(self, config: ModelConfig): super().__init__() self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.attn = GQAAttention(config) self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.mlp = SwiGLUMLP(config) def forward( self, hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Forward pass with residual connections.""" attn_output, present = self.attn(self.norm1(hidden_states), cos, sin, past_key_value=past_key_value) hidden_states = hidden_states + attn_output hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states, present