""" RMSNorm implementation for SLM. Pre-norm architecture for stable FP16 training and better quantization. """ import torch import torch.nn as nn class RMSNorm(nn.Module): """Root Mean Square Layer Normalization. RMSNorm is computationally simpler than LayerNorm as it doesn't compute mean statistics. This makes it: - Faster to compute - More stable in FP16 - Better for quantization Reference: https://arxiv.org/abs/1910.07467 """ def __init__(self, hidden_size: int, eps: float = 1e-6): """Initialize RMSNorm. Args: hidden_size: The size of the hidden dimension eps: Small constant for numerical stability """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply RMS normalization. Args: x: Input tensor of shape [..., hidden_size] Returns: Normalized tensor of same shape """ # Compute RMS: sqrt(mean(x^2)) # Use float32 for numerical stability, then cast back input_dtype = x.dtype x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return (self.weight * x).to(input_dtype)