|
|
""" |
|
|
RMSNorm implementation for SLM. |
|
|
Pre-norm architecture for stable FP16 training and better quantization. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
"""Root Mean Square Layer Normalization. |
|
|
|
|
|
RMSNorm is computationally simpler than LayerNorm as it doesn't |
|
|
compute mean statistics. This makes it: |
|
|
- Faster to compute |
|
|
- More stable in FP16 |
|
|
- Better for quantization |
|
|
|
|
|
Reference: https://arxiv.org/abs/1910.07467 |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, eps: float = 1e-6): |
|
|
"""Initialize RMSNorm. |
|
|
|
|
|
Args: |
|
|
hidden_size: The size of the hidden dimension |
|
|
eps: Small constant for numerical stability |
|
|
""" |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply RMS normalization. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape [..., hidden_size] |
|
|
|
|
|
Returns: |
|
|
Normalized tensor of same shape |
|
|
""" |
|
|
|
|
|
|
|
|
input_dtype = x.dtype |
|
|
x = x.float() |
|
|
|
|
|
variance = x.pow(2).mean(dim=-1, keepdim=True) |
|
|
x = x * torch.rsqrt(variance + self.eps) |
|
|
|
|
|
return (self.weight * x).to(input_dtype) |
|
|
|