| #!/usr/bin/env python3 | |
| """ | |
| AETHER-Micro Normalization | |
| RMSNorm 구현 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class AETHERMicroRMSNorm(nn.Module): | |
| """ | |
| Root Mean Square Layer Normalization | |
| Reference: https://arxiv.org/abs/1910.07467 | |
| """ | |
| def __init__(self, hidden_size, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight * hidden_states.to(input_dtype) | |