Kernels
danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
e5b00e0 verified
import math
import operator
import torch
import triton
import triton.language as tl
from .utils import calculate_settings
from .utils import compare_version
from .utils import ensure_contiguous
if compare_version("triton", operator.ge, "3.0.0"):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _layer_norm_forward_kernel(
Y_ptr, # pointer to output, shape (n_rows, n_cols)
Y_row_stride, # stride of each row in output
X_ptr, # pointer to input, shape (n_rows, n_cols)
X_row_stride, # stride of each row in input
W_ptr, # pointer to weights, shape (n_cols,)
W_row_stride, # stride of each row in weights
B_ptr, # pointer to bias, shape (n_cols,)
B_row_stride, # stride of each row in bias
Mean_ptr, # pointer to mean, shape (n_rows,)
Mean_row_stride, # stride of each row in mean
RSTD_ptr, # pointer to rstd, shape (n_rows,)
RSTD_row_stride, # stride of each row in rstd
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
References:
https://arxiv.org/abs/1607.06450
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y_ptr += row_idx * Y_row_stride
X_ptr += row_idx * X_row_stride
Mean_ptr += row_idx * Mean_row_stride
RSTD_ptr += row_idx * RSTD_row_stride
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
mean = tl.sum(X_row, axis=0) / n_cols
Xmm = tl.where(mask, X_row - mean, 0)
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
rstd = rsqrt(var + eps)
tl.store(Mean_ptr, mean)
tl.store(RSTD_ptr, rstd)
Y_row = Xmm * rstd * W_row + B_row
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
@triton.jit
def _layer_norm_backward_kernel(
X_ptr, # pointer to input, shape (n_rows, n_cols)
W_ptr, # pointer to weights, shape (n_cols,)
Mean_ptr, # pointer to mean, shape (n_rows,)
RSTD_ptr, # pointer to rstd, shape (n_rows,)
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
DW_ptr, # pointer to weights grad, shape (n_cols,)
DB_ptr, # pointer to bias grad, shape (n_cols,)
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
stride_x, # stride of each row in input
stride_dx, # stride of each row in input grad
stride_dw, # stride of each row in weights grad
stride_db, # stride of each row in bias grad
stride_dy, # stride of each row in output grad
n_rows,
n_cols,
rows_per_program: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
dtype: tl.constexpr,
):
"""
References:
https://arxiv.org/abs/1607.06450
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
"""
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
X_ptr += row_start * stride_x
Mean_ptr += row_start
RSTD_ptr += row_start
DX_ptr += row_start * stride_dx
DY_ptr += row_start * stride_dy
for _ in range(row_start, row_end):
x = tl.load(X_ptr + cols, mask=mask, other=0.0)
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
mean = tl.load(Mean_ptr)
rstd = tl.load(RSTD_ptr)
x_hat = (x - mean) * rstd
wdy = w * dy
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
c2 = tl.sum(wdy, axis=0) / n_cols
dx = (wdy - (x_hat * c1 + c2)) * rstd
tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
dw_row += dy * x_hat
db_row += dy
X_ptr += stride_x
Mean_ptr += 1
RSTD_ptr += 1
DX_ptr += stride_dx
DY_ptr += stride_dy
tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
def layer_norm_forward(X, W, B, eps):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
if X.shape[1] != W.shape[0]:
raise ValueError(
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
f"must match weight size (W.shape[0]={W.shape[0]})"
)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
kernel_args["grf_mode"] = "large"
_layer_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0),
B,
B.stride(0),
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = 1
if X.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if n_cols > BLOCK_SIZE:
raise RuntimeError(
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
triton_dtype = (
tl.float32
if X.dtype == torch.float32
else tl.bfloat16
if X.dtype == torch.bfloat16
else tl.float16
if X.dtype == torch.float16
else tl.float32 # fallback to float32 for other types
)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
_layer_norm_backward_kernel[grid](
X,
W,
Mean,
RSTD,
DX,
_DW,
_DB,
dY,
X.stride(0),
DX.stride(0),
_DW.stride(0),
_DB.stride(0),
dY.stride(0),
n_rows,
n_cols,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
dtype=triton_dtype,
**kernel_args, # XPU-specific optimization
)
DW = _DW.sum(dim=0).to(W.dtype)
DB = _DB.sum(dim=0).to(W.dtype)
DX = DX.view(*shape)
return DX, DW, DB
class LigerLayerNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps):
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
ctx.save_for_backward(X, W, B, Mean, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
return DX, DW, DB, None