""" model/block.py Single Transformer Block (pre-norm LLaMA-style). Pre-Norm vs Post-Norm: GPT-2 (post-norm): x = x + Attention(LayerNorm(x)) <- less stable LLaMA (pre-norm): x = LayerNorm(x); x = x + Attention(x) <- more stable We use PRE-NORM with RMSNorm for training stability at scale. Block structure: x -> RMSNorm -> CausalSelfAttention -> (+residual) -> RMSNorm -> SwiGLU MLP -> (+residual) -> output Note: Residual connections bypass both norm and sublayer, which allows gradients to flow directly to earlier layers during backprop. """ import torch import torch.nn as nn from model.config import ModelConfig from model.norm import RMSNorm from model.attention import CausalSelfAttention from model.mlp import SwiGLU class TransformerBlock(nn.Module): def __init__(self, config: ModelConfig): super().__init__() # Pre-attention norm self.norm_attn = RMSNorm(config.d_model) # Causal self-attention with RoPE self.attn = CausalSelfAttention(config) # Pre-FFN norm self.norm_mlp = RMSNorm(config.d_model) # SwiGLU feed-forward self.mlp = SwiGLU(config) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x : (B, T, d_model) Returns: x : (B, T, d_model) """ # Attention sub-layer with residual x = x + self.attn(self.norm_attn(x)) # FFN sub-layer with residual x = x + self.mlp(self.norm_mlp(x)) return x # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": from model.config import SLLM_100M cfg = SLLM_100M block = TransformerBlock(cfg) n = sum(p.numel() for p in block.parameters()) print(f"Block params : {n/1e6:.3f}M") B, T = 2, 64 x = torch.randn(B, T, cfg.d_model) out = block(x) print(f"Input shape : {x.shape}") print(f"Output shape : {out.shape}") assert out.shape == x.shape, "Shape mismatch!" print("PASS")