PebbleLM-117M / src /model /normalization.py
nameissakthi's picture
Add model architecture code
27871e7
"""
RMSNorm implementation for SLM.
Pre-norm architecture for stable FP16 training and better quantization.
"""
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
RMSNorm is computationally simpler than LayerNorm as it doesn't
compute mean statistics. This makes it:
- Faster to compute
- More stable in FP16
- Better for quantization
Reference: https://arxiv.org/abs/1910.07467
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""Initialize RMSNorm.
Args:
hidden_size: The size of the hidden dimension
eps: Small constant for numerical stability
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply RMS normalization.
Args:
x: Input tensor of shape [..., hidden_size]
Returns:
Normalized tensor of same shape
"""
# Compute RMS: sqrt(mean(x^2))
# Use float32 for numerical stability, then cast back
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return (self.weight * x).to(input_dtype)