# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from typing import Optional import torch import torch.nn as nn import triton import triton.language as tl from fla.utils import input_guard BT_LIST = [8, 16, 32, 64, 128] @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] ], key=['D'] ) @triton.jit def l2norm_fwd_kernel1( x, y, D, BD: tl.constexpr, eps, ): i_t = tl.program_id(0) x += i_t * D y += i_t * D # Compute mean and variance cols = tl.arange(0, BD) mask = cols < D b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) b_var = tl.sum(b_x * b_x, axis=0) b_rstd = 1 / tl.sqrt(b_var + eps) # tl.store(Rstd + i_t, rstd) # Normalize and apply linear transformation b_y = b_x * b_rstd tl.store(y + cols, b_y, mask=mask) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] ], key=['D'] ) @triton.jit def l2norm_bwd_kernel1( x, dy, dx, eps, D, BD: tl.constexpr, ): i_t = tl.program_id(0) x += i_t * D dx += i_t * D dy += i_t * D # Y += i_t * stride_y_row cols = tl.arange(0, BD) mask = cols < D b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) b_var = tl.sum(b_x * b_x) b_rstd = 1 / tl.sqrt(b_var + eps) b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32) b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x) * (1 / (b_var+eps)) * b_rstd * b_x tl.store(dx + cols, b_dx, mask=mask) @triton.autotune( configs=[ triton.Config({'BT': BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST ], key=['D', 'NB'] ) @triton.jit def l2norm_fwd_kernel( x, y, eps, NB: tl.constexpr, T: tl.constexpr, D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr, ): i_t = tl.program_id(0) p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) b_var = tl.sum(b_x * b_x, axis=1) b_y = b_x / tl.sqrt(b_var + eps)[:, None] p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) @triton.autotune( configs=[ triton.Config({'BT': BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST ], key=['D', 'NB'] ) @triton.jit def l2norm_bwd_kernel( x, dy, dx, eps, NB: tl.constexpr, T: tl.constexpr, D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr, ): i_t = tl.program_id(0) p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) b_var = tl.sum(b_x * b_x, axis=1)[:, None] b_rstd = 1 / tl.sqrt(b_var + eps) b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x, axis=1)[:, None] / (b_var+eps) * b_rstd * b_x tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) def l2norm_fwd( x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None ): x_shape_og = x.shape x = x.view(-1, x.shape[-1]) # allocate output if output_dtype is None: y = torch.empty_like(x) else: y = torch.empty_like(x, dtype=output_dtype) assert y.stride(-1) == 1 T, D = x.shape[0], x.shape[-1] # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) if D > BD: raise RuntimeError("This layer doesn't support feature dim >= 64KB.") if D <= 512: NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_fwd_kernel[grid]( x, y, eps, NB=NB, T=T, D=D, BD=BD, ) else: l2norm_fwd_kernel1[(T,)]( x, y, eps=eps, D=D, BD=BD, ) return y.view(x_shape_og) def l2norm_bwd( x: torch.Tensor, dy: torch.Tensor, eps: float = 1e-5 ): x_shape_og = x.shape x = x.view(-1, dy.shape[-1]) dy = dy.view(-1, dy.shape[-1]) assert dy.shape == x.shape # allocate output dx = torch.empty_like(x) T, D = x.shape[0], x.shape[-1] # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) if D > BD: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps if D <= 512: NB = triton.cdiv(T, 2048) def grid(meta): return (triton.cdiv(T, meta['BT']), ) l2norm_bwd_kernel[grid]( x, dy, dx, eps=eps, NB=NB, T=T, D=D, BD=BD, ) else: l2norm_bwd_kernel1[(T,)]( x, dy, dx, eps=eps, D=D, BD=BD, ) return dx.view(x_shape_og) class L2NormFunction(torch.autograd.Function): @staticmethod @input_guard def forward( ctx, x, eps=1e-6, output_dtype=None ): y = l2norm_fwd(x, eps, output_dtype) ctx.eps = eps ctx.x_dtype = x.dtype ctx.save_for_backward(x) return y @staticmethod @input_guard def backward(ctx, dy): x, = ctx.saved_tensors dx = l2norm_bwd(x, dy, ctx.eps) return dx, None, None def l2norm( x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None ) -> torch.Tensor: return L2NormFunction.apply(x, eps, output_dtype) l2_norm = l2norm class L2Norm(nn.Module): def __init__( self, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None ): super().__init__() self.eps = eps self.output_dtype = output_dtype def forward(self, x: torch.Tensor) -> torch.Tensor: return l2norm(x, self.eps, self.output_dtype)