File size: 771 Bytes
de40e7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | #!/usr/bin/env python3
"""
AETHER-Micro Normalization
RMSNorm 구현
"""
import torch
import torch.nn as nn
class AETHERMicroRMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
Reference: https://arxiv.org/abs/1910.07467
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
|