| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | import torch |
| | import torch.nn.functional as F |
| | import triton |
| | import triton.language as tl |
| |
|
| |
|
| | |
| | @triton.jit |
| | def srms_norm_fw(X, Y, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr): |
| | |
| | row = tl.program_id(0) |
| | cols = tl.arange(0, BLOCK_SIZE_N) |
| | mask = cols < N |
| |
|
| | |
| | x_ptrs = X + row * stride + cols |
| | x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) |
| |
|
| | x_zm = tl.where(mask, x, 0.0) |
| |
|
| | x_var = tl.sum(x_zm * x_zm, axis=0) / N |
| | rstd = 1.0 / tl.sqrt(x_var + eps) |
| |
|
| | |
| | y = x_zm * rstd |
| | tl.store(V + row, rstd) |
| |
|
| | y_ptrs = Y + row * stride + cols |
| | tl.store(y_ptrs, y, mask=mask) |
| |
|
| |
|
| | |
| | |
| | @triton.jit |
| | def srms_norm_bwd_dx_fused( |
| | DX, DY, |
| | X, V, |
| | stride, N, |
| | |
| | BLOCK_SIZE_N: tl.constexpr, |
| | ): |
| | |
| |
|
| | |
| | row = tl.program_id(0) |
| | cols = tl.arange(0, BLOCK_SIZE_N) |
| | mask = cols < N |
| |
|
| | |
| | x_ptrs = X + row * stride + cols |
| | dy_ptrs = DY + row * stride + cols |
| |
|
| | |
| | x = tl.load(x_ptrs, mask=mask, other=0) |
| | dy = tl.load(dy_ptrs, mask=mask, other=0) |
| | rstd = tl.load(V + row) |
| |
|
| | |
| | xhat = x * rstd |
| | wdy = dy |
| |
|
| | xhat = tl.where(mask, xhat, 0.) |
| | wdy = tl.where(mask, wdy, 0.) |
| | mean1 = tl.sum(xhat * wdy, axis=0) / N |
| | dx = (wdy - (xhat * mean1)) * rstd |
| |
|
| | |
| | mask = cols < N |
| | dx_ptrs = DX + row * stride + cols |
| | tl.store(dx_ptrs, dx, mask=mask) |
| |
|
| |
|
| | class _SrmsNorm(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, x, eps): |
| | |
| | if x.dtype == torch.float16: |
| | eps = max(eps, 1.6e-5) |
| |
|
| | |
| | y = torch.empty_like(x) |
| |
|
| | |
| | x_arg = x.reshape(-1, x.shape[-1]) |
| | M, N = x_arg.shape |
| |
|
| | |
| | rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) |
| |
|
| | |
| | MAX_FUSED_SIZE = 65536 // x.element_size() |
| | BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) |
| | if N > BLOCK_SIZE_N: |
| | raise RuntimeError( |
| | "This layer norm doesn't support feature dim >= 64KB.") |
| |
|
| | if not x_arg.is_contiguous() or not y.is_contiguous(): |
| | x_arg = x_arg.contiguous() |
| | y = y.contiguous() |
| |
|
| | |
| | num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16) |
| |
|
| | |
| | |
| | srms_norm_fw[(M,)]( |
| | x_arg, y, rstd, |
| | x_arg.stride(0), |
| | N, |
| | eps, |
| | num_warps=num_warps, |
| | BLOCK_SIZE_N=BLOCK_SIZE_N, |
| | ) |
| | |
| |
|
| | ctx.save_for_backward(x, rstd) |
| | ctx.BLOCK_SIZE_N = BLOCK_SIZE_N |
| | ctx.num_warps = num_warps |
| |
|
| | return y.reshape_as(x) |
| |
|
| | @staticmethod |
| | def backward( |
| | ctx, dy |
| | ): |
| | x, rstd = ctx.saved_tensors |
| |
|
| | |
| | |
| | x = x.reshape(-1, x.size(-1)) |
| | M, N = x.size() |
| |
|
| | |
| | GROUP_SIZE_M = 32 |
| | if N <= 8192: |
| | GROUP_SIZE_M = 64 |
| | if N <= 4096: |
| | GROUP_SIZE_M = 96 |
| | if N <= 2048: |
| | GROUP_SIZE_M = 128 |
| | if N <= 1024: |
| | GROUP_SIZE_M = 256 |
| |
|
| | if dy.dtype == torch.float32: |
| | GROUP_SIZE_M = GROUP_SIZE_M // 2 |
| |
|
| | |
| | dy = dy.contiguous() |
| | dx = torch.empty_like(dy) |
| |
|
| | |
| | |
| | assert ( |
| | dy.numel() == x.numel() |
| | ), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm" |
| |
|
| | |
| | |
| | num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16) |
| |
|
| | |
| | srms_norm_bwd_dx_fused[(M,)]( |
| | dx, dy, x, |
| | rstd, |
| | x.stride(0), |
| | N, |
| | BLOCK_SIZE_N=ctx.BLOCK_SIZE_N, |
| | num_warps=num_warps |
| | ) |
| | |
| |
|
| | dx = dx.reshape_as(dy) |
| | return dx, None, None |
| |
|
| |
|
| | class SimpleRMSNorm(torch.nn.Module): |
| |
|
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.dim = dim |
| |
|
| | def forward(self, x): |
| | return _SrmsNorm.apply(x, self.eps) |
| |
|