AETHER-Micro-0.5B / normalization.py
Be2Jay's picture
Upload AETHER-Micro 0.5B Phase 1 checkpoint (Step 57000)
de40e7d verified
#!/usr/bin/env python3
"""
AETHER-Micro Normalization
RMSNorm 구현
"""
import torch
import torch.nn as nn
class AETHERMicroRMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
Reference: https://arxiv.org/abs/1910.07467
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)