File size: 8,695 Bytes
1b703d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
from __future__ import annotations

from collections.abc import Sequence

import torch
from torch import Tensor, nn
from torch.nn import functional as F

__all__ = [
    "ChannelWiseRMSNorm",
    "GlobalRMSNorm",
    "GroupNormF32",
    "LayerNorm",
    "LayerNorm2d",
    "RMSNorm",
    "global_rms_norm",
    "row_norm",
]


_HALF_PRECISION_DTYPES: tuple[torch.dtype, ...] = (torch.float16, torch.bfloat16)


def _cast_to_float32(x: Tensor) -> tuple[Tensor, torch.dtype]:
    """Return tensor cast to fp32 for compute along with the original dtype."""
    dtype = x.dtype
    if dtype in _HALF_PRECISION_DTYPES:
        return x.float(), dtype
    return x, dtype


def _restore_dtype(x: Tensor, dtype: torch.dtype) -> Tensor:
    return x if x.dtype == dtype else x.to(dtype)


class RMSNorm(nn.Module):
    """Thin wrapper around ``torch.nn.RMSNorm`` that preserves our API.

    - Keeps an ``_eps`` attribute used by tests.
    - Maps ``affine`` -> ``elementwise_affine``.
    - Delegates all compute to the native implementation.

    Notes on precision
    - PyTorch ≥ 2.8 computes RMSNorm reductions in ``opmath`` dtype
      (float32 for float16/bfloat16) internally, then restores the input dtype.
    """

    def __init__(self, dim: int, eps: float = 1e-6, affine: bool = True) -> None:
        super().__init__()
        self._eps: float = float(eps)
        self._impl: nn.RMSNorm = nn.RMSNorm(
            dim, eps=self._eps, elementwise_affine=affine
        )
        self._dim: int = int(dim)

    @property
    def weight(self) -> Tensor | None:  # expose for tests/compat
        return self._impl.weight

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        """Apply RMSNorm while avoiding dtype-mismatch warnings under AMP.

        When inputs are bfloat16/float16 under autocast and the stored affine
        weight is float32 (common when model weights remain FP32), PyTorch emits
        a warning about mismatched dtypes and disables the fused path.

        We pass a view of the weight cast to the input dtype into the functional
        RMSNorm to enable the fused implementation without changing the
        parameter storage dtype (which remains FP32 for stability).
        """
        # Prefer functional to control the weight dtype for the kernel
        w: Tensor | None = self._impl.weight
        w_cast = w.to(dtype=x.dtype) if w is not None else None
        # Bias is not present in RMSNorm; functional takes (input, shape, weight, eps)
        return F.rms_norm(x, (self._dim,), w_cast, self._eps)


class LayerNorm(nn.LayerNorm):
    """Thin wrapper over ``torch.nn.LayerNorm`` with an ``_eps`` attribute.

    Notes on precision
    - Native LayerNorm kernels accumulate statistics in ``opmath`` dtype
      (float32 for float16/bfloat16) before casting results back.
    """

    def __init__(
        self,
        normalized_shape: int | Sequence[int],
        eps: float = 1e-6,
        elementwise_affine: bool = True,
    ) -> None:
        shape: int | list[int]
        match normalized_shape:
            case int() as dim:
                shape = dim
            case _:
                shape = [int(v) for v in normalized_shape]
        super().__init__(shape, eps=eps, elementwise_affine=elementwise_affine)
        self._eps: float = float(eps)

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        # Delegate to native LayerNorm
        return super().forward(x)


# Prefer numerically stable GroupNormF32 below.


class GroupNormF32(nn.GroupNorm):
    """Thin wrapper over ``torch.nn.GroupNorm`` with an ``_eps`` attribute.

    Notes on precision
    - Native GroupNorm uses ``opmath`` accumulation (float32 for
      float16/bfloat16) for statistics and fused scale/bias math; results
      are cast back to the input dtype.
    - Despite the class name, this wrapper does not force a cast; it
      delegates to the native implementation.
    """

    def __init__(
        self,
        num_groups: int,
        num_channels: int,
        eps: float = 1e-6,
        affine: bool = True,
    ) -> None:
        super().__init__(num_groups, num_channels, eps=eps, affine=affine)
        self._eps: float = float(eps)


class ChannelWiseRMSNorm(nn.Module):
    """Channel-wise RMSNorm for NCHW tensors (fast NCHW path).

    - Normalizes across channels per spatial position without reshaping, using
      a float32 reduction for numerical stability and keeping elementwise ops
      in input dtype for throughput.
    - Supports optional per-channel affine ``weight`` and ``bias``.
    """

    def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
        super().__init__()
        self.channels: int = int(channels)
        self._eps: float = float(eps)
        self.affine: bool = bool(affine)
        if self.affine:
            self.weight = nn.Parameter(torch.ones(self.channels))
            self.bias = nn.Parameter(torch.zeros(self.channels))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        if x.dim() < 2:
            return x
        C = x.size(1)
        if self.channels != C:
            raise ValueError(f"ChannelWiseRMSNorm expected C={self.channels}, got {C}")
        # Keep only the reductions in fp32; scale/apply in the input dtype.
        ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
        inv_rms = torch.rsqrt(ms + self._eps)  # float32
        y = x * inv_rms.to(dtype=x.dtype)
        if self.affine and self.weight is not None:
            shape = (1, -1) + (1,) * (x.dim() - 2)
            y = y * self.weight.view(shape).to(dtype=x.dtype)
            if self.bias is not None:
                y = y + self.bias.view(shape).to(dtype=x.dtype)
        return y


def global_rms_norm(x: Tensor, eps: float = 1e-6) -> Tensor:
    """Project each sample to unit RMS across all non-batch dimensions.

    This is equivalent to RMSNorm with ``normalized_shape=x.shape[1:]`` and no
    affine parameters. Delegating to the native functional keeps the fast fused
    CUDA path and the same opmath accumulation behavior as ``torch.nn.RMSNorm``.
    """

    if x.dim() < 2:
        return x
    normalized_shape = tuple(int(dim) for dim in x.shape[1:])
    return F.rms_norm(x, normalized_shape, None, eps)


class GlobalRMSNorm(nn.Module):
    """RMSNorm across all dims except batch — sphere projection for NCHW tensors.

    Unlike :class:`ChannelWiseRMSNorm` (which normalizes per spatial position
    over channels), this normalizes the *entire* feature volume jointly,
    projecting each sample onto a hypersphere.  No learnable parameters.
    """

    def __init__(self, eps: float = 1e-6) -> None:
        super().__init__()
        self._eps: float = float(eps)

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        return global_rms_norm(x, eps=self._eps)


class LayerNorm2d(nn.LayerNorm):
    """Channel-wise LayerNorm using native ``F.layer_norm`` on a reshaped view.

    - Normalizes over channels only for each spatial location (B, h, w).
    - Weight and bias follow the base class semantics (shape [C]).

    Notes on precision
    - ``F.layer_norm`` calls the native LayerNorm kernel which accumulates in
      ``opmath`` dtype (float32 for float16/bfloat16), then casts back.
    """

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        if x.dim() < 3:
            return super().forward(x)
        B, C = x.shape[:2]
        spatial = x.shape[2:]
        x_view = x.permute(0, *range(2, x.dim()), 1).contiguous().view(-1, C)
        y = F.layer_norm(x_view, (C,), self.weight, self.bias, self.eps)
        y = y.view(B, *spatial, C).permute(0, x.dim() - 1, *range(1, x.dim() - 1))
        return y.contiguous()


def row_norm(W: Tensor, eps: float = 1e-6) -> Tensor:
    """Row-normalise weight matrices along the last dimension.

    Precision and performance
    - Accumulates the squared sum in float32 without materializing a full fp32
      copy of ``W`` via ``sum(..., dtype=torch.float32)``.
    - Uses ``rsqrt`` and clamps the inverse norm via ``clamp_max(1/eps)`` to
      match ``clamp_min(eps)`` on the denominator.
    - Scales in the input dtype for throughput; callers relying on exact
      float32 scaling should cast explicitly.
    """
    # Sum of squares in fp32 for stability
    ss = torch.sum(torch.square(W), dim=-1, keepdim=True, dtype=torch.float32)
    inv = torch.rsqrt(ss).clamp_max(1.0 / float(eps))  # float32
    return W * inv.to(dtype=W.dtype)