sage / model /rmsnorm.py
sage002's picture
feat: rewrite SAGE 1B architecture and replace legacy repo contents
ef18673 verified
raw
history blame contribute delete
785 Bytes
"""RMSNorm implementation used by SAGE."""
from __future__ import annotations
import torch
from torch import nn
class RMSNorm(nn.Module):
"""Root mean square normalization with float32 accumulation."""
def __init__(self, dim: int, eps: float = 1.0e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize the last dimension and cast back to the input dtype."""
if x.ndim < 2:
raise ValueError("RMSNorm expects at least 2 dimensions.")
variance = x.float().pow(2).mean(dim=-1, keepdim=True)
normalized = x.float() * torch.rsqrt(variance + self.eps)
return (normalized.to(dtype=x.dtype)) * self.weight