sllm / model /norm.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
model/norm.py
RMSNorm — Root Mean Square Layer Normalization.
Used in LLaMA-style transformers instead of standard LayerNorm.
Key difference from LayerNorm:
- No mean subtraction (centering)
- No bias term
- Only re-scales with a single learned gain vector (weight)
- ~40% faster in practice (no mean computation)
Formula:
RMSNorm(x) = x / RMS(x) * weight
where RMS(x) = sqrt( mean(x^2) + eps )
"""
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
"""
Args:
d_model : hidden dimension (size of last axis of input)
eps : small constant for numerical stability
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model)) # learnable gain
def _norm(self, x: torch.Tensor) -> torch.Tensor:
# x: (..., d_model)
# compute RMS along last dimension, keepdim for broadcasting
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# cast to float32 for stable norm, then back to input dtype
output = self._norm(x.float()).type_as(x)
return output * self.weight
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
torch.manual_seed(0)
B, T, D = 2, 16, 768
x = torch.randn(B, T, D)
norm = RMSNorm(D)
out = norm(x)
print(f"Input shape : {x.shape}")
print(f"Output shape : {out.shape}")
print(f"Output dtype : {out.dtype}")
# Verify: each vector should be approximately unit RMS after norm (before weight)
rms_before = x.pow(2).mean(dim=-1).sqrt()
rms_after = out.pow(2).mean(dim=-1).sqrt()
print(f"RMS before norm : {rms_before.mean():.3f}")
print(f"RMS after norm : {rms_after.mean():.3f} (weight=1 so should be ~1.0)")
print("PASS" if torch.allclose(rms_after, torch.ones_like(rms_after), atol=1e-4) else "FAIL")