| | import math |
| | import operator |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| | from .utils import calculate_settings |
| | from .utils import compare_version |
| | from .utils import ensure_contiguous |
| |
|
| | if compare_version("triton", operator.ge, "3.0.0"): |
| | try: |
| | |
| | from triton.language.extra.libdevice import rsqrt |
| | except ModuleNotFoundError: |
| | |
| | from triton.language.extra.cuda.libdevice import rsqrt |
| | else: |
| | from triton.language.math import rsqrt |
| |
|
| |
|
| | @triton.jit |
| | def _layer_norm_forward_kernel( |
| | Y_ptr, |
| | Y_row_stride, |
| | X_ptr, |
| | X_row_stride, |
| | W_ptr, |
| | W_row_stride, |
| | B_ptr, |
| | B_row_stride, |
| | Mean_ptr, |
| | Mean_row_stride, |
| | RSTD_ptr, |
| | RSTD_row_stride, |
| | n_cols, |
| | eps, |
| | BLOCK_SIZE: tl.constexpr, |
| | ): |
| | """ |
| | References: |
| | https://arxiv.org/abs/1607.06450 |
| | https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md |
| | """ |
| | row_idx = tl.program_id(0) |
| | col_offsets = tl.arange(0, BLOCK_SIZE) |
| | mask = col_offsets < n_cols |
| |
|
| | Y_ptr += row_idx * Y_row_stride |
| | X_ptr += row_idx * X_row_stride |
| | Mean_ptr += row_idx * Mean_row_stride |
| | RSTD_ptr += row_idx * RSTD_row_stride |
| |
|
| | X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) |
| | W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) |
| | B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0) |
| |
|
| | mean = tl.sum(X_row, axis=0) / n_cols |
| | Xmm = tl.where(mask, X_row - mean, 0) |
| | var = tl.sum(Xmm * Xmm, axis=0) / n_cols |
| | rstd = rsqrt(var + eps) |
| |
|
| | tl.store(Mean_ptr, mean) |
| | tl.store(RSTD_ptr, rstd) |
| |
|
| | Y_row = Xmm * rstd * W_row + B_row |
| |
|
| | tl.store(Y_ptr + col_offsets, Y_row, mask=mask) |
| |
|
| |
|
| | @triton.jit |
| | def _layer_norm_backward_kernel( |
| | X_ptr, |
| | W_ptr, |
| | Mean_ptr, |
| | RSTD_ptr, |
| | DX_ptr, |
| | DW_ptr, |
| | DB_ptr, |
| | DY_ptr, |
| | stride_x, |
| | stride_dx, |
| | stride_dw, |
| | stride_db, |
| | stride_dy, |
| | n_rows, |
| | n_cols, |
| | rows_per_program: tl.constexpr, |
| | BLOCK_SIZE: tl.constexpr, |
| | dtype: tl.constexpr, |
| | ): |
| | """ |
| | References: |
| | https://arxiv.org/abs/1607.06450 |
| | https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md |
| | https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html |
| | https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py |
| | """ |
| | row_block_id = tl.program_id(0) |
| | row_start = row_block_id * rows_per_program |
| | row_end = min((row_block_id + 1) * rows_per_program, n_rows) |
| | cols = tl.arange(0, BLOCK_SIZE) |
| | mask = cols < n_cols |
| |
|
| | dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) |
| | db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) |
| |
|
| | X_ptr += row_start * stride_x |
| | Mean_ptr += row_start |
| | RSTD_ptr += row_start |
| | DX_ptr += row_start * stride_dx |
| | DY_ptr += row_start * stride_dy |
| |
|
| | for _ in range(row_start, row_end): |
| | x = tl.load(X_ptr + cols, mask=mask, other=0.0) |
| | w = tl.load(W_ptr + cols, mask=mask, other=0.0) |
| | dy = tl.load(DY_ptr + cols, mask=mask, other=0.0) |
| | mean = tl.load(Mean_ptr) |
| | rstd = tl.load(RSTD_ptr) |
| |
|
| | x_hat = (x - mean) * rstd |
| | wdy = w * dy |
| | c1 = tl.sum(x_hat * wdy, axis=0) / n_cols |
| | c2 = tl.sum(wdy, axis=0) / n_cols |
| | dx = (wdy - (x_hat * c1 + c2)) * rstd |
| | tl.store(DX_ptr + cols, dx.to(dtype), mask=mask) |
| |
|
| | dw_row += dy * x_hat |
| | db_row += dy |
| |
|
| | X_ptr += stride_x |
| | Mean_ptr += 1 |
| | RSTD_ptr += 1 |
| | DX_ptr += stride_dx |
| | DY_ptr += stride_dy |
| |
|
| | tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask) |
| | tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) |
| |
|
| |
|
| | def layer_norm_forward(X, W, B, eps): |
| | shape = X.shape |
| | dim = shape[-1] |
| | X = X.view(-1, dim) |
| | n_rows, n_cols = X.shape |
| | BLOCK_SIZE, num_warps = calculate_settings(n_cols) |
| | Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) |
| | Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) |
| | RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) |
| | if X.shape[1] != W.shape[0]: |
| | raise ValueError( |
| | f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) " |
| | f"must match weight size (W.shape[0]={W.shape[0]})" |
| | ) |
| |
|
| | |
| | kernel_args = {} |
| | if X.device.type == "xpu": |
| | kernel_args["grf_mode"] = "large" |
| |
|
| | _layer_norm_forward_kernel[(n_rows,)]( |
| | Y, |
| | Y.stride(0), |
| | X, |
| | X.stride(0), |
| | W, |
| | W.stride(0), |
| | B, |
| | B.stride(0), |
| | Mean, |
| | Mean.stride(0), |
| | RSTD, |
| | RSTD.stride(0), |
| | n_cols, |
| | eps, |
| | BLOCK_SIZE=BLOCK_SIZE, |
| | num_warps=num_warps, |
| | **kernel_args, |
| | ) |
| | return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps |
| |
|
| |
|
| | def layer_norm_backward(dY, X, W, B, Mean, RSTD): |
| | shape = dY.shape |
| | dim = shape[-1] |
| | dY = dY.view(-1, dim) |
| | n_rows, n_cols = dY.shape |
| |
|
| | sm_count = 1 |
| | if X.device.type == "cuda": |
| | sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count |
| | elif X.device.type == "xpu": |
| | sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count |
| |
|
| | DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) |
| | _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) |
| | _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) |
| |
|
| | BLOCK_SIZE, num_warps = calculate_settings(n_cols) |
| | if n_cols > BLOCK_SIZE: |
| | raise RuntimeError( |
| | f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." |
| | ) |
| |
|
| | rows_per_program = math.ceil(n_rows / sm_count) |
| | grid = (sm_count,) |
| | triton_dtype = ( |
| | tl.float32 |
| | if X.dtype == torch.float32 |
| | else tl.bfloat16 |
| | if X.dtype == torch.bfloat16 |
| | else tl.float16 |
| | if X.dtype == torch.float16 |
| | else tl.float32 |
| | ) |
| |
|
| | |
| | kernel_args = {} |
| | if X.device.type == "xpu": |
| | kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) |
| |
|
| | _layer_norm_backward_kernel[grid]( |
| | X, |
| | W, |
| | Mean, |
| | RSTD, |
| | DX, |
| | _DW, |
| | _DB, |
| | dY, |
| | X.stride(0), |
| | DX.stride(0), |
| | _DW.stride(0), |
| | _DB.stride(0), |
| | dY.stride(0), |
| | n_rows, |
| | n_cols, |
| | rows_per_program, |
| | BLOCK_SIZE=BLOCK_SIZE, |
| | dtype=triton_dtype, |
| | **kernel_args, |
| | ) |
| |
|
| | DW = _DW.sum(dim=0).to(W.dtype) |
| | DB = _DB.sum(dim=0).to(W.dtype) |
| |
|
| | DX = DX.view(*shape) |
| | return DX, DW, DB |
| |
|
| |
|
| | class LigerLayerNormFunction(torch.autograd.Function): |
| | @staticmethod |
| | @ensure_contiguous |
| | def forward(ctx, X, W, B, eps): |
| | Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps) |
| | ctx.save_for_backward(X, W, B, Mean, RSTD) |
| | return Y |
| |
|
| | @staticmethod |
| | @ensure_contiguous |
| | def backward(ctx, dY): |
| | X, W, B, Mean, RSTD = ctx.saved_tensors |
| | DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) |
| | return DX, DW, DB, None |