| | """Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py""" |
| |
|
| | import math |
| |
|
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | import triton |
| | import triton.language as tl |
| |
|
| | from torch.distributed._tensor import Partial, Replicate, Shard |
| | from torch.distributed._tensor.experimental import local_map |
| | from torch._utils import _get_available_device_type, _get_device_module |
| |
|
| |
|
| | def get_device_info(): |
| | device_type = _get_available_device_type() |
| |
|
| | if device_type is None: |
| | device_type = "cuda" |
| |
|
| | device_module = _get_device_module(device_type) |
| | return device_type, device_module |
| |
|
| | device_type, device_module = get_device_info() |
| |
|
| | def build_norm(norm_type: str, dim: int, eps: float = 1e-6): |
| | """ |
| | Builds the specified normalization layer based on the norm_type. |
| | |
| | Args: |
| | norm_type (str): The type of normalization layer to build. |
| | Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm |
| | dim (int): The dimension of the normalization layer. |
| | eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. |
| | |
| | Returns: |
| | The built normalization layer. |
| | |
| | Raises: |
| | NotImplementedError: If an unknown norm_type is provided. |
| | """ |
| | norm_type = norm_type.lower() |
| |
|
| | if norm_type == "layernorm": |
| | return nn.LayerNorm(dim, eps=eps, bias=False) |
| | elif norm_type == "np_layernorm": |
| | return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) |
| | elif norm_type == "rmsnorm": |
| | return RMSNorm(dim, eps=eps) |
| | elif norm_type == "fused_rmsnorm": |
| | return FusedRMSNorm(dim, eps=eps) |
| | else: |
| | raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") |
| |
|
| |
|
| | class FusedRMSNorm(nn.Module): |
| | """Fused RMS Norm, wraps a fused Triton Kernel""" |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | eps: float = 1e-6, |
| | ): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| | self.fused_rms_norm_fn = fused_rms_norm_fn |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """leverages Triton Fused RMS Norm kernel""" |
| | return self.fused_rms_norm_fn( |
| | x, |
| | self.weight, |
| | eps=self.eps, |
| | ) |
| |
|
| | def reset_parameters(self): |
| | torch.nn.init.ones_(self.weight) |
| |
|
| |
|
| | class RMSNorm(torch.nn.Module): |
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | """ |
| | Initialize the RMSNorm normalization layer. |
| | |
| | Args: |
| | dim (int): The dimension of the input tensor. |
| | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
| | |
| | Attributes: |
| | eps (float): A small value added to the denominator for numerical stability. |
| | weight (nn.Parameter): Learnable scaling parameter. |
| | |
| | """ |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def _norm(self, x): |
| | """ |
| | Apply the RMSNorm normalization to the input tensor. |
| | |
| | Args: |
| | x (torch.Tensor): The input tensor. |
| | |
| | Returns: |
| | torch.Tensor: The normalized tensor. |
| | |
| | """ |
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| |
|
| | def forward(self, x): |
| | """ |
| | Forward pass through the RMSNorm layer. |
| | |
| | Args: |
| | x (torch.Tensor): The input tensor. |
| | |
| | Returns: |
| | torch.Tensor: The output tensor after applying RMSNorm. |
| | |
| | """ |
| | output = self._norm(x.float()).type_as(x) |
| | return output * self.weight |
| |
|
| | def reset_parameters(self): |
| | torch.nn.init.ones_(self.weight) |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=1), |
| | triton.Config({}, num_warps=2), |
| | triton.Config({}, num_warps=4), |
| | triton.Config({}, num_warps=8), |
| | triton.Config({}, num_warps=16), |
| | triton.Config({}, num_warps=32), |
| | ], |
| | key=["N"], |
| | ) |
| | @triton.jit |
| | def _rms_norm_fwd_kernel( |
| | X, |
| | stride_x, |
| | Y, |
| | stride_y, |
| | W, |
| | Rstd, |
| | eps, |
| | M, |
| | N, |
| | block_N: tl.constexpr, |
| | ): |
| | row = tl.program_id(0) |
| | cols = tl.arange(0, block_N) |
| |
|
| | |
| | mask = cols < N |
| | x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) |
| | w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) |
| |
|
| | |
| | xbar = tl.where(cols < N, x, 0.0) |
| | var = tl.sum(xbar * xbar, axis=0) / N |
| | rstd = 1 / tl.sqrt(var + eps) |
| |
|
| | |
| | tl.store(Rstd + row, rstd) |
| |
|
| | |
| | x_hat = x * rstd |
| | y = x_hat * w |
| |
|
| | |
| | tl.store(Y + row * stride_y + cols, y, mask=mask) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=1), |
| | triton.Config({}, num_warps=2), |
| | triton.Config({}, num_warps=4), |
| | triton.Config({}, num_warps=8), |
| | triton.Config({}, num_warps=16), |
| | triton.Config({}, num_warps=32), |
| | ], |
| | key=["N"], |
| | ) |
| | @triton.jit |
| | def _rms_norm_bwd_kernel_sm( |
| | X, |
| | stride_x, |
| | W, |
| | DY, |
| | stride_dy, |
| | DX, |
| | stride_dx, |
| | Rstd, |
| | DW, |
| | eps, |
| | M, |
| | N, |
| | rows_per_program, |
| | block_N: tl.constexpr, |
| | ): |
| | row_block_id = tl.program_id(0) |
| | row_start = row_block_id * rows_per_program |
| | cols = tl.arange(0, block_N) |
| | mask = cols < N |
| |
|
| | |
| | w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) |
| |
|
| | |
| | dw = tl.zeros((block_N,), dtype=tl.float32) |
| |
|
| | row_end = min(row_start + rows_per_program, M) |
| | for row in range(row_start, row_end): |
| | |
| | x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) |
| | dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) |
| | rstd = tl.load(Rstd + row) |
| |
|
| | |
| | x_hat = x * rstd |
| | wdy = w * dy |
| | dw += dy * x_hat |
| | c1 = tl.sum(x_hat * wdy, axis=0) / N |
| | dx = (wdy - x_hat * c1) * rstd |
| |
|
| | |
| | tl.store(DX + row * stride_dx + cols, dx, mask=mask) |
| |
|
| | |
| | tl.store(DW + row_block_id * N + cols, dw, mask=mask) |
| |
|
| |
|
| | class TritonFusedRMSNorm(torch.autograd.Function): |
| | @partial( |
| | local_map, |
| | out_placements=[Shard(1)], |
| | in_placements=(None, [Shard(1)], [Replicate()], None), |
| | ) |
| | @staticmethod |
| | def forward(ctx, x, weight, eps): |
| | x_shape_start = x.shape |
| |
|
| | |
| | x = x.view(-1, x.shape[-1]) |
| | if x.stride(-1) != 1: |
| | x = x.contiguous() |
| | if weight.stride(-1) != 1: |
| | weight = weight.contiguous() |
| |
|
| | M, N = x.shape |
| | y = torch.empty_like(x) |
| | rstd = torch.empty((M,), dtype=torch.float32, device=x.device) |
| |
|
| | max_size = 65536 // x.element_size() |
| | block_N = min(max_size, triton.next_power_of_2(N)) |
| |
|
| | if N > block_N: |
| | raise ValueError(f"N {N} must be <= {block_N=}") |
| |
|
| | grid = lambda meta: (M,) |
| | _rms_norm_fwd_kernel[grid]( |
| | x, |
| | x.stride(0), |
| | y, |
| | y.stride(0), |
| | weight, |
| | rstd, |
| | eps, |
| | M, |
| | N, |
| | block_N, |
| | ) |
| |
|
| | ctx.eps = eps |
| | ctx.save_for_backward(x, weight, rstd) |
| | ctx.x_shape_start = x_shape_start |
| |
|
| | y = y.reshape(x_shape_start) |
| | return y |
| |
|
| | @partial( |
| | local_map, |
| | out_placements=([Shard(1)], [Partial()], None), |
| | in_placements=(None, [Shard(1)]), |
| | ) |
| | @staticmethod |
| | def backward(ctx, dy): |
| | x, weight, rstd = ctx.saved_tensors |
| | eps = ctx.eps |
| | x_shape_start = ctx.x_shape_start |
| |
|
| | |
| | dy = dy.view(-1, dy.shape[-1]) |
| | if dy.stride(-1) != 1: |
| | dy = dy.contiguous() |
| |
|
| | M, N = dy.shape |
| | dx = torch.empty_like(x) |
| |
|
| | sm_count = device_module.get_device_properties(x.device).multi_processor_count |
| | _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) |
| |
|
| | max_size = 65536 // x.element_size() |
| | block_N = min(max_size, triton.next_power_of_2(N)) |
| | rows_per_sm = math.ceil(M / sm_count) |
| |
|
| | if N > block_N: |
| | raise ValueError(f"N {N} must be <= {block_N=}") |
| |
|
| | grid = lambda meta: (sm_count,) |
| | _rms_norm_bwd_kernel_sm[grid]( |
| | x, |
| | x.stride(0), |
| | weight, |
| | dy, |
| | dy.stride(0), |
| | dx, |
| | dx.stride(0), |
| | rstd, |
| | _dw, |
| | eps, |
| | M, |
| | N, |
| | rows_per_sm, |
| | block_N, |
| | ) |
| | dw = _dw.sum(0).to(weight.dtype) |
| | dx = dx.view(x_shape_start) |
| | return dx, dw, None |
| |
|
| |
|
| | |
| | def fused_rms_norm_fn( |
| | x, |
| | weight, |
| | eps=1e-6, |
| | ): |
| | return TritonFusedRMSNorm.apply( |
| | x, |
| | weight, |
| | eps, |
| | ) |