StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Normalization layers for Gamma Space Model."""
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Simpler normalization often used in modern language models.
Inspired by T5's implementation.
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply RMSNorm: x * (weight / RMS(x))"""
# (batch, seq_len, d_model) or (batch, d_model)
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x * self.weight / rms
class LayerNorm(nn.Module):
"""Standard Layer Normalization (wrapper around torch.nn.LayerNorm)."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(d_model, eps=eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.norm(x)