| """ |
| model/norm.py |
| |
| RMSNorm — Root Mean Square Layer Normalization. |
| Used in LLaMA-style transformers instead of standard LayerNorm. |
| |
| Key difference from LayerNorm: |
| - No mean subtraction (centering) |
| - No bias term |
| - Only re-scales with a single learned gain vector (weight) |
| - ~40% faster in practice (no mean computation) |
| |
| Formula: |
| RMSNorm(x) = x / RMS(x) * weight |
| where RMS(x) = sqrt( mean(x^2) + eps ) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class RMSNorm(nn.Module): |
|
|
| def __init__(self, d_model: int, eps: float = 1e-6): |
| """ |
| Args: |
| d_model : hidden dimension (size of last axis of input) |
| eps : small constant for numerical stability |
| """ |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
| def _norm(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| torch.manual_seed(0) |
| B, T, D = 2, 16, 768 |
| x = torch.randn(B, T, D) |
| norm = RMSNorm(D) |
|
|
| out = norm(x) |
| print(f"Input shape : {x.shape}") |
| print(f"Output shape : {out.shape}") |
| print(f"Output dtype : {out.dtype}") |
|
|
| |
| rms_before = x.pow(2).mean(dim=-1).sqrt() |
| rms_after = out.pow(2).mean(dim=-1).sqrt() |
| print(f"RMS before norm : {rms_before.mean():.3f}") |
| print(f"RMS after norm : {rms_after.mean():.3f} (weight=1 so should be ~1.0)") |
| print("PASS" if torch.allclose(rms_after, torch.ones_like(rms_after), atol=1e-4) else "FAIL") |
|
|