| from __future__ import annotations |
|
|
| from collections.abc import Sequence |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
|
|
| __all__ = [ |
| "ChannelWiseRMSNorm", |
| "GlobalRMSNorm", |
| "GroupNormF32", |
| "LayerNorm", |
| "LayerNorm2d", |
| "RMSNorm", |
| "global_rms_norm", |
| "row_norm", |
| ] |
|
|
|
|
| _HALF_PRECISION_DTYPES: tuple[torch.dtype, ...] = (torch.float16, torch.bfloat16) |
|
|
|
|
| def _cast_to_float32(x: Tensor) -> tuple[Tensor, torch.dtype]: |
| """Return tensor cast to fp32 for compute along with the original dtype.""" |
| dtype = x.dtype |
| if dtype in _HALF_PRECISION_DTYPES: |
| return x.float(), dtype |
| return x, dtype |
|
|
|
|
| def _restore_dtype(x: Tensor, dtype: torch.dtype) -> Tensor: |
| return x if x.dtype == dtype else x.to(dtype) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Thin wrapper around ``torch.nn.RMSNorm`` that preserves our API. |
| |
| - Keeps an ``_eps`` attribute used by tests. |
| - Maps ``affine`` -> ``elementwise_affine``. |
| - Delegates all compute to the native implementation. |
| |
| Notes on precision |
| - PyTorch ≥ 2.8 computes RMSNorm reductions in ``opmath`` dtype |
| (float32 for float16/bfloat16) internally, then restores the input dtype. |
| """ |
|
|
| def __init__(self, dim: int, eps: float = 1e-6, affine: bool = True) -> None: |
| super().__init__() |
| self._eps: float = float(eps) |
| self._impl: nn.RMSNorm = nn.RMSNorm( |
| dim, eps=self._eps, elementwise_affine=affine |
| ) |
| self._dim: int = int(dim) |
|
|
| @property |
| def weight(self) -> Tensor | None: |
| return self._impl.weight |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """Apply RMSNorm while avoiding dtype-mismatch warnings under AMP. |
| |
| When inputs are bfloat16/float16 under autocast and the stored affine |
| weight is float32 (common when model weights remain FP32), PyTorch emits |
| a warning about mismatched dtypes and disables the fused path. |
| |
| We pass a view of the weight cast to the input dtype into the functional |
| RMSNorm to enable the fused implementation without changing the |
| parameter storage dtype (which remains FP32 for stability). |
| """ |
| |
| w: Tensor | None = self._impl.weight |
| w_cast = w.to(dtype=x.dtype) if w is not None else None |
| |
| return F.rms_norm(x, (self._dim,), w_cast, self._eps) |
|
|
|
|
| class LayerNorm(nn.LayerNorm): |
| """Thin wrapper over ``torch.nn.LayerNorm`` with an ``_eps`` attribute. |
| |
| Notes on precision |
| - Native LayerNorm kernels accumulate statistics in ``opmath`` dtype |
| (float32 for float16/bfloat16) before casting results back. |
| """ |
|
|
| def __init__( |
| self, |
| normalized_shape: int | Sequence[int], |
| eps: float = 1e-6, |
| elementwise_affine: bool = True, |
| ) -> None: |
| shape: int | list[int] |
| match normalized_shape: |
| case int() as dim: |
| shape = dim |
| case _: |
| shape = [int(v) for v in normalized_shape] |
| super().__init__(shape, eps=eps, elementwise_affine=elementwise_affine) |
| self._eps: float = float(eps) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| return super().forward(x) |
|
|
|
|
| |
|
|
|
|
| class GroupNormF32(nn.GroupNorm): |
| """Thin wrapper over ``torch.nn.GroupNorm`` with an ``_eps`` attribute. |
| |
| Notes on precision |
| - Native GroupNorm uses ``opmath`` accumulation (float32 for |
| float16/bfloat16) for statistics and fused scale/bias math; results |
| are cast back to the input dtype. |
| - Despite the class name, this wrapper does not force a cast; it |
| delegates to the native implementation. |
| """ |
|
|
| def __init__( |
| self, |
| num_groups: int, |
| num_channels: int, |
| eps: float = 1e-6, |
| affine: bool = True, |
| ) -> None: |
| super().__init__(num_groups, num_channels, eps=eps, affine=affine) |
| self._eps: float = float(eps) |
|
|
|
|
| class ChannelWiseRMSNorm(nn.Module): |
| """Channel-wise RMSNorm for NCHW tensors (fast NCHW path). |
| |
| - Normalizes across channels per spatial position without reshaping, using |
| a float32 reduction for numerical stability and keeping elementwise ops |
| in input dtype for throughput. |
| - Supports optional per-channel affine ``weight`` and ``bias``. |
| """ |
|
|
| def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None: |
| super().__init__() |
| self.channels: int = int(channels) |
| self._eps: float = float(eps) |
| self.affine: bool = bool(affine) |
| if self.affine: |
| self.weight = nn.Parameter(torch.ones(self.channels)) |
| self.bias = nn.Parameter(torch.zeros(self.channels)) |
| else: |
| self.register_parameter("weight", None) |
| self.register_parameter("bias", None) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if x.dim() < 2: |
| return x |
| C = x.size(1) |
| if self.channels != C: |
| raise ValueError(f"ChannelWiseRMSNorm expected C={self.channels}, got {C}") |
| |
| ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32) |
| inv_rms = torch.rsqrt(ms + self._eps) |
| y = x * inv_rms.to(dtype=x.dtype) |
| if self.affine and self.weight is not None: |
| shape = (1, -1) + (1,) * (x.dim() - 2) |
| y = y * self.weight.view(shape).to(dtype=x.dtype) |
| if self.bias is not None: |
| y = y + self.bias.view(shape).to(dtype=x.dtype) |
| return y |
|
|
|
|
| def global_rms_norm(x: Tensor, eps: float = 1e-6) -> Tensor: |
| """Project each sample to unit RMS across all non-batch dimensions. |
| |
| This is equivalent to RMSNorm with ``normalized_shape=x.shape[1:]`` and no |
| affine parameters. Delegating to the native functional keeps the fast fused |
| CUDA path and the same opmath accumulation behavior as ``torch.nn.RMSNorm``. |
| """ |
|
|
| if x.dim() < 2: |
| return x |
| normalized_shape = tuple(int(dim) for dim in x.shape[1:]) |
| return F.rms_norm(x, normalized_shape, None, eps) |
|
|
|
|
| class GlobalRMSNorm(nn.Module): |
| """RMSNorm across all dims except batch — sphere projection for NCHW tensors. |
| |
| Unlike :class:`ChannelWiseRMSNorm` (which normalizes per spatial position |
| over channels), this normalizes the *entire* feature volume jointly, |
| projecting each sample onto a hypersphere. No learnable parameters. |
| """ |
|
|
| def __init__(self, eps: float = 1e-6) -> None: |
| super().__init__() |
| self._eps: float = float(eps) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return global_rms_norm(x, eps=self._eps) |
|
|
|
|
| class LayerNorm2d(nn.LayerNorm): |
| """Channel-wise LayerNorm using native ``F.layer_norm`` on a reshaped view. |
| |
| - Normalizes over channels only for each spatial location (B, h, w). |
| - Weight and bias follow the base class semantics (shape [C]). |
| |
| Notes on precision |
| - ``F.layer_norm`` calls the native LayerNorm kernel which accumulates in |
| ``opmath`` dtype (float32 for float16/bfloat16), then casts back. |
| """ |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if x.dim() < 3: |
| return super().forward(x) |
| B, C = x.shape[:2] |
| spatial = x.shape[2:] |
| x_view = x.permute(0, *range(2, x.dim()), 1).contiguous().view(-1, C) |
| y = F.layer_norm(x_view, (C,), self.weight, self.bias, self.eps) |
| y = y.view(B, *spatial, C).permute(0, x.dim() - 1, *range(1, x.dim() - 1)) |
| return y.contiguous() |
|
|
|
|
| def row_norm(W: Tensor, eps: float = 1e-6) -> Tensor: |
| """Row-normalise weight matrices along the last dimension. |
| |
| Precision and performance |
| - Accumulates the squared sum in float32 without materializing a full fp32 |
| copy of ``W`` via ``sum(..., dtype=torch.float32)``. |
| - Uses ``rsqrt`` and clamps the inverse norm via ``clamp_max(1/eps)`` to |
| match ``clamp_min(eps)`` on the denominator. |
| - Scales in the input dtype for throughput; callers relying on exact |
| float32 scaling should cast explicitly. |
| """ |
| |
| ss = torch.sum(torch.square(W), dim=-1, keepdim=True, dtype=torch.float32) |
| inv = torch.rsqrt(ss).clamp_max(1.0 / float(eps)) |
| return W * inv.to(dtype=W.dtype) |
|
|