| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import triton |
| | import triton.language as tl |
| |
|
| | from fla.utils import input_guard |
| |
|
| | BT_LIST = [8, 16, 32, 64, 128] |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=num_warps) |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | ], |
| | key=['D'] |
| | ) |
| | @triton.jit |
| | def l2norm_fwd_kernel1( |
| | x, |
| | y, |
| | D, |
| | BD: tl.constexpr, |
| | eps, |
| | ): |
| | i_t = tl.program_id(0) |
| | x += i_t * D |
| | y += i_t * D |
| | |
| | cols = tl.arange(0, BD) |
| | mask = cols < D |
| | b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) |
| | b_var = tl.sum(b_x * b_x, axis=0) |
| | b_rstd = 1 / tl.sqrt(b_var + eps) |
| | |
| | |
| | b_y = b_x * b_rstd |
| | tl.store(y + cols, b_y, mask=mask) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=num_warps) |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | ], |
| | key=['D'] |
| | ) |
| | @triton.jit |
| | def l2norm_bwd_kernel1( |
| | x, |
| | dy, |
| | dx, |
| | eps, |
| | D, |
| | BD: tl.constexpr, |
| | ): |
| | i_t = tl.program_id(0) |
| | x += i_t * D |
| | dx += i_t * D |
| | dy += i_t * D |
| |
|
| | |
| | cols = tl.arange(0, BD) |
| | mask = cols < D |
| | b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) |
| | b_var = tl.sum(b_x * b_x) |
| | b_rstd = 1 / tl.sqrt(b_var + eps) |
| | b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32) |
| | b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x) * (1 / (b_var+eps)) * b_rstd * b_x |
| | tl.store(dx + cols, b_dx, mask=mask) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BT': BT}, num_warps=num_warps) |
| | for num_warps in [1, 2, 4, 8, 16] |
| | for BT in BT_LIST |
| | ], |
| | key=['D', 'NB'] |
| | ) |
| | @triton.jit |
| | def l2norm_fwd_kernel( |
| | x, |
| | y, |
| | eps, |
| | NB: tl.constexpr, |
| | T: tl.constexpr, |
| | D: tl.constexpr, |
| | BT: tl.constexpr, |
| | BD: tl.constexpr, |
| | ): |
| | i_t = tl.program_id(0) |
| | p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) |
| | b_var = tl.sum(b_x * b_x, axis=1) |
| | b_y = b_x / tl.sqrt(b_var + eps)[:, None] |
| | p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BT': BT}, num_warps=num_warps) |
| | for num_warps in [1, 2, 4, 8, 16] |
| | for BT in BT_LIST |
| | ], |
| | key=['D', 'NB'] |
| | ) |
| | @triton.jit |
| | def l2norm_bwd_kernel( |
| | x, |
| | dy, |
| | dx, |
| | eps, |
| | NB: tl.constexpr, |
| | T: tl.constexpr, |
| | D: tl.constexpr, |
| | BT: tl.constexpr, |
| | BD: tl.constexpr, |
| | ): |
| | i_t = tl.program_id(0) |
| | p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) |
| | b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) |
| | b_var = tl.sum(b_x * b_x, axis=1)[:, None] |
| | b_rstd = 1 / tl.sqrt(b_var + eps) |
| | b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x, axis=1)[:, None] / (b_var+eps) * b_rstd * b_x |
| | tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| |
|
| | def l2norm_fwd( |
| | x: torch.Tensor, |
| | eps: float = 1e-6, |
| | output_dtype: Optional[torch.dtype] = None |
| | ): |
| | x_shape_og = x.shape |
| | x = x.view(-1, x.shape[-1]) |
| | |
| | if output_dtype is None: |
| | y = torch.empty_like(x) |
| | else: |
| | y = torch.empty_like(x, dtype=output_dtype) |
| | assert y.stride(-1) == 1 |
| | T, D = x.shape[0], x.shape[-1] |
| | |
| | |
| | MAX_FUSED_SIZE = 65536 // x.element_size() |
| | BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) |
| | if D > BD: |
| | raise RuntimeError("This layer doesn't support feature dim >= 64KB.") |
| |
|
| | if D <= 512: |
| | NB = triton.cdiv(T, 2048) |
| | def grid(meta): return (triton.cdiv(T, meta['BT']), ) |
| | l2norm_fwd_kernel[grid]( |
| | x, |
| | y, |
| | eps, |
| | NB=NB, |
| | T=T, |
| | D=D, |
| | BD=BD, |
| | ) |
| | else: |
| | l2norm_fwd_kernel1[(T,)]( |
| | x, |
| | y, |
| | eps=eps, |
| | D=D, |
| | BD=BD, |
| | ) |
| |
|
| | return y.view(x_shape_og) |
| |
|
| |
|
| | def l2norm_bwd( |
| | x: torch.Tensor, |
| | dy: torch.Tensor, |
| | eps: float = 1e-5 |
| | ): |
| | x_shape_og = x.shape |
| | x = x.view(-1, dy.shape[-1]) |
| | dy = dy.view(-1, dy.shape[-1]) |
| | assert dy.shape == x.shape |
| | |
| | dx = torch.empty_like(x) |
| | T, D = x.shape[0], x.shape[-1] |
| | |
| | |
| | MAX_FUSED_SIZE = 65536 // x.element_size() |
| | BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) |
| | if D > BD: |
| | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
| | |
| |
|
| | if D <= 512: |
| | NB = triton.cdiv(T, 2048) |
| | def grid(meta): return (triton.cdiv(T, meta['BT']), ) |
| | l2norm_bwd_kernel[grid]( |
| | x, |
| | dy, |
| | dx, |
| | eps=eps, |
| | NB=NB, |
| | T=T, |
| | D=D, |
| | BD=BD, |
| | ) |
| | else: |
| | l2norm_bwd_kernel1[(T,)]( |
| | x, |
| | dy, |
| | dx, |
| | eps=eps, |
| | D=D, |
| | BD=BD, |
| | ) |
| |
|
| | return dx.view(x_shape_og) |
| |
|
| |
|
| | class L2NormFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | @input_guard |
| | def forward( |
| | ctx, |
| | x, |
| | eps=1e-6, |
| | output_dtype=None |
| | ): |
| | y = l2norm_fwd(x, eps, output_dtype) |
| | ctx.eps = eps |
| | ctx.x_dtype = x.dtype |
| | ctx.save_for_backward(x) |
| | return y |
| |
|
| | @staticmethod |
| | @input_guard |
| | def backward(ctx, dy): |
| | x, = ctx.saved_tensors |
| | dx = l2norm_bwd(x, dy, ctx.eps) |
| | return dx, None, None |
| |
|
| |
|
| | def l2norm( |
| | x: torch.Tensor, |
| | eps: float = 1e-6, |
| | output_dtype: Optional[torch.dtype] = None |
| | ) -> torch.Tensor: |
| | return L2NormFunction.apply(x, eps, output_dtype) |
| |
|
| |
|
| | l2_norm = l2norm |
| |
|
| |
|
| | class L2Norm(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | eps: float = 1e-6, |
| | output_dtype: Optional[torch.dtype] = None |
| | ): |
| | super().__init__() |
| | self.eps = eps |
| | self.output_dtype = output_dtype |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return l2norm(x, self.eps, self.output_dtype) |
| |
|