""" model/norm.py RMSNorm — Root Mean Square Layer Normalization. Used in LLaMA-style transformers instead of standard LayerNorm. Key difference from LayerNorm: - No mean subtraction (centering) - No bias term - Only re-scales with a single learned gain vector (weight) - ~40% faster in practice (no mean computation) Formula: RMSNorm(x) = x / RMS(x) * weight where RMS(x) = sqrt( mean(x^2) + eps ) """ import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-6): """ Args: d_model : hidden dimension (size of last axis of input) eps : small constant for numerical stability """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) # learnable gain def _norm(self, x: torch.Tensor) -> torch.Tensor: # x: (..., d_model) # compute RMS along last dimension, keepdim for broadcasting return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: # cast to float32 for stable norm, then back to input dtype output = self._norm(x.float()).type_as(x) return output * self.weight # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": torch.manual_seed(0) B, T, D = 2, 16, 768 x = torch.randn(B, T, D) norm = RMSNorm(D) out = norm(x) print(f"Input shape : {x.shape}") print(f"Output shape : {out.shape}") print(f"Output dtype : {out.dtype}") # Verify: each vector should be approximately unit RMS after norm (before weight) rms_before = x.pow(2).mean(dim=-1).sqrt() rms_after = out.pow(2).mean(dim=-1).sqrt() print(f"RMS before norm : {rms_before.mean():.3f}") print(f"RMS after norm : {rms_after.mean():.3f} (weight=1 so should be ~1.0)") print("PASS" if torch.allclose(rms_after, torch.ones_like(rms_after), atol=1e-4) else "FAIL")