File size: 8,307 Bytes
e5b00e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | import math
import operator
import torch
import triton
import triton.language as tl
from .utils import calculate_settings
from .utils import compare_version
from .utils import ensure_contiguous
if compare_version("triton", operator.ge, "3.0.0"):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _layer_norm_forward_kernel(
Y_ptr, # pointer to output, shape (n_rows, n_cols)
Y_row_stride, # stride of each row in output
X_ptr, # pointer to input, shape (n_rows, n_cols)
X_row_stride, # stride of each row in input
W_ptr, # pointer to weights, shape (n_cols,)
W_row_stride, # stride of each row in weights
B_ptr, # pointer to bias, shape (n_cols,)
B_row_stride, # stride of each row in bias
Mean_ptr, # pointer to mean, shape (n_rows,)
Mean_row_stride, # stride of each row in mean
RSTD_ptr, # pointer to rstd, shape (n_rows,)
RSTD_row_stride, # stride of each row in rstd
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
References:
https://arxiv.org/abs/1607.06450
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y_ptr += row_idx * Y_row_stride
X_ptr += row_idx * X_row_stride
Mean_ptr += row_idx * Mean_row_stride
RSTD_ptr += row_idx * RSTD_row_stride
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
mean = tl.sum(X_row, axis=0) / n_cols
Xmm = tl.where(mask, X_row - mean, 0)
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
rstd = rsqrt(var + eps)
tl.store(Mean_ptr, mean)
tl.store(RSTD_ptr, rstd)
Y_row = Xmm * rstd * W_row + B_row
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
@triton.jit
def _layer_norm_backward_kernel(
X_ptr, # pointer to input, shape (n_rows, n_cols)
W_ptr, # pointer to weights, shape (n_cols,)
Mean_ptr, # pointer to mean, shape (n_rows,)
RSTD_ptr, # pointer to rstd, shape (n_rows,)
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
DW_ptr, # pointer to weights grad, shape (n_cols,)
DB_ptr, # pointer to bias grad, shape (n_cols,)
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
stride_x, # stride of each row in input
stride_dx, # stride of each row in input grad
stride_dw, # stride of each row in weights grad
stride_db, # stride of each row in bias grad
stride_dy, # stride of each row in output grad
n_rows,
n_cols,
rows_per_program: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
dtype: tl.constexpr,
):
"""
References:
https://arxiv.org/abs/1607.06450
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
"""
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
X_ptr += row_start * stride_x
Mean_ptr += row_start
RSTD_ptr += row_start
DX_ptr += row_start * stride_dx
DY_ptr += row_start * stride_dy
for _ in range(row_start, row_end):
x = tl.load(X_ptr + cols, mask=mask, other=0.0)
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
mean = tl.load(Mean_ptr)
rstd = tl.load(RSTD_ptr)
x_hat = (x - mean) * rstd
wdy = w * dy
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
c2 = tl.sum(wdy, axis=0) / n_cols
dx = (wdy - (x_hat * c1 + c2)) * rstd
tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
dw_row += dy * x_hat
db_row += dy
X_ptr += stride_x
Mean_ptr += 1
RSTD_ptr += 1
DX_ptr += stride_dx
DY_ptr += stride_dy
tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
def layer_norm_forward(X, W, B, eps):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
if X.shape[1] != W.shape[0]:
raise ValueError(
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
f"must match weight size (W.shape[0]={W.shape[0]})"
)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
kernel_args["grf_mode"] = "large"
_layer_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0),
B,
B.stride(0),
Mean,
Mean.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = 1
if X.device.type == "cuda":
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
if n_cols > BLOCK_SIZE:
raise RuntimeError(
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
triton_dtype = (
tl.float32
if X.dtype == torch.float32
else tl.bfloat16
if X.dtype == torch.bfloat16
else tl.float16
if X.dtype == torch.float16
else tl.float32 # fallback to float32 for other types
)
# XPU-specific optimization
kernel_args = {}
if X.device.type == "xpu":
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
_layer_norm_backward_kernel[grid](
X,
W,
Mean,
RSTD,
DX,
_DW,
_DB,
dY,
X.stride(0),
DX.stride(0),
_DW.stride(0),
_DB.stride(0),
dY.stride(0),
n_rows,
n_cols,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
dtype=triton_dtype,
**kernel_args, # XPU-specific optimization
)
DW = _DW.sum(dim=0).to(W.dtype)
DB = _DB.sum(dim=0).to(W.dtype)
DX = DX.view(*shape)
return DX, DW, DB
class LigerLayerNormFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, B, eps):
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
ctx.save_for_backward(X, W, B, Mean, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, B, Mean, RSTD = ctx.saved_tensors
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
return DX, DW, DB, None |