| | """Channel-wise RMSNorm for NCHW tensors.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| |
|
| |
|
| | class ChannelWiseRMSNorm(nn.Module): |
| | """Channel-wise RMSNorm with float32 reduction for numerical stability. |
| | |
| | Normalizes across channels per spatial position. Supports optional |
| | per-channel affine weight and bias. |
| | """ |
| |
|
| | def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None: |
| | super().__init__() |
| | self.channels: int = int(channels) |
| | self._eps: float = float(eps) |
| | if affine: |
| | self.weight = nn.Parameter(torch.ones(self.channels)) |
| | self.bias = nn.Parameter(torch.zeros(self.channels)) |
| | else: |
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | if x.dim() < 2: |
| | return x |
| | |
| | ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32) |
| | inv_rms = torch.rsqrt(ms + self._eps) |
| | y = x * inv_rms |
| | if self.weight is not None: |
| | shape = (1, -1) + (1,) * (x.dim() - 2) |
| | y = y * self.weight.view(shape).to(dtype=y.dtype) |
| | if self.bias is not None: |
| | y = y + self.bias.view(shape).to(dtype=y.dtype) |
| | return y.to(dtype=x.dtype) |
| |
|