msj19's picture
Add files using upload-large-folder tool
e73a905 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
import torch.nn as nn
import triton
import triton.language as tl
from fla.utils import input_guard
BT_LIST = [8, 16, 32, 64, 128]
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D']
)
@triton.jit
def l2norm_fwd_kernel1(
x,
y,
D,
BD: tl.constexpr,
eps,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
# Compute mean and variance
cols = tl.arange(0, BD)
mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=0)
b_rstd = 1 / tl.sqrt(b_var + eps)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y = b_x * b_rstd
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D']
)
@triton.jit
def l2norm_bwd_kernel1(
x,
dy,
dx,
eps,
D,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
x += i_t * D
dx += i_t * D
dy += i_t * D
# Y += i_t * stride_y_row
cols = tl.arange(0, BD)
mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x)
b_rstd = 1 / tl.sqrt(b_var + eps)
b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32)
b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x) * (1 / (b_var+eps)) * b_rstd * b_x
tl.store(dx + cols, b_dx, mask=mask)
@triton.autotune(
configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=['D', 'NB']
)
@triton.jit
def l2norm_fwd_kernel(
x,
y,
eps,
NB: tl.constexpr,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=['D', 'NB']
)
@triton.jit
def l2norm_bwd_kernel(
x,
dy,
dx,
eps,
NB: tl.constexpr,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)[:, None]
b_rstd = 1 / tl.sqrt(b_var + eps)
b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x, axis=1)[:, None] / (b_var+eps) * b_rstd * b_x
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
def l2norm_fwd(
x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None
):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta): return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T,)](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)
def l2norm_bwd(
x: torch.Tensor,
dy: torch.Tensor,
eps: float = 1e-5
):
x_shape_og = x.shape
x = x.view(-1, dy.shape[-1])
dy = dy.view(-1, dy.shape[-1])
assert dy.shape == x.shape
# allocate output
dx = torch.empty_like(x)
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta): return (triton.cdiv(T, meta['BT']), )
l2norm_bwd_kernel[grid](
x,
dy,
dx,
eps=eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_bwd_kernel1[(T,)](
x,
dy,
dx,
eps=eps,
D=D,
BD=BD,
)
return dx.view(x_shape_og)
class L2NormFunction(torch.autograd.Function):
@staticmethod
@input_guard
def forward(
ctx,
x,
eps=1e-6,
output_dtype=None
):
y = l2norm_fwd(x, eps, output_dtype)
ctx.eps = eps
ctx.x_dtype = x.dtype
ctx.save_for_backward(x)
return y
@staticmethod
@input_guard
def backward(ctx, dy):
x, = ctx.saved_tensors
dx = l2norm_bwd(x, dy, ctx.eps)
return dx, None, None
def l2norm(
x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
return L2NormFunction.apply(x, eps, output_dtype)
l2_norm = l2norm
class L2Norm(nn.Module):
def __init__(
self,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None
):
super().__init__()
self.eps = eps
self.output_dtype = output_dtype
def forward(self, x: torch.Tensor) -> torch.Tensor:
return l2norm(x, self.eps, self.output_dtype)