File size: 8,695 Bytes
1b703d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 | from __future__ import annotations
from collections.abc import Sequence
import torch
from torch import Tensor, nn
from torch.nn import functional as F
__all__ = [
"ChannelWiseRMSNorm",
"GlobalRMSNorm",
"GroupNormF32",
"LayerNorm",
"LayerNorm2d",
"RMSNorm",
"global_rms_norm",
"row_norm",
]
_HALF_PRECISION_DTYPES: tuple[torch.dtype, ...] = (torch.float16, torch.bfloat16)
def _cast_to_float32(x: Tensor) -> tuple[Tensor, torch.dtype]:
"""Return tensor cast to fp32 for compute along with the original dtype."""
dtype = x.dtype
if dtype in _HALF_PRECISION_DTYPES:
return x.float(), dtype
return x, dtype
def _restore_dtype(x: Tensor, dtype: torch.dtype) -> Tensor:
return x if x.dtype == dtype else x.to(dtype)
class RMSNorm(nn.Module):
"""Thin wrapper around ``torch.nn.RMSNorm`` that preserves our API.
- Keeps an ``_eps`` attribute used by tests.
- Maps ``affine`` -> ``elementwise_affine``.
- Delegates all compute to the native implementation.
Notes on precision
- PyTorch ≥ 2.8 computes RMSNorm reductions in ``opmath`` dtype
(float32 for float16/bfloat16) internally, then restores the input dtype.
"""
def __init__(self, dim: int, eps: float = 1e-6, affine: bool = True) -> None:
super().__init__()
self._eps: float = float(eps)
self._impl: nn.RMSNorm = nn.RMSNorm(
dim, eps=self._eps, elementwise_affine=affine
)
self._dim: int = int(dim)
@property
def weight(self) -> Tensor | None: # expose for tests/compat
return self._impl.weight
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
"""Apply RMSNorm while avoiding dtype-mismatch warnings under AMP.
When inputs are bfloat16/float16 under autocast and the stored affine
weight is float32 (common when model weights remain FP32), PyTorch emits
a warning about mismatched dtypes and disables the fused path.
We pass a view of the weight cast to the input dtype into the functional
RMSNorm to enable the fused implementation without changing the
parameter storage dtype (which remains FP32 for stability).
"""
# Prefer functional to control the weight dtype for the kernel
w: Tensor | None = self._impl.weight
w_cast = w.to(dtype=x.dtype) if w is not None else None
# Bias is not present in RMSNorm; functional takes (input, shape, weight, eps)
return F.rms_norm(x, (self._dim,), w_cast, self._eps)
class LayerNorm(nn.LayerNorm):
"""Thin wrapper over ``torch.nn.LayerNorm`` with an ``_eps`` attribute.
Notes on precision
- Native LayerNorm kernels accumulate statistics in ``opmath`` dtype
(float32 for float16/bfloat16) before casting results back.
"""
def __init__(
self,
normalized_shape: int | Sequence[int],
eps: float = 1e-6,
elementwise_affine: bool = True,
) -> None:
shape: int | list[int]
match normalized_shape:
case int() as dim:
shape = dim
case _:
shape = [int(v) for v in normalized_shape]
super().__init__(shape, eps=eps, elementwise_affine=elementwise_affine)
self._eps: float = float(eps)
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
# Delegate to native LayerNorm
return super().forward(x)
# Prefer numerically stable GroupNormF32 below.
class GroupNormF32(nn.GroupNorm):
"""Thin wrapper over ``torch.nn.GroupNorm`` with an ``_eps`` attribute.
Notes on precision
- Native GroupNorm uses ``opmath`` accumulation (float32 for
float16/bfloat16) for statistics and fused scale/bias math; results
are cast back to the input dtype.
- Despite the class name, this wrapper does not force a cast; it
delegates to the native implementation.
"""
def __init__(
self,
num_groups: int,
num_channels: int,
eps: float = 1e-6,
affine: bool = True,
) -> None:
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self._eps: float = float(eps)
class ChannelWiseRMSNorm(nn.Module):
"""Channel-wise RMSNorm for NCHW tensors (fast NCHW path).
- Normalizes across channels per spatial position without reshaping, using
a float32 reduction for numerical stability and keeping elementwise ops
in input dtype for throughput.
- Supports optional per-channel affine ``weight`` and ``bias``.
"""
def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
super().__init__()
self.channels: int = int(channels)
self._eps: float = float(eps)
self.affine: bool = bool(affine)
if self.affine:
self.weight = nn.Parameter(torch.ones(self.channels))
self.bias = nn.Parameter(torch.zeros(self.channels))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
if x.dim() < 2:
return x
C = x.size(1)
if self.channels != C:
raise ValueError(f"ChannelWiseRMSNorm expected C={self.channels}, got {C}")
# Keep only the reductions in fp32; scale/apply in the input dtype.
ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
inv_rms = torch.rsqrt(ms + self._eps) # float32
y = x * inv_rms.to(dtype=x.dtype)
if self.affine and self.weight is not None:
shape = (1, -1) + (1,) * (x.dim() - 2)
y = y * self.weight.view(shape).to(dtype=x.dtype)
if self.bias is not None:
y = y + self.bias.view(shape).to(dtype=x.dtype)
return y
def global_rms_norm(x: Tensor, eps: float = 1e-6) -> Tensor:
"""Project each sample to unit RMS across all non-batch dimensions.
This is equivalent to RMSNorm with ``normalized_shape=x.shape[1:]`` and no
affine parameters. Delegating to the native functional keeps the fast fused
CUDA path and the same opmath accumulation behavior as ``torch.nn.RMSNorm``.
"""
if x.dim() < 2:
return x
normalized_shape = tuple(int(dim) for dim in x.shape[1:])
return F.rms_norm(x, normalized_shape, None, eps)
class GlobalRMSNorm(nn.Module):
"""RMSNorm across all dims except batch — sphere projection for NCHW tensors.
Unlike :class:`ChannelWiseRMSNorm` (which normalizes per spatial position
over channels), this normalizes the *entire* feature volume jointly,
projecting each sample onto a hypersphere. No learnable parameters.
"""
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
self._eps: float = float(eps)
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return global_rms_norm(x, eps=self._eps)
class LayerNorm2d(nn.LayerNorm):
"""Channel-wise LayerNorm using native ``F.layer_norm`` on a reshaped view.
- Normalizes over channels only for each spatial location (B, h, w).
- Weight and bias follow the base class semantics (shape [C]).
Notes on precision
- ``F.layer_norm`` calls the native LayerNorm kernel which accumulates in
``opmath`` dtype (float32 for float16/bfloat16), then casts back.
"""
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
if x.dim() < 3:
return super().forward(x)
B, C = x.shape[:2]
spatial = x.shape[2:]
x_view = x.permute(0, *range(2, x.dim()), 1).contiguous().view(-1, C)
y = F.layer_norm(x_view, (C,), self.weight, self.bias, self.eps)
y = y.view(B, *spatial, C).permute(0, x.dim() - 1, *range(1, x.dim() - 1))
return y.contiguous()
def row_norm(W: Tensor, eps: float = 1e-6) -> Tensor:
"""Row-normalise weight matrices along the last dimension.
Precision and performance
- Accumulates the squared sum in float32 without materializing a full fp32
copy of ``W`` via ``sum(..., dtype=torch.float32)``.
- Uses ``rsqrt`` and clamps the inverse norm via ``clamp_max(1/eps)`` to
match ``clamp_min(eps)`` on the denominator.
- Scales in the input dtype for throughput; callers relying on exact
float32 scaling should cast explicitly.
"""
# Sum of squares in fp32 for stability
ss = torch.sum(torch.square(W), dim=-1, keepdim=True, dtype=torch.float32)
inv = torch.rsqrt(ss).clamp_max(1.0 / float(eps)) # float32
return W * inv.to(dtype=W.dtype)
|