| """ |
| model/block.py |
| |
| Single Transformer Block (pre-norm LLaMA-style). |
| |
| Pre-Norm vs Post-Norm: |
| GPT-2 (post-norm): x = x + Attention(LayerNorm(x)) <- less stable |
| LLaMA (pre-norm): x = LayerNorm(x); x = x + Attention(x) <- more stable |
| |
| We use PRE-NORM with RMSNorm for training stability at scale. |
| |
| Block structure: |
| x -> RMSNorm -> CausalSelfAttention -> (+residual) |
| -> RMSNorm -> SwiGLU MLP -> (+residual) |
| -> output |
| |
| Note: Residual connections bypass both norm and sublayer, which allows |
| gradients to flow directly to earlier layers during backprop. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from model.config import ModelConfig |
| from model.norm import RMSNorm |
| from model.attention import CausalSelfAttention |
| from model.mlp import SwiGLU |
|
|
|
|
| class TransformerBlock(nn.Module): |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
|
|
| |
| self.norm_attn = RMSNorm(config.d_model) |
|
|
| |
| self.attn = CausalSelfAttention(config) |
|
|
| |
| self.norm_mlp = RMSNorm(config.d_model) |
|
|
| |
| self.mlp = SwiGLU(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x : (B, T, d_model) |
| |
| Returns: |
| x : (B, T, d_model) |
| """ |
| |
| x = x + self.attn(self.norm_attn(x)) |
|
|
| |
| x = x + self.mlp(self.norm_mlp(x)) |
|
|
| return x |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| from model.config import SLLM_100M |
|
|
| cfg = SLLM_100M |
| block = TransformerBlock(cfg) |
|
|
| n = sum(p.numel() for p in block.parameters()) |
| print(f"Block params : {n/1e6:.3f}M") |
|
|
| B, T = 2, 64 |
| x = torch.randn(B, T, cfg.d_model) |
| out = block(x) |
|
|
| print(f"Input shape : {x.shape}") |
| print(f"Output shape : {out.shape}") |
| assert out.shape == x.shape, "Shape mismatch!" |
| print("PASS") |
|
|