Kernels:
Trusted publisher
Uploaded using `kernel-builder`.
Browse files- build/torch-cuda/__init__.py +2 -1
- build/torch-cuda/_ops.py +1 -1
- build/torch-cuda/cross_entropy.py +163 -65
- build/torch-cuda/dyt.py +121 -182
- build/torch-cuda/fused_linear_cross_entropy.py +152 -35
- build/torch-cuda/geglu.py +7 -5
- build/torch-cuda/group_norm.py +22 -16
- build/torch-cuda/jsd.py +1 -1
- build/torch-cuda/kl_div.py +9 -12
- build/torch-cuda/layer_norm.py +139 -84
- build/torch-cuda/layers.py +457 -33
- build/torch-cuda/metadata.json +1 -1
- build/torch-cuda/qwen2vl_mrope.py +1 -1
- build/torch-cuda/rms_norm.py +390 -101
- build/torch-cuda/rope.py +2 -2
- build/torch-cuda/swiglu.py +75 -15
- build/torch-cuda/tvd.py +18 -7
- build/torch-cuda/utils.py +42 -1
build/torch-cuda/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
from . import layers
|
|
|
|
| 2 |
|
| 3 |
-
__all__ = ["layers"]
|
|
|
|
| 1 |
from . import layers
|
| 2 |
+
from .layers import CrossEntropyOutput, LigerForCausalLMLoss
|
| 3 |
|
| 4 |
+
__all__ = ["layers", "LigerForCausalLMLoss", "CrossEntropyOutput"]
|
build/torch-cuda/_ops.py
CHANGED
|
@@ -22,7 +22,7 @@ def get_backend() -> str:
|
|
| 22 |
|
| 23 |
def _find_ops_name() -> str:
|
| 24 |
kernel_name = "liger_kernels"
|
| 25 |
-
unique_id = "
|
| 26 |
backend = get_backend()
|
| 27 |
return f"_{kernel_name}_{backend}_{unique_id}"
|
| 28 |
|
|
|
|
| 22 |
|
| 23 |
def _find_ops_name() -> str:
|
| 24 |
kernel_name = "liger_kernels"
|
| 25 |
+
unique_id = "08b4d53"
|
| 26 |
backend = get_backend()
|
| 27 |
return f"_{kernel_name}_{backend}_{unique_id}"
|
| 28 |
|
build/torch-cuda/cross_entropy.py
CHANGED
|
@@ -10,8 +10,9 @@ from .utils import compare_version
|
|
| 10 |
from .utils import element_mul_kernel
|
| 11 |
from .utils import is_hip
|
| 12 |
from .utils import infer_device
|
|
|
|
| 13 |
|
| 14 |
-
if compare_version("triton", operator.ge, "3.0.0"):
|
| 15 |
try:
|
| 16 |
# typical import path with dispatch available
|
| 17 |
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,10 @@ def liger_cross_entropy_kernel(
|
|
| 32 |
loss_ptr,
|
| 33 |
z_loss_ptr,
|
| 34 |
loss_stride,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
n_cols,
|
| 36 |
n_non_ignore,
|
| 37 |
sum_non_ignore_weight,
|
|
@@ -42,9 +47,12 @@ def liger_cross_entropy_kernel(
|
|
| 42 |
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
| 43 |
softcap,
|
| 44 |
RETURN_Z_LOSS: tl.constexpr,
|
|
|
|
|
|
|
| 45 |
BLOCK_SIZE: tl.constexpr,
|
| 46 |
HAS_WEIGHT: tl.constexpr,
|
| 47 |
HAS_SOFTCAPPING: tl.constexpr,
|
|
|
|
| 48 |
):
|
| 49 |
"""
|
| 50 |
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +67,8 @@ def liger_cross_entropy_kernel(
|
|
| 59 |
loss_ptr: Pointer to tensor to store the loss.
|
| 60 |
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
| 61 |
loss_stride (int): The stride of the loss tensor.
|
|
|
|
|
|
|
| 62 |
n_cols (int): The number of columns in the input tensor.
|
| 63 |
n_non_ignore (float): The number of non-ignored elements in the batch.
|
| 64 |
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +78,12 @@ def liger_cross_entropy_kernel(
|
|
| 68 |
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
| 69 |
reduction (str): The string for the reduction to apply
|
| 70 |
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
| 71 |
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
|
|
| 72 |
BLOCK_SIZE (int): The block size for Triton operations.
|
| 73 |
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
| 74 |
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
|
|
| 75 |
"""
|
| 76 |
|
| 77 |
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +102,22 @@ def liger_cross_entropy_kernel(
|
|
| 90 |
for i in range(0, n_cols, BLOCK_SIZE):
|
| 91 |
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
| 92 |
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
return
|
| 94 |
|
| 95 |
loss_ptr += program_id * loss_stride
|
| 96 |
if RETURN_Z_LOSS:
|
| 97 |
z_loss_ptr += program_id * loss_stride
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
if HAS_WEIGHT:
|
| 100 |
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +128,7 @@ def liger_cross_entropy_kernel(
|
|
| 105 |
# 3. [Online softmax] first pass: find max + sum
|
| 106 |
m = float("-inf") # m is the max value. use the notation from the paper
|
| 107 |
d = 0.0 # d is the sum. use the notation from the paper
|
|
|
|
| 108 |
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
| 109 |
if HAS_SOFTCAPPING:
|
| 110 |
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +149,19 @@ def liger_cross_entropy_kernel(
|
|
| 125 |
if HAS_SOFTCAPPING:
|
| 126 |
X_block = softcap * tanh(X_block / softcap)
|
| 127 |
block_max = tl.max(X_block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
if label_smoothing > 0:
|
| 129 |
# scale X beforehand to avoid overflow
|
| 130 |
if HAS_WEIGHT:
|
|
@@ -155,58 +192,58 @@ def liger_cross_entropy_kernel(
|
|
| 155 |
# For 'sum' reduction, no normalization is applied:
|
| 156 |
# dx_y = softmax(x_y) - 1
|
| 157 |
# dx_i = softmax(x_i), for i ≠ y
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
|
| 211 |
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
| 212 |
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,12 +291,24 @@ def liger_cross_entropy_kernel(
|
|
| 254 |
tl.store(loss_ptr, loss)
|
| 255 |
if RETURN_Z_LOSS:
|
| 256 |
tl.store(z_loss_ptr, z_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
|
| 259 |
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
| 260 |
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
| 261 |
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
|
| 265 |
def cross_entropy_forward(
|
|
@@ -272,8 +321,16 @@ def cross_entropy_forward(
|
|
| 272 |
reduction,
|
| 273 |
softcap,
|
| 274 |
return_z_loss,
|
|
|
|
|
|
|
| 275 |
):
|
| 276 |
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
BT, V = _input.shape
|
| 279 |
n_rows = BT
|
|
@@ -283,6 +340,12 @@ def cross_entropy_forward(
|
|
| 283 |
# unreduced loss
|
| 284 |
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
| 285 |
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
target_mask = target != ignore_index
|
| 288 |
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +382,14 @@ def cross_entropy_forward(
|
|
| 319 |
loss_ptr=loss_1d,
|
| 320 |
z_loss_ptr=z_loss_1d,
|
| 321 |
loss_stride=loss_1d.stride(-1), # always 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
n_cols=V,
|
| 323 |
n_non_ignore=n_non_ignore,
|
| 324 |
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +400,12 @@ def cross_entropy_forward(
|
|
| 329 |
reduction=reduction,
|
| 330 |
softcap=softcap,
|
| 331 |
RETURN_Z_LOSS=return_z_loss,
|
|
|
|
|
|
|
| 332 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 333 |
HAS_WEIGHT=True if weight is not None else False,
|
| 334 |
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
|
|
| 335 |
# TODO: 32 seems to give the best performance
|
| 336 |
# Performance is quite sensitive to num_warps
|
| 337 |
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,11 +414,16 @@ def cross_entropy_forward(
|
|
| 340 |
if reduction == "none":
|
| 341 |
loss = loss_1d
|
| 342 |
z_loss = z_loss_1d if return_z_loss else None
|
|
|
|
| 343 |
else:
|
| 344 |
loss = torch.sum(loss_1d)
|
| 345 |
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
|
|
|
|
|
|
|
| 348 |
|
| 349 |
|
| 350 |
def cross_entropy_backward(_input, grad_output):
|
|
@@ -392,6 +471,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
| 392 |
reduction: str = "mean",
|
| 393 |
softcap: Optional[float] = None,
|
| 394 |
return_z_loss: bool = False,
|
|
|
|
|
|
|
| 395 |
):
|
| 396 |
"""
|
| 397 |
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -406,12 +487,16 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
| 406 |
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
| 407 |
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
| 408 |
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
| 409 |
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
|
|
|
|
|
| 410 |
|
| 411 |
Returns:
|
| 412 |
-
tuple: A tuple with the
|
| 413 |
"""
|
| 414 |
-
|
|
|
|
|
|
|
| 415 |
_input,
|
| 416 |
target,
|
| 417 |
weight,
|
|
@@ -421,29 +506,40 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
| 421 |
reduction,
|
| 422 |
softcap,
|
| 423 |
return_z_loss,
|
|
|
|
|
|
|
| 424 |
)
|
| 425 |
# TODO: investigation
|
| 426 |
# If we don't detach the _input tensor, the memory will double
|
| 427 |
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
| 428 |
-
|
|
|
|
| 429 |
ctx.return_z_loss = return_z_loss
|
|
|
|
|
|
|
| 430 |
|
| 431 |
-
return loss, z_loss
|
| 432 |
|
| 433 |
@staticmethod
|
| 434 |
-
def backward(ctx, grad_output,
|
| 435 |
"""
|
| 436 |
The backward pass of the Liger Cross Entropy loss.
|
| 437 |
|
| 438 |
Parameters:
|
| 439 |
ctx : The context object with saved tensors.
|
| 440 |
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
| 441 |
-
grad_output2 (
|
|
|
|
|
|
|
| 442 |
Returns:
|
| 443 |
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
| 444 |
"""
|
| 445 |
if ctx.return_z_loss:
|
| 446 |
-
del
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
(_input,) = ctx.saved_tensors
|
| 449 |
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -457,4 +553,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
| 457 |
None,
|
| 458 |
None,
|
| 459 |
None,
|
| 460 |
-
|
|
|
|
|
|
|
|
|
| 10 |
from .utils import element_mul_kernel
|
| 11 |
from .utils import is_hip
|
| 12 |
from .utils import infer_device
|
| 13 |
+
from .utils import is_npu_available
|
| 14 |
|
| 15 |
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
| 16 |
try:
|
| 17 |
# typical import path with dispatch available
|
| 18 |
from triton.language.extra.libdevice import tanh
|
|
|
|
| 33 |
loss_ptr,
|
| 34 |
z_loss_ptr,
|
| 35 |
loss_stride,
|
| 36 |
+
token_accuracy_ptr,
|
| 37 |
+
token_accuracy_stride,
|
| 38 |
+
predicted_tokens_ptr,
|
| 39 |
+
predicted_tokens_stride,
|
| 40 |
n_cols,
|
| 41 |
n_non_ignore,
|
| 42 |
sum_non_ignore_weight,
|
|
|
|
| 47 |
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
| 48 |
softcap,
|
| 49 |
RETURN_Z_LOSS: tl.constexpr,
|
| 50 |
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
| 51 |
+
RETURN_PREDICTED_TOKENS: tl.constexpr,
|
| 52 |
BLOCK_SIZE: tl.constexpr,
|
| 53 |
HAS_WEIGHT: tl.constexpr,
|
| 54 |
HAS_SOFTCAPPING: tl.constexpr,
|
| 55 |
+
HAS_GRADIENTS: tl.constexpr,
|
| 56 |
):
|
| 57 |
"""
|
| 58 |
This kernel computes both cross entropy loss and the gradient of the input.
|
|
|
|
| 67 |
loss_ptr: Pointer to tensor to store the loss.
|
| 68 |
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
| 69 |
loss_stride (int): The stride of the loss tensor.
|
| 70 |
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
| 71 |
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
| 72 |
n_cols (int): The number of columns in the input tensor.
|
| 73 |
n_non_ignore (float): The number of non-ignored elements in the batch.
|
| 74 |
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
|
|
| 78 |
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
| 79 |
reduction (str): The string for the reduction to apply
|
| 80 |
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
| 81 |
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
| 82 |
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
| 83 |
BLOCK_SIZE (int): The block size for Triton operations.
|
| 84 |
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
| 85 |
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
| 86 |
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
| 87 |
"""
|
| 88 |
|
| 89 |
# https://github.com/triton-lang/triton/issues/1058
|
|
|
|
| 102 |
for i in range(0, n_cols, BLOCK_SIZE):
|
| 103 |
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
| 104 |
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
| 105 |
+
# For ignored tokens, set token accuracy to 0
|
| 106 |
+
if RETURN_TOKEN_ACCURACY:
|
| 107 |
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
| 108 |
+
tl.store(token_accuracy_ptr, 0.0)
|
| 109 |
+
if RETURN_PREDICTED_TOKENS:
|
| 110 |
+
predicted_tokens_ptr += program_id * predicted_tokens_stride
|
| 111 |
+
tl.store(predicted_tokens_ptr, -1)
|
| 112 |
return
|
| 113 |
|
| 114 |
loss_ptr += program_id * loss_stride
|
| 115 |
if RETURN_Z_LOSS:
|
| 116 |
z_loss_ptr += program_id * loss_stride
|
| 117 |
+
if RETURN_TOKEN_ACCURACY:
|
| 118 |
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
| 119 |
+
if RETURN_PREDICTED_TOKENS:
|
| 120 |
+
predicted_tokens_ptr += program_id * predicted_tokens_stride
|
| 121 |
|
| 122 |
if HAS_WEIGHT:
|
| 123 |
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
|
|
| 128 |
# 3. [Online softmax] first pass: find max + sum
|
| 129 |
m = float("-inf") # m is the max value. use the notation from the paper
|
| 130 |
d = 0.0 # d is the sum. use the notation from the paper
|
| 131 |
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy / predicted tokens computation
|
| 132 |
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
| 133 |
if HAS_SOFTCAPPING:
|
| 134 |
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
|
|
| 149 |
if HAS_SOFTCAPPING:
|
| 150 |
X_block = softcap * tanh(X_block / softcap)
|
| 151 |
block_max = tl.max(X_block)
|
| 152 |
+
|
| 153 |
+
# Track argmax for accuracy / predicted tokens computation
|
| 154 |
+
if RETURN_TOKEN_ACCURACY or RETURN_PREDICTED_TOKENS:
|
| 155 |
+
# Find the index of the maximum value in this block
|
| 156 |
+
is_max_mask = X_block == block_max
|
| 157 |
+
# Mask out invalid indices with a value larger than n_cols
|
| 158 |
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
| 159 |
+
# Get the first (smallest) index where max occurs
|
| 160 |
+
current_block_argmax_idx = tl.min(masked_offsets)
|
| 161 |
+
|
| 162 |
+
is_new_max = block_max > m
|
| 163 |
+
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
|
| 164 |
+
|
| 165 |
if label_smoothing > 0:
|
| 166 |
# scale X beforehand to avoid overflow
|
| 167 |
if HAS_WEIGHT:
|
|
|
|
| 192 |
# For 'sum' reduction, no normalization is applied:
|
| 193 |
# dx_y = softmax(x_y) - 1
|
| 194 |
# dx_i = softmax(x_i), for i ≠ y
|
| 195 |
+
if HAS_GRADIENTS:
|
| 196 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
| 197 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
| 198 |
+
X_block = tl.load(
|
| 199 |
+
X_ptr + X_offsets,
|
| 200 |
+
mask=X_offsets < n_cols,
|
| 201 |
+
other=float("-inf"),
|
| 202 |
+
# Ensure float32 precision for softmax calculation
|
| 203 |
+
).cast(tl.float32)
|
| 204 |
+
if HAS_SOFTCAPPING:
|
| 205 |
+
intermediate = tanh(X_block / softcap)
|
| 206 |
+
X_block = softcap * intermediate
|
| 207 |
+
|
| 208 |
+
if not HAS_WEIGHT:
|
| 209 |
+
# softmax(x_i)
|
| 210 |
+
X_block = tl.exp(X_block - m) / d
|
| 211 |
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
| 212 |
+
X_block += 2 * lse_square_scale * lse * X_block
|
| 213 |
+
# smoothing term
|
| 214 |
+
X_block += -eps
|
| 215 |
+
# special handle dx_y
|
| 216 |
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
| 217 |
+
# reduction scale
|
| 218 |
+
if reduction == "mean":
|
| 219 |
+
X_block = X_block / n_non_ignore
|
| 220 |
+
else:
|
| 221 |
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
| 222 |
+
softmax_X = tl.exp(X_block - m) / d
|
| 223 |
+
# derivative of original_loss
|
| 224 |
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
| 225 |
+
# specially handle dx_y
|
| 226 |
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
| 227 |
+
dloss_ori = dloss_ori * weight_y
|
| 228 |
+
# derivative of smooth_loss
|
| 229 |
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
| 230 |
+
# derivative of z-loss
|
| 231 |
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
| 232 |
+
# reduction scale
|
| 233 |
+
if reduction == "mean":
|
| 234 |
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
| 235 |
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
| 236 |
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
| 237 |
+
dz_loss = dz_loss / n_non_ignore
|
| 238 |
+
# derivative of total_loss
|
| 239 |
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
| 240 |
+
|
| 241 |
+
# chain rule softcapping
|
| 242 |
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
| 243 |
+
if HAS_SOFTCAPPING:
|
| 244 |
+
X_block = X_block * (1 - intermediate * intermediate)
|
| 245 |
+
|
| 246 |
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
| 247 |
|
| 248 |
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
| 249 |
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
|
|
| 291 |
tl.store(loss_ptr, loss)
|
| 292 |
if RETURN_Z_LOSS:
|
| 293 |
tl.store(z_loss_ptr, z_loss)
|
| 294 |
+
if RETURN_TOKEN_ACCURACY:
|
| 295 |
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
| 296 |
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
| 297 |
+
tl.store(token_accuracy_ptr, is_correct)
|
| 298 |
+
if RETURN_PREDICTED_TOKENS:
|
| 299 |
+
tl.store(predicted_tokens_ptr, argmax_idx)
|
| 300 |
|
| 301 |
|
| 302 |
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
| 303 |
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
| 304 |
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
| 305 |
+
# the best size we found by manually tuning on xpu and npu.
|
| 306 |
+
if infer_device() == "xpu":
|
| 307 |
+
MAX_FUSED_SIZE = 4096
|
| 308 |
+
elif infer_device() == "npu":
|
| 309 |
+
MAX_FUSED_SIZE = 2048
|
| 310 |
+
else:
|
| 311 |
+
MAX_FUSED_SIZE = 65536 // 2
|
| 312 |
|
| 313 |
|
| 314 |
def cross_entropy_forward(
|
|
|
|
| 321 |
reduction,
|
| 322 |
softcap,
|
| 323 |
return_z_loss,
|
| 324 |
+
return_token_accuracy=False,
|
| 325 |
+
return_predicted_tokens=False,
|
| 326 |
):
|
| 327 |
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
| 328 |
+
assert isinstance(return_token_accuracy, bool), (
|
| 329 |
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
| 330 |
+
)
|
| 331 |
+
assert isinstance(return_predicted_tokens, bool), (
|
| 332 |
+
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
|
| 333 |
+
)
|
| 334 |
|
| 335 |
BT, V = _input.shape
|
| 336 |
n_rows = BT
|
|
|
|
| 340 |
# unreduced loss
|
| 341 |
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
| 342 |
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
| 343 |
+
token_accuracy_1d = (
|
| 344 |
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
| 345 |
+
)
|
| 346 |
+
predicted_tokens_1d = (
|
| 347 |
+
torch.full((n_rows,), -1, dtype=torch.int64, device=_input.device) if return_predicted_tokens else None
|
| 348 |
+
)
|
| 349 |
|
| 350 |
target_mask = target != ignore_index
|
| 351 |
n_non_ignore = target_mask.sum().item()
|
|
|
|
| 382 |
loss_ptr=loss_1d,
|
| 383 |
z_loss_ptr=z_loss_1d,
|
| 384 |
loss_stride=loss_1d.stride(-1), # always 1
|
| 385 |
+
token_accuracy_ptr=token_accuracy_1d,
|
| 386 |
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
| 387 |
+
if return_token_accuracy
|
| 388 |
+
else 0, # always 1 if accuracy is enabled
|
| 389 |
+
predicted_tokens_ptr=predicted_tokens_1d,
|
| 390 |
+
predicted_tokens_stride=predicted_tokens_1d.stride(-1)
|
| 391 |
+
if return_predicted_tokens
|
| 392 |
+
else 0, # always 1 if predicted tokens is enabled
|
| 393 |
n_cols=V,
|
| 394 |
n_non_ignore=n_non_ignore,
|
| 395 |
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
|
|
| 400 |
reduction=reduction,
|
| 401 |
softcap=softcap,
|
| 402 |
RETURN_Z_LOSS=return_z_loss,
|
| 403 |
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
| 404 |
+
RETURN_PREDICTED_TOKENS=return_predicted_tokens,
|
| 405 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 406 |
HAS_WEIGHT=True if weight is not None else False,
|
| 407 |
HAS_SOFTCAPPING=True if softcap is not None else False,
|
| 408 |
+
HAS_GRADIENTS=_input.requires_grad,
|
| 409 |
# TODO: 32 seems to give the best performance
|
| 410 |
# Performance is quite sensitive to num_warps
|
| 411 |
num_warps=32 if not is_hip() else 16,
|
|
|
|
| 414 |
if reduction == "none":
|
| 415 |
loss = loss_1d
|
| 416 |
z_loss = z_loss_1d if return_z_loss else None
|
| 417 |
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
| 418 |
else:
|
| 419 |
loss = torch.sum(loss_1d)
|
| 420 |
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
| 421 |
+
# For accuracy, we compute the mean across all non-ignored tokens
|
| 422 |
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
| 423 |
|
| 424 |
+
predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
|
| 425 |
+
|
| 426 |
+
return loss, z_loss, token_accuracy, predicted_tokens, _input
|
| 427 |
|
| 428 |
|
| 429 |
def cross_entropy_backward(_input, grad_output):
|
|
|
|
| 471 |
reduction: str = "mean",
|
| 472 |
softcap: Optional[float] = None,
|
| 473 |
return_z_loss: bool = False,
|
| 474 |
+
return_token_accuracy: bool = False,
|
| 475 |
+
return_predicted_tokens: bool = False,
|
| 476 |
):
|
| 477 |
"""
|
| 478 |
The forward pass of the Liger Cross Entropy loss.
|
|
|
|
| 487 |
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
| 488 |
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
| 489 |
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
| 490 |
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False`
|
| 491 |
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
| 492 |
+
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
|
| 493 |
|
| 494 |
Returns:
|
| 495 |
+
tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested.
|
| 496 |
"""
|
| 497 |
+
input_requires_grad = _input.requires_grad
|
| 498 |
+
|
| 499 |
+
loss, z_loss, token_accuracy, predicted_tokens, _input = cross_entropy_forward(
|
| 500 |
_input,
|
| 501 |
target,
|
| 502 |
weight,
|
|
|
|
| 506 |
reduction,
|
| 507 |
softcap,
|
| 508 |
return_z_loss,
|
| 509 |
+
return_token_accuracy,
|
| 510 |
+
return_predicted_tokens,
|
| 511 |
)
|
| 512 |
# TODO: investigation
|
| 513 |
# If we don't detach the _input tensor, the memory will double
|
| 514 |
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
| 515 |
+
if input_requires_grad:
|
| 516 |
+
ctx.save_for_backward(_input.detach())
|
| 517 |
ctx.return_z_loss = return_z_loss
|
| 518 |
+
ctx.return_token_accuracy = return_token_accuracy
|
| 519 |
+
ctx.return_predicted_tokens = return_predicted_tokens
|
| 520 |
|
| 521 |
+
return loss, z_loss, token_accuracy, predicted_tokens
|
| 522 |
|
| 523 |
@staticmethod
|
| 524 |
+
def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
|
| 525 |
"""
|
| 526 |
The backward pass of the Liger Cross Entropy loss.
|
| 527 |
|
| 528 |
Parameters:
|
| 529 |
ctx : The context object with saved tensors.
|
| 530 |
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
| 531 |
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
| 532 |
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
| 533 |
+
grad_output4 (tensor): No use. Gradient for predicted_tokens (not used as predicted_tokens is only for metrics).
|
| 534 |
Returns:
|
| 535 |
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
| 536 |
"""
|
| 537 |
if ctx.return_z_loss:
|
| 538 |
+
del grad_output2 # z_loss is only for logging
|
| 539 |
+
if ctx.return_token_accuracy:
|
| 540 |
+
del grad_output3 # token_accuracy is only for metrics
|
| 541 |
+
if ctx.return_predicted_tokens:
|
| 542 |
+
del grad_output4 # predicted_tokens is only for metrics
|
| 543 |
|
| 544 |
(_input,) = ctx.saved_tensors
|
| 545 |
_input = cross_entropy_backward(_input, grad_output)
|
|
|
|
| 553 |
None,
|
| 554 |
None,
|
| 555 |
None,
|
| 556 |
+
None,
|
| 557 |
+
None,
|
| 558 |
+
)
|
build/torch-cuda/dyt.py
CHANGED
|
@@ -4,12 +4,13 @@ import torch
|
|
| 4 |
import triton
|
| 5 |
import triton.language as tl
|
| 6 |
|
| 7 |
-
from .utils import calculate_settings
|
| 8 |
from .utils import compare_version
|
| 9 |
from .utils import ensure_contiguous
|
|
|
|
| 10 |
from .utils import infer_device
|
|
|
|
| 11 |
|
| 12 |
-
if compare_version("triton", operator.ge, "3.0.0"):
|
| 13 |
try:
|
| 14 |
# typical import path with dispatch available
|
| 15 |
from triton.language.extra.libdevice import tanh
|
|
@@ -20,187 +21,131 @@ else:
|
|
| 20 |
from triton.language.math import tanh
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
@triton.jit
|
| 24 |
-
def _dyt_fwd_kernel(
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
@triton.jit
|
| 61 |
def _dyt_bwd_kernel(
|
| 62 |
-
|
| 63 |
-
x_row_stride,
|
| 64 |
-
dy_ptr,
|
| 65 |
-
dy_row_stride,
|
| 66 |
-
dx_ptr,
|
| 67 |
-
dx_row_stride,
|
| 68 |
-
alpha_ptr,
|
| 69 |
-
dalpha_ptr,
|
| 70 |
-
gamma_ptr,
|
| 71 |
-
dgamma_ptr,
|
| 72 |
-
dgamma_row_stride,
|
| 73 |
-
n_cols,
|
| 74 |
-
n_rows,
|
| 75 |
-
ROWS_PER_PROGRAM: tl.constexpr,
|
| 76 |
-
BLOCK_SIZE: tl.constexpr,
|
| 77 |
):
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
dalpha = 0.0
|
| 106 |
-
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
| 107 |
-
|
| 108 |
-
x_ptr += row_start * x_row_stride
|
| 109 |
-
dx_ptr += row_start * dx_row_stride
|
| 110 |
-
dy_ptr += row_start * dy_row_stride
|
| 111 |
-
alpha = tl.load(alpha_ptr)
|
| 112 |
-
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
| 113 |
-
|
| 114 |
-
for _ in tl.range(row_start, row_end):
|
| 115 |
-
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
| 116 |
-
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
| 117 |
-
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
| 118 |
-
sech2_ax = 1 - tanh_ax * tanh_ax
|
| 119 |
-
|
| 120 |
-
dx = dy * gamma * sech2_ax * alpha
|
| 121 |
-
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
| 122 |
-
dgamma += dy * tanh_ax
|
| 123 |
-
tl.store(dx_ptr + offsets, dx, mask=mask)
|
| 124 |
-
|
| 125 |
-
dy_ptr += dy_row_stride
|
| 126 |
-
x_ptr += x_row_stride
|
| 127 |
-
dx_ptr += dx_row_stride
|
| 128 |
-
|
| 129 |
-
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
| 130 |
-
tl.store(dalpha_ptr + pid, dalpha)
|
| 131 |
-
|
| 132 |
-
pass
|
| 133 |
|
| 134 |
|
| 135 |
def liger_dyt_fwd(x, alpha, gamma, beta):
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
y = torch.empty_like(x)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
BLOCK_SIZE=BLOCK_SIZE,
|
| 152 |
-
num_warps=num_warps,
|
| 153 |
)
|
| 154 |
-
return y.view(
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def liger_dyt_bwd(dy, x, alpha, gamma):
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
| 165 |
-
sm_count = 1
|
| 166 |
device = infer_device()
|
| 167 |
if device == "cuda":
|
| 168 |
-
|
| 169 |
elif device == "xpu":
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
dx = torch.empty_like(
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
dy_ptr=dy,
|
| 186 |
-
dy_row_stride=dy.stride(0),
|
| 187 |
-
dx_ptr=dx,
|
| 188 |
-
dx_row_stride=dx.stride(0),
|
| 189 |
-
alpha_ptr=alpha,
|
| 190 |
-
dalpha_ptr=_dalpha,
|
| 191 |
-
gamma_ptr=gamma,
|
| 192 |
-
dgamma_ptr=_dgamma,
|
| 193 |
-
dgamma_row_stride=_dgamma.stride(0),
|
| 194 |
-
n_cols=n_cols,
|
| 195 |
-
n_rows=n_rows,
|
| 196 |
-
ROWS_PER_PROGRAM=rows_per_program,
|
| 197 |
-
BLOCK_SIZE=BLOCK_SIZE,
|
| 198 |
-
num_warps=num_warps,
|
| 199 |
-
)
|
| 200 |
-
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
| 201 |
-
dgamma = _dgamma.sum(dim=0).to(dtype)
|
| 202 |
-
dbeta = dy.sum(dim=0).to(dtype)
|
| 203 |
-
return dx.view(*shape), dalpha, dgamma, dbeta
|
| 204 |
|
| 205 |
|
| 206 |
class LigerDyTFunction(torch.autograd.Function):
|
|
@@ -208,18 +153,12 @@ class LigerDyTFunction(torch.autograd.Function):
|
|
| 208 |
@ensure_contiguous
|
| 209 |
def forward(ctx, x, alpha, gamma, beta):
|
| 210 |
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
| 211 |
-
ctx.save_for_backward(x, alpha, gamma)
|
| 212 |
return y
|
| 213 |
|
| 214 |
@staticmethod
|
| 215 |
@ensure_contiguous
|
| 216 |
-
def backward(ctx,
|
| 217 |
-
x, alpha, gamma = ctx.saved_tensors
|
| 218 |
-
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
| 219 |
-
|
| 220 |
-
x,
|
| 221 |
-
alpha,
|
| 222 |
-
gamma,
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
return (dx, dalpha, dgamma, dbeta)
|
|
|
|
| 4 |
import triton
|
| 5 |
import triton.language as tl
|
| 6 |
|
|
|
|
| 7 |
from .utils import compare_version
|
| 8 |
from .utils import ensure_contiguous
|
| 9 |
+
from .utils import get_npu_core_count
|
| 10 |
from .utils import infer_device
|
| 11 |
+
from .utils import is_npu_available
|
| 12 |
|
| 13 |
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
| 14 |
try:
|
| 15 |
# typical import path with dispatch available
|
| 16 |
from triton.language.extra.libdevice import tanh
|
|
|
|
| 21 |
from triton.language.math import tanh
|
| 22 |
|
| 23 |
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({"BLOCK_N": bn}, num_stages=ns, num_warps=nw)
|
| 27 |
+
for bn in [1024, 2048, 4096]
|
| 28 |
+
for ns in [1, 2]
|
| 29 |
+
for nw in [4, 8, 16]
|
| 30 |
+
],
|
| 31 |
+
key=["N"],
|
| 32 |
+
)
|
| 33 |
@triton.jit
|
| 34 |
+
def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr):
|
| 35 |
+
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 36 |
+
mask = col < N
|
| 37 |
+
row_id = tl.cast(tl.program_id(1), tl.int64)
|
| 38 |
+
|
| 39 |
+
X += row_id * N
|
| 40 |
+
Y += row_id * N
|
| 41 |
+
alpha = tl.load(Alpha).to(tl.float32)
|
| 42 |
+
|
| 43 |
+
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
|
| 44 |
+
|
| 45 |
+
x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
|
| 46 |
+
|
| 47 |
+
tanh_x = tanh(alpha * x)
|
| 48 |
+
y = tanh_x * gamma
|
| 49 |
+
if HAVE_BETA:
|
| 50 |
+
beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
|
| 51 |
+
y += beta
|
| 52 |
+
tl.store(Y + col, y, mask=mask)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@triton.autotune(
|
| 56 |
+
configs=[
|
| 57 |
+
triton.Config({"BLOCK_N": bn}, num_stages=ns, num_warps=nw)
|
| 58 |
+
for bn in [1024, 2048, 4096]
|
| 59 |
+
for ns in [1, 2]
|
| 60 |
+
for nw in [4, 8, 16]
|
| 61 |
+
],
|
| 62 |
+
key=["N"],
|
| 63 |
+
# DA is indexed by program_id(0), so different BLOCK_N configs write to
|
| 64 |
+
# different slot counts per SM. Autotune trials don't zero outputs between
|
| 65 |
+
# runs, so stale slots from a prior trial would leak into da.sum(). Reset
|
| 66 |
+
# DA between trials to isolate each config's writes.
|
| 67 |
+
reset_to_zero=["DA"],
|
| 68 |
+
)
|
|
|
|
| 69 |
@triton.jit
|
| 70 |
def _dyt_bwd_kernel(
|
| 71 |
+
DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
):
|
| 73 |
+
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 74 |
+
mask = col < N
|
| 75 |
+
start_row_id = tl.cast(tl.program_id(1), tl.int64)
|
| 76 |
+
|
| 77 |
+
alpha = tl.load(Alpha).to(tl.float32)
|
| 78 |
+
da = 0.0
|
| 79 |
+
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
|
| 80 |
+
dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 81 |
+
if HAVE_BETA:
|
| 82 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
| 83 |
+
for row_id in range(start_row_id, M, tl.num_programs(1)):
|
| 84 |
+
x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
|
| 85 |
+
dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
|
| 86 |
+
tanh_x = tanh(alpha * x)
|
| 87 |
+
if HAVE_BETA:
|
| 88 |
+
db += dy
|
| 89 |
+
dg += dy * tanh_x
|
| 90 |
+
tmp = (1 - tanh_x * tanh_x) * dy * gamma
|
| 91 |
+
da += tl.sum(x * tmp, 0)
|
| 92 |
+
dx = alpha * tmp
|
| 93 |
+
tl.store(DX + row_id * N + col, dx, mask=mask)
|
| 94 |
+
|
| 95 |
+
tl.store(DG + start_row_id * N + col, dg, mask=mask)
|
| 96 |
+
if HAVE_BETA:
|
| 97 |
+
tl.store(DB + start_row_id * N + col, db, mask=mask)
|
| 98 |
+
tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def liger_dyt_fwd(x, alpha, gamma, beta):
|
| 102 |
+
assert x.is_contiguous()
|
| 103 |
+
HAVE_BETA = True if beta is not None else False
|
| 104 |
+
input_shape = x.shape
|
| 105 |
+
x = x.view(-1, input_shape[-1])
|
| 106 |
+
M, N = x.shape
|
| 107 |
+
|
| 108 |
y = torch.empty_like(x)
|
| 109 |
+
|
| 110 |
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
|
| 111 |
+
_dyt_fwd_kernel[grid](
|
| 112 |
+
x,
|
| 113 |
+
y,
|
| 114 |
+
alpha,
|
| 115 |
+
gamma,
|
| 116 |
+
beta,
|
| 117 |
+
HAVE_BETA,
|
| 118 |
+
N,
|
|
|
|
|
|
|
| 119 |
)
|
| 120 |
+
return y.view(input_shape)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
| 124 |
+
assert dy.is_contiguous()
|
| 125 |
+
input_shape = x.shape
|
| 126 |
+
x = x.view(-1, input_shape[-1])
|
| 127 |
+
M, N = x.shape
|
| 128 |
+
HAVE_BETA = True if beta is not None else False
|
| 129 |
+
|
|
|
|
|
|
|
| 130 |
device = infer_device()
|
| 131 |
if device == "cuda":
|
| 132 |
+
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
| 133 |
elif device == "xpu":
|
| 134 |
+
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
| 135 |
+
elif device == "npu":
|
| 136 |
+
NUM_SMS = get_npu_core_count()
|
| 137 |
+
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
| 138 |
+
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
| 139 |
+
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|
| 140 |
+
dx = torch.empty_like(dy)
|
| 141 |
+
|
| 142 |
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
|
| 143 |
+
_dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N)
|
| 144 |
+
if HAVE_BETA:
|
| 145 |
+
db = db.sum(0).to(x.dtype)
|
| 146 |
+
dg = dg.sum(0).to(gamma.dtype)
|
| 147 |
+
da = da.sum().to(x.dtype).unsqueeze(0)
|
| 148 |
+
return dx.view(input_shape), da, dg, db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
class LigerDyTFunction(torch.autograd.Function):
|
|
|
|
| 153 |
@ensure_contiguous
|
| 154 |
def forward(ctx, x, alpha, gamma, beta):
|
| 155 |
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
| 156 |
+
ctx.save_for_backward(x, alpha, gamma, beta)
|
| 157 |
return y
|
| 158 |
|
| 159 |
@staticmethod
|
| 160 |
@ensure_contiguous
|
| 161 |
+
def backward(ctx, dy):
|
| 162 |
+
x, alpha, gamma, beta = ctx.saved_tensors
|
| 163 |
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
|
| 164 |
+
return dx, dalpha, dgamma, dbeta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch-cuda/fused_linear_cross_entropy.py
CHANGED
|
@@ -6,11 +6,12 @@ from .utils import amp_custom_bwd
|
|
| 6 |
from .utils import amp_custom_fwd
|
| 7 |
from .utils import element_mul_kernel
|
| 8 |
from .utils import is_hip
|
|
|
|
| 9 |
|
| 10 |
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
| 11 |
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
| 12 |
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
| 13 |
-
MAX_FUSED_SIZE = 65536 // 2
|
| 14 |
|
| 15 |
|
| 16 |
def fused_linear_cross_entropy_forward(
|
|
@@ -25,10 +26,22 @@ def fused_linear_cross_entropy_forward(
|
|
| 25 |
reduction="mean",
|
| 26 |
softcap=None,
|
| 27 |
return_z_loss=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
):
|
| 29 |
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
device = _input.device
|
| 31 |
|
|
|
|
|
|
|
| 32 |
# inputs have shape: BT x H
|
| 33 |
# materialized activations will have shape: BT x V
|
| 34 |
# the increase in memory = BT x V
|
|
@@ -44,12 +57,24 @@ def fused_linear_cross_entropy_forward(
|
|
| 44 |
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
| 45 |
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
| 46 |
|
| 47 |
-
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
| 48 |
grad_input = torch.zeros_like(_input, device=device)
|
| 49 |
-
|
| 50 |
-
# we use fp32 for loss accumulator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
| 52 |
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
| 55 |
target_mask = target != ignore_index
|
|
@@ -82,9 +107,41 @@ def fused_linear_cross_entropy_forward(
|
|
| 82 |
|
| 83 |
n_rows = logits_chunk.shape[0]
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# unreduced loss
|
| 86 |
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
| 87 |
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# ensure _input and target are contiguous
|
| 90 |
logits_chunk = logits_chunk.contiguous()
|
|
@@ -100,6 +157,14 @@ def fused_linear_cross_entropy_forward(
|
|
| 100 |
loss_ptr=loss_1d_slice,
|
| 101 |
z_loss_ptr=z_loss_1d_slice,
|
| 102 |
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
n_cols=V,
|
| 104 |
n_non_ignore=total_n_non_ignore,
|
| 105 |
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -110,35 +175,46 @@ def fused_linear_cross_entropy_forward(
|
|
| 110 |
reduction=reduction,
|
| 111 |
softcap=softcap,
|
| 112 |
RETURN_Z_LOSS=return_z_loss,
|
|
|
|
|
|
|
| 113 |
HAS_WEIGHT=True if ce_weight is not None else False,
|
| 114 |
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
|
|
| 115 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 116 |
num_warps=32 if not is_hip() else 16,
|
| 117 |
)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
loss_1d[start_idx:end_idx] = loss_1d_slice
|
| 120 |
if return_z_loss:
|
| 121 |
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
grad_logits_chunk = logits_chunk # chunk_size x V
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
if
|
| 127 |
-
|
| 128 |
-
input=grad_weight,
|
| 129 |
-
mat1=logits_chunk.t().to(
|
| 130 |
-
_input_chunk.dtype
|
| 131 |
-
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
| 132 |
-
mat2=_input_chunk,
|
| 133 |
-
out=grad_weight,
|
| 134 |
-
alpha=1.0,
|
| 135 |
-
beta=1.0,
|
| 136 |
-
)
|
| 137 |
|
| 138 |
-
if
|
|
|
|
|
|
|
|
|
|
| 139 |
torch.add(
|
| 140 |
input=grad_bias,
|
| 141 |
-
other=
|
| 142 |
out=grad_bias,
|
| 143 |
alpha=1.0,
|
| 144 |
)
|
|
@@ -148,10 +224,24 @@ def fused_linear_cross_entropy_forward(
|
|
| 148 |
# loss = loss_1d
|
| 149 |
# z_loss = z_loss_1d if return_z_loss else None
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
else:
|
| 152 |
loss = torch.sum(loss_1d)
|
| 153 |
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
|
| 157 |
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -217,6 +307,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
| 217 |
reduction="mean",
|
| 218 |
softcap=None,
|
| 219 |
return_z_loss: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
):
|
| 221 |
"""
|
| 222 |
Fusing the last linear layer with cross-entropy loss
|
|
@@ -235,35 +329,54 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
| 235 |
ignore_index: the index to ignore in the target
|
| 236 |
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
| 237 |
reduction: reduction to apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
"""
|
| 239 |
|
| 240 |
-
loss, z_loss, grad_input, grad_weight, grad_bias =
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
)
|
| 253 |
# downcast to dtype and store for backward
|
| 254 |
ctx.save_for_backward(
|
| 255 |
grad_input.detach(),
|
| 256 |
grad_weight.detach() if grad_weight is not None else None,
|
| 257 |
-
grad_bias.detach() if
|
| 258 |
)
|
| 259 |
ctx.return_z_loss = return_z_loss
|
| 260 |
-
|
|
|
|
|
|
|
| 261 |
|
| 262 |
@staticmethod
|
| 263 |
@amp_custom_bwd
|
| 264 |
-
def backward(ctx, grad_output, grad_output2):
|
| 265 |
if ctx.return_z_loss:
|
| 266 |
del grad_output2 # z_loss is only for logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
| 268 |
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
| 269 |
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -280,4 +393,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
| 280 |
None,
|
| 281 |
None,
|
| 282 |
None,
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from .utils import amp_custom_fwd
|
| 7 |
from .utils import element_mul_kernel
|
| 8 |
from .utils import is_hip
|
| 9 |
+
from .utils import infer_device
|
| 10 |
|
| 11 |
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
| 12 |
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
| 13 |
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
| 14 |
+
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
|
| 15 |
|
| 16 |
|
| 17 |
def fused_linear_cross_entropy_forward(
|
|
|
|
| 26 |
reduction="mean",
|
| 27 |
softcap=None,
|
| 28 |
return_z_loss=False,
|
| 29 |
+
accum_dtype=None,
|
| 30 |
+
use_token_scaling=False,
|
| 31 |
+
return_token_accuracy=False,
|
| 32 |
+
return_predicted_tokens=False,
|
| 33 |
):
|
| 34 |
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
| 35 |
+
assert isinstance(return_token_accuracy, bool), (
|
| 36 |
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
| 37 |
+
)
|
| 38 |
+
assert isinstance(return_predicted_tokens, bool), (
|
| 39 |
+
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
|
| 40 |
+
)
|
| 41 |
device = _input.device
|
| 42 |
|
| 43 |
+
input_requires_grad = _input.requires_grad
|
| 44 |
+
|
| 45 |
# inputs have shape: BT x H
|
| 46 |
# materialized activations will have shape: BT x V
|
| 47 |
# the increase in memory = BT x V
|
|
|
|
| 57 |
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
| 58 |
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
| 59 |
|
|
|
|
| 60 |
grad_input = torch.zeros_like(_input, device=device)
|
| 61 |
+
|
| 62 |
+
# we use fp32 for loss and gradients accumulator
|
| 63 |
+
if input_requires_grad:
|
| 64 |
+
if accum_dtype is None:
|
| 65 |
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
| 66 |
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
| 67 |
+
else:
|
| 68 |
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
| 69 |
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
| 70 |
+
else:
|
| 71 |
+
grad_weight = None
|
| 72 |
+
grad_bias = None
|
| 73 |
+
|
| 74 |
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
| 75 |
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
| 76 |
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
| 77 |
+
predicted_tokens_1d = torch.full((BT,), -1, dtype=torch.int64, device=device) if return_predicted_tokens else None
|
| 78 |
|
| 79 |
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
| 80 |
target_mask = target != ignore_index
|
|
|
|
| 107 |
|
| 108 |
n_rows = logits_chunk.shape[0]
|
| 109 |
|
| 110 |
+
# Compute predicted probabilities for token scaling if needed
|
| 111 |
+
if use_token_scaling:
|
| 112 |
+
# Compute softmax probabilities for scaling
|
| 113 |
+
# We need to compute this before the cross entropy kernel modifies logits_chunk
|
| 114 |
+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
|
| 115 |
+
if softcap is not None:
|
| 116 |
+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
|
| 117 |
+
|
| 118 |
+
# Compute softmax to get predicted probabilities
|
| 119 |
+
probs = torch.softmax(logits_for_softmax, dim=-1)
|
| 120 |
+
|
| 121 |
+
# Get predicted probabilities for token scaling, handling ignored targets
|
| 122 |
+
valid_target_mask = target_chunk != ignore_index
|
| 123 |
+
valid_targets = target_chunk[valid_target_mask]
|
| 124 |
+
|
| 125 |
+
if len(valid_targets) > 0:
|
| 126 |
+
# Gather probabilities only for valid targets
|
| 127 |
+
valid_probs = probs[valid_target_mask]
|
| 128 |
+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
|
| 129 |
+
|
| 130 |
+
# Create full tensor with zeros for ignored targets
|
| 131 |
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
| 132 |
+
pred_probs[valid_target_mask] = pred_probs_valid
|
| 133 |
+
else:
|
| 134 |
+
# All targets are ignored
|
| 135 |
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
| 136 |
+
|
| 137 |
+
# Store the scaling factors
|
| 138 |
+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
| 139 |
+
|
| 140 |
# unreduced loss
|
| 141 |
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
| 142 |
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
| 143 |
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
| 144 |
+
predicted_tokens_1d_slice = predicted_tokens_1d[start_idx:end_idx] if return_predicted_tokens else None
|
| 145 |
|
| 146 |
# ensure _input and target are contiguous
|
| 147 |
logits_chunk = logits_chunk.contiguous()
|
|
|
|
| 157 |
loss_ptr=loss_1d_slice,
|
| 158 |
z_loss_ptr=z_loss_1d_slice,
|
| 159 |
loss_stride=loss_1d_slice.stride(-1), # always 1
|
| 160 |
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
| 161 |
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
| 162 |
+
if return_token_accuracy
|
| 163 |
+
else 0, # always 1 if accuracy is enabled
|
| 164 |
+
predicted_tokens_ptr=predicted_tokens_1d_slice,
|
| 165 |
+
predicted_tokens_stride=predicted_tokens_1d_slice.stride(-1)
|
| 166 |
+
if return_predicted_tokens
|
| 167 |
+
else 0, # always 1 if predicted tokens is enabled
|
| 168 |
n_cols=V,
|
| 169 |
n_non_ignore=total_n_non_ignore,
|
| 170 |
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
|
|
| 175 |
reduction=reduction,
|
| 176 |
softcap=softcap,
|
| 177 |
RETURN_Z_LOSS=return_z_loss,
|
| 178 |
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
| 179 |
+
RETURN_PREDICTED_TOKENS=return_predicted_tokens,
|
| 180 |
HAS_WEIGHT=True if ce_weight is not None else False,
|
| 181 |
HAS_SOFTCAPPING=True if softcap is not None else False,
|
| 182 |
+
HAS_GRADIENTS=input_requires_grad,
|
| 183 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 184 |
num_warps=32 if not is_hip() else 16,
|
| 185 |
)
|
| 186 |
|
| 187 |
+
# Apply token scaling if requested
|
| 188 |
+
if use_token_scaling:
|
| 189 |
+
loss_1d_slice = loss_1d_slice * scaling_factors
|
| 190 |
+
if return_z_loss:
|
| 191 |
+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
|
| 192 |
+
|
| 193 |
loss_1d[start_idx:end_idx] = loss_1d_slice
|
| 194 |
if return_z_loss:
|
| 195 |
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
| 196 |
+
if return_token_accuracy:
|
| 197 |
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
| 198 |
+
if return_predicted_tokens:
|
| 199 |
+
predicted_tokens_1d[start_idx:end_idx] = predicted_tokens_1d_slice
|
| 200 |
grad_logits_chunk = logits_chunk # chunk_size x V
|
| 201 |
|
| 202 |
+
# Apply token scaling to gradients if requested
|
| 203 |
+
if use_token_scaling:
|
| 204 |
+
# Expand scaling factors to match gradient dimensions
|
| 205 |
+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
| 206 |
+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
| 207 |
|
| 208 |
+
if input_requires_grad:
|
| 209 |
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
if grad_weight is not None and input_requires_grad:
|
| 212 |
+
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
| 213 |
+
|
| 214 |
+
if bias is not None and input_requires_grad:
|
| 215 |
torch.add(
|
| 216 |
input=grad_bias,
|
| 217 |
+
other=grad_logits_chunk.sum(dim=0),
|
| 218 |
out=grad_bias,
|
| 219 |
alpha=1.0,
|
| 220 |
)
|
|
|
|
| 224 |
# loss = loss_1d
|
| 225 |
# z_loss = z_loss_1d if return_z_loss else None
|
| 226 |
|
| 227 |
+
if reduction == "none":
|
| 228 |
+
# Return per-token losses
|
| 229 |
+
loss = loss_1d
|
| 230 |
+
z_loss = z_loss_1d if return_z_loss else None
|
| 231 |
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
| 232 |
else:
|
| 233 |
loss = torch.sum(loss_1d)
|
| 234 |
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
| 235 |
+
# For accuracy, we compute the mean across all non-ignored tokens
|
| 236 |
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
| 237 |
+
|
| 238 |
+
predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
|
| 239 |
+
|
| 240 |
+
# Cast back to original dtype
|
| 241 |
+
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
| 242 |
+
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
| 243 |
+
|
| 244 |
+
return loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias
|
| 245 |
|
| 246 |
|
| 247 |
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
|
|
| 307 |
reduction="mean",
|
| 308 |
softcap=None,
|
| 309 |
return_z_loss: bool = False,
|
| 310 |
+
accum_dtype=None,
|
| 311 |
+
use_token_scaling: bool = False,
|
| 312 |
+
return_token_accuracy: bool = False,
|
| 313 |
+
return_predicted_tokens: bool = False,
|
| 314 |
):
|
| 315 |
"""
|
| 316 |
Fusing the last linear layer with cross-entropy loss
|
|
|
|
| 329 |
ignore_index: the index to ignore in the target
|
| 330 |
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
| 331 |
reduction: reduction to apply
|
| 332 |
+
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
|
| 333 |
+
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
|
| 334 |
+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
| 335 |
+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
| 336 |
+
Default: False.
|
| 337 |
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
| 338 |
+
return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
|
| 339 |
"""
|
| 340 |
|
| 341 |
+
loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias = (
|
| 342 |
+
fused_linear_cross_entropy_forward(
|
| 343 |
+
_input=_input,
|
| 344 |
+
weight=weight,
|
| 345 |
+
target=target,
|
| 346 |
+
bias=bias,
|
| 347 |
+
ce_weight=ce_weight,
|
| 348 |
+
ignore_index=ignore_index,
|
| 349 |
+
lse_square_scale=lse_square_scale,
|
| 350 |
+
label_smoothing=label_smoothing,
|
| 351 |
+
reduction=reduction,
|
| 352 |
+
softcap=softcap,
|
| 353 |
+
return_z_loss=return_z_loss,
|
| 354 |
+
accum_dtype=accum_dtype,
|
| 355 |
+
use_token_scaling=use_token_scaling,
|
| 356 |
+
return_token_accuracy=return_token_accuracy,
|
| 357 |
+
return_predicted_tokens=return_predicted_tokens,
|
| 358 |
+
)
|
| 359 |
)
|
| 360 |
# downcast to dtype and store for backward
|
| 361 |
ctx.save_for_backward(
|
| 362 |
grad_input.detach(),
|
| 363 |
grad_weight.detach() if grad_weight is not None else None,
|
| 364 |
+
grad_bias.detach() if grad_bias is not None else None,
|
| 365 |
)
|
| 366 |
ctx.return_z_loss = return_z_loss
|
| 367 |
+
ctx.return_token_accuracy = return_token_accuracy
|
| 368 |
+
ctx.return_predicted_tokens = return_predicted_tokens
|
| 369 |
+
return loss, z_loss, token_accuracy, predicted_tokens
|
| 370 |
|
| 371 |
@staticmethod
|
| 372 |
@amp_custom_bwd
|
| 373 |
+
def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
|
| 374 |
if ctx.return_z_loss:
|
| 375 |
del grad_output2 # z_loss is only for logging
|
| 376 |
+
if ctx.return_token_accuracy:
|
| 377 |
+
del grad_output3 # token_accuracy is only for metrics
|
| 378 |
+
if ctx.return_predicted_tokens:
|
| 379 |
+
del grad_output4 # predicted_tokens is only for metrics
|
| 380 |
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
| 381 |
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
| 382 |
grad_output, grad_input, grad_weight, grad_bias
|
|
|
|
| 393 |
None,
|
| 394 |
None,
|
| 395 |
None,
|
| 396 |
+
None,
|
| 397 |
+
None, # use_token_scaling
|
| 398 |
+
None, # return_token_accuracy
|
| 399 |
+
None, # return_predicted_tokens
|
| 400 |
+
)
|
build/torch-cuda/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
| 7 |
from .utils import calculate_settings
|
| 8 |
from .utils import compare_version
|
| 9 |
from .utils import ensure_contiguous
|
|
|
|
| 10 |
|
| 11 |
-
if compare_version("triton", operator.ge, "3.0.0"):
|
| 12 |
try:
|
| 13 |
# typical import path with dispatch available
|
| 14 |
from triton.language.extra.libdevice import tanh
|
|
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
|
|
| 40 |
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
| 41 |
tanh_result = tanh(tanh_arg)
|
| 42 |
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
| 43 |
-
c_row = geglu_a * b_row
|
| 44 |
tl.store(c + col_offsets, c_row, mask=mask)
|
| 45 |
|
| 46 |
|
|
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
| 66 |
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
| 67 |
tanh_result = tanh(tanh_arg)
|
| 68 |
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
|
|
| 69 |
|
| 70 |
-
db_row = dc_row * geglu_a
|
| 71 |
|
| 72 |
# Gradient w.r.t. a can be computed with:
|
| 73 |
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
| 78 |
da_row = dc_row * b_row * (term1 + term2)
|
| 79 |
|
| 80 |
tl.store(a + col_offsets, da_row, mask=mask)
|
| 81 |
-
tl.store(b + col_offsets, db_row, mask=mask)
|
| 82 |
|
| 83 |
|
| 84 |
def geglu_forward(a, b):
|
|
@@ -138,4 +140,4 @@ class LigerGELUMulFunction(torch.autograd.Function):
|
|
| 138 |
def backward(ctx, dc):
|
| 139 |
a, b = ctx.saved_tensors
|
| 140 |
a, b = geglu_backward(a, b, dc)
|
| 141 |
-
return a, b
|
|
|
|
| 7 |
from .utils import calculate_settings
|
| 8 |
from .utils import compare_version
|
| 9 |
from .utils import ensure_contiguous
|
| 10 |
+
from .utils import is_npu_available
|
| 11 |
|
| 12 |
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
| 13 |
try:
|
| 14 |
# typical import path with dispatch available
|
| 15 |
from triton.language.extra.libdevice import tanh
|
|
|
|
| 41 |
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
| 42 |
tanh_result = tanh(tanh_arg)
|
| 43 |
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
| 44 |
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
| 45 |
tl.store(c + col_offsets, c_row, mask=mask)
|
| 46 |
|
| 47 |
|
|
|
|
| 67 |
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
| 68 |
tanh_result = tanh(tanh_arg)
|
| 69 |
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
| 70 |
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
| 71 |
|
| 72 |
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
| 73 |
|
| 74 |
# Gradient w.r.t. a can be computed with:
|
| 75 |
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
|
|
| 80 |
da_row = dc_row * b_row * (term1 + term2)
|
| 81 |
|
| 82 |
tl.store(a + col_offsets, da_row, mask=mask)
|
| 83 |
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
| 84 |
|
| 85 |
|
| 86 |
def geglu_forward(a, b):
|
|
|
|
| 140 |
def backward(ctx, dc):
|
| 141 |
a, b = ctx.saved_tensors
|
| 142 |
a, b = geglu_backward(a, b, dc)
|
| 143 |
+
return a, b
|
build/torch-cuda/group_norm.py
CHANGED
|
@@ -6,8 +6,10 @@ import triton.language as tl
|
|
| 6 |
|
| 7 |
from .utils import compare_version
|
| 8 |
from .utils import ensure_contiguous
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
if compare_version("triton", operator.ge, "3.0.0"):
|
| 11 |
try:
|
| 12 |
# typical import path with dispatch available
|
| 13 |
from triton.language.extra.libdevice import rsqrt
|
|
@@ -17,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0"):
|
|
| 17 |
else:
|
| 18 |
from triton.language.math import rsqrt
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
@triton.jit
|
|
@@ -72,20 +77,21 @@ def _group_norm_forward_kernel(
|
|
| 72 |
# 1/std
|
| 73 |
rstd = rsqrt(variance + eps)
|
| 74 |
|
| 75 |
-
# Normalize
|
|
|
|
|
|
|
| 76 |
hidden_size_per_channel = hidden_size // channels_per_group
|
| 77 |
-
for
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
Y_ptr += hidden_size_per_channel
|
| 89 |
|
| 90 |
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
| 91 |
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
@@ -302,4 +308,4 @@ class LigerGroupNormFunction(torch.autograd.Function):
|
|
| 302 |
def backward(ctx, dY):
|
| 303 |
X, W, B, Mean, RSTD = ctx.saved_tensors
|
| 304 |
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
| 305 |
-
return DX, DW, DB, None, None, None
|
|
|
|
| 6 |
|
| 7 |
from .utils import compare_version
|
| 8 |
from .utils import ensure_contiguous
|
| 9 |
+
from .utils import infer_device
|
| 10 |
+
from .utils import is_npu_available
|
| 11 |
|
| 12 |
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
| 13 |
try:
|
| 14 |
# typical import path with dispatch available
|
| 15 |
from triton.language.extra.libdevice import rsqrt
|
|
|
|
| 19 |
else:
|
| 20 |
from triton.language.math import rsqrt
|
| 21 |
|
| 22 |
+
if infer_device() == "npu":
|
| 23 |
+
MAX_FUSED_SIZE = 16384 # 8192
|
| 24 |
+
else:
|
| 25 |
+
MAX_FUSED_SIZE = 65536
|
| 26 |
|
| 27 |
|
| 28 |
@triton.jit
|
|
|
|
| 77 |
# 1/std
|
| 78 |
rstd = rsqrt(variance + eps)
|
| 79 |
|
| 80 |
+
# Normalize — flat loop over full hidden_size (not per-channel)
|
| 81 |
+
# This avoids the nested channel × per_channel_hidden loop where
|
| 82 |
+
# BLOCK_SIZE >> hidden_size_per_channel causes massive padding waste.
|
| 83 |
hidden_size_per_channel = hidden_size // channels_per_group
|
| 84 |
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
| 85 |
+
hidden_size_offsets = i + block_range
|
| 86 |
+
mask = hidden_size_offsets < hidden_size
|
| 87 |
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
| 88 |
+
# Determine which channel each element belongs to, then load W/B
|
| 89 |
+
local_channel = hidden_size_offsets // hidden_size_per_channel
|
| 90 |
+
global_channel = group_idx * channels_per_group + local_channel
|
| 91 |
+
W = tl.load(W_ptr + global_channel, mask=mask)
|
| 92 |
+
B = tl.load(B_ptr + global_channel, mask=mask)
|
| 93 |
+
Y = (X - m) * rstd * W + B
|
| 94 |
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
|
|
| 95 |
|
| 96 |
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
| 97 |
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
|
|
| 308 |
def backward(ctx, dY):
|
| 309 |
X, W, B, Mean, RSTD = ctx.saved_tensors
|
| 310 |
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
| 311 |
+
return DX, DW, DB, None, None, None
|
build/torch-cuda/jsd.py
CHANGED
|
@@ -198,4 +198,4 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
| 198 |
None,
|
| 199 |
None,
|
| 200 |
None,
|
| 201 |
-
)
|
|
|
|
| 198 |
None,
|
| 199 |
None,
|
| 200 |
None,
|
| 201 |
+
)
|
build/torch-cuda/kl_div.py
CHANGED
|
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
|
|
| 21 |
return num_warps
|
| 22 |
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
| 27 |
|
|
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
|
|
| 116 |
|
| 117 |
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
| 118 |
BT, V = y_pred.shape
|
| 119 |
-
BLOCK_SIZE = (
|
| 120 |
-
min(8192, triton.next_power_of_2(V))
|
| 121 |
-
if infer_device() == "xpu"
|
| 122 |
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
| 123 |
-
)
|
| 124 |
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
| 125 |
|
| 126 |
grid = (BT,)
|
|
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
| 159 |
|
| 160 |
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
| 161 |
BT, V = target.shape
|
| 162 |
-
BLOCK_SIZE = (
|
| 163 |
-
min(8192, triton.next_power_of_2(V))
|
| 164 |
-
if infer_device() == "xpu"
|
| 165 |
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
| 166 |
-
)
|
| 167 |
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
| 168 |
|
| 169 |
grid = (BT,)
|
|
@@ -259,4 +256,4 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
| 259 |
None,
|
| 260 |
None,
|
| 261 |
None,
|
| 262 |
-
)
|
|
|
|
| 21 |
return num_warps
|
| 22 |
|
| 23 |
|
| 24 |
+
if infer_device() == "xpu":
|
| 25 |
+
MAX_FUSED_SIZE = 8192
|
| 26 |
+
elif infer_device() == "npu":
|
| 27 |
+
MAX_FUSED_SIZE = 8192
|
| 28 |
+
else:
|
| 29 |
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
| 30 |
|
| 31 |
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
| 32 |
|
|
|
|
| 121 |
|
| 122 |
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
| 123 |
BT, V = y_pred.shape
|
| 124 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
| 126 |
|
| 127 |
grid = (BT,)
|
|
|
|
| 160 |
|
| 161 |
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
| 162 |
BT, V = target.shape
|
| 163 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
| 165 |
|
| 166 |
grid = (BT,)
|
|
|
|
| 256 |
None,
|
| 257 |
None,
|
| 258 |
None,
|
| 259 |
+
)
|
build/torch-cuda/layer_norm.py
CHANGED
|
@@ -8,8 +8,11 @@ import triton.language as tl
|
|
| 8 |
from .utils import calculate_settings
|
| 9 |
from .utils import compare_version
|
| 10 |
from .utils import ensure_contiguous
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
if compare_version("triton", operator.ge, "3.0.0"):
|
| 13 |
try:
|
| 14 |
# typical import path with dispatch available
|
| 15 |
from triton.language.extra.libdevice import rsqrt
|
|
@@ -43,111 +46,151 @@ def _layer_norm_forward_kernel(
|
|
| 43 |
https://arxiv.org/abs/1607.06450
|
| 44 |
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
| 45 |
"""
|
| 46 |
-
row_idx = tl.program_id(0)
|
| 47 |
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 48 |
mask = col_offsets < n_cols
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
rstd = rsqrt(var + eps)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
tl.store(
|
|
|
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
@triton.jit
|
| 73 |
def _layer_norm_backward_kernel(
|
| 74 |
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
|
|
| 75 |
W_ptr, # pointer to weights, shape (n_cols,)
|
| 76 |
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
|
|
| 77 |
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
|
|
| 78 |
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
| 79 |
-
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
| 80 |
-
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
| 81 |
-
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
| 82 |
-
stride_x, # stride of each row in input
|
| 83 |
stride_dx, # stride of each row in input grad
|
|
|
|
| 84 |
stride_dw, # stride of each row in weights grad
|
|
|
|
| 85 |
stride_db, # stride of each row in bias grad
|
|
|
|
| 86 |
stride_dy, # stride of each row in output grad
|
| 87 |
n_rows,
|
| 88 |
n_cols,
|
| 89 |
rows_per_program: tl.constexpr,
|
| 90 |
BLOCK_SIZE: tl.constexpr,
|
| 91 |
-
dtype: tl.constexpr,
|
| 92 |
):
|
| 93 |
"""
|
| 94 |
References:
|
| 95 |
https://arxiv.org/abs/1607.06450
|
| 96 |
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
| 97 |
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 98 |
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
| 99 |
"""
|
| 100 |
-
row_block_id = tl.program_id(0)
|
| 101 |
row_start = row_block_id * rows_per_program
|
| 102 |
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
| 103 |
cols = tl.arange(0, BLOCK_SIZE)
|
| 104 |
mask = cols < n_cols
|
| 105 |
|
| 106 |
-
|
| 107 |
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
| 125 |
c2 = tl.sum(wdy, axis=0) / n_cols
|
| 126 |
-
dx = (wdy - (x_hat * c1 + c2)) *
|
| 127 |
-
tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
-
tl.store(DW_ptr + row_block_id * stride_dw + cols,
|
| 139 |
-
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row
|
| 140 |
|
| 141 |
|
| 142 |
def layer_norm_forward(X, W, B, eps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
shape = X.shape
|
| 144 |
dim = shape[-1]
|
| 145 |
X = X.view(-1, dim)
|
| 146 |
n_rows, n_cols = X.shape
|
|
|
|
|
|
|
| 147 |
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
|
|
|
|
|
| 148 |
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
| 149 |
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
| 150 |
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
|
|
|
|
|
| 151 |
if X.shape[1] != W.shape[0]:
|
| 152 |
raise ValueError(
|
| 153 |
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
@@ -157,9 +200,11 @@ def layer_norm_forward(X, W, B, eps):
|
|
| 157 |
# XPU-specific optimization
|
| 158 |
kernel_args = {}
|
| 159 |
if X.device.type == "xpu":
|
| 160 |
-
kernel_args
|
| 161 |
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
Y,
|
| 164 |
Y.stride(0),
|
| 165 |
X,
|
|
@@ -176,12 +221,25 @@ def layer_norm_forward(X, W, B, eps):
|
|
| 176 |
eps,
|
| 177 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 178 |
num_warps=num_warps,
|
| 179 |
-
**kernel_args,
|
| 180 |
)
|
|
|
|
| 181 |
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
| 182 |
|
| 183 |
|
| 184 |
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
shape = dY.shape
|
| 186 |
dim = shape[-1]
|
| 187 |
dY = dY.view(-1, dim)
|
|
@@ -192,60 +250,57 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
| 192 |
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
| 193 |
elif X.device.type == "xpu":
|
| 194 |
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
_DW = torch.empty((sm_count, n_cols), dtype=
|
| 198 |
-
_DB = torch.empty((sm_count, n_cols), dtype=
|
| 199 |
|
|
|
|
| 200 |
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
| 201 |
if n_cols > BLOCK_SIZE:
|
| 202 |
-
raise RuntimeError(
|
| 203 |
-
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
rows_per_program = math.ceil(n_rows / sm_count)
|
| 207 |
grid = (sm_count,)
|
| 208 |
-
triton_dtype = (
|
| 209 |
-
tl.float32
|
| 210 |
-
if X.dtype == torch.float32
|
| 211 |
-
else tl.bfloat16
|
| 212 |
-
if X.dtype == torch.bfloat16
|
| 213 |
-
else tl.float16
|
| 214 |
-
if X.dtype == torch.float16
|
| 215 |
-
else tl.float32 # fallback to float32 for other types
|
| 216 |
-
)
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
# XPU-specific optimization
|
| 219 |
-
kernel_args = {}
|
| 220 |
if X.device.type == "xpu":
|
| 221 |
-
kernel_args.update({"
|
|
|
|
| 222 |
|
|
|
|
| 223 |
_layer_norm_backward_kernel[grid](
|
| 224 |
X,
|
|
|
|
| 225 |
W,
|
| 226 |
Mean,
|
|
|
|
| 227 |
RSTD,
|
|
|
|
| 228 |
DX,
|
| 229 |
-
_DW,
|
| 230 |
-
_DB,
|
| 231 |
-
dY,
|
| 232 |
-
X.stride(0),
|
| 233 |
DX.stride(0),
|
|
|
|
| 234 |
_DW.stride(0),
|
|
|
|
| 235 |
_DB.stride(0),
|
|
|
|
| 236 |
dY.stride(0),
|
| 237 |
n_rows,
|
| 238 |
n_cols,
|
| 239 |
-
rows_per_program,
|
| 240 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 241 |
-
|
| 242 |
-
**kernel_args, # XPU-specific optimization
|
| 243 |
)
|
| 244 |
|
|
|
|
| 245 |
DW = _DW.sum(dim=0).to(W.dtype)
|
| 246 |
-
DB = _DB.sum(dim=0).to(
|
| 247 |
|
| 248 |
-
DX = DX.view(*shape)
|
| 249 |
return DX, DW, DB
|
| 250 |
|
| 251 |
|
|
@@ -262,4 +317,4 @@ class LigerLayerNormFunction(torch.autograd.Function):
|
|
| 262 |
def backward(ctx, dY):
|
| 263 |
X, W, B, Mean, RSTD = ctx.saved_tensors
|
| 264 |
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
|
| 265 |
-
return DX, DW, DB, None
|
|
|
|
| 8 |
from .utils import calculate_settings
|
| 9 |
from .utils import compare_version
|
| 10 |
from .utils import ensure_contiguous
|
| 11 |
+
from .utils import get_npu_core_count
|
| 12 |
+
from .utils import set_large_grf_mode
|
| 13 |
+
from .utils import is_npu_available
|
| 14 |
|
| 15 |
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
| 16 |
try:
|
| 17 |
# typical import path with dispatch available
|
| 18 |
from triton.language.extra.libdevice import rsqrt
|
|
|
|
| 46 |
https://arxiv.org/abs/1607.06450
|
| 47 |
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
| 48 |
"""
|
| 49 |
+
row_idx = tl.program_id(0).to(tl.int64)
|
| 50 |
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 51 |
mask = col_offsets < n_cols
|
| 52 |
|
| 53 |
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
| 54 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
| 55 |
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
| 56 |
+
W_f32 = W_row.to(tl.float32)
|
| 57 |
+
B_f32 = B_row.to(tl.float32)
|
| 58 |
+
|
| 59 |
+
# Calculate pointers for this row
|
| 60 |
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
| 61 |
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
| 62 |
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
| 63 |
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
| 64 |
+
|
| 65 |
+
# Load input data and convert to fp32 for numerical stability
|
| 66 |
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
| 67 |
+
X_f32 = X_row.to(tl.float32)
|
| 68 |
+
|
| 69 |
+
# Compute statistics in fp32 for numerical stability
|
| 70 |
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
| 71 |
+
X_centered = X_f32 - mean
|
| 72 |
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
| 73 |
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
| 74 |
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
| 75 |
rstd = rsqrt(var + eps)
|
| 76 |
|
| 77 |
+
# Store statistics (convert back to original dtype only once)
|
| 78 |
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
| 79 |
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
| 80 |
|
| 81 |
+
# Fused normalization and affine transformation
|
| 82 |
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
| 83 |
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
| 84 |
|
| 85 |
+
# Store output (single conversion back to original dtype)
|
| 86 |
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
| 87 |
|
| 88 |
|
| 89 |
@triton.jit
|
| 90 |
def _layer_norm_backward_kernel(
|
| 91 |
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
| 92 |
+
stride_x, # stride of each row in input
|
| 93 |
W_ptr, # pointer to weights, shape (n_cols,)
|
| 94 |
Mean_ptr, # pointer to mean, shape (n_rows,)
|
| 95 |
+
stride_mean, # stride of each row in mean
|
| 96 |
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
| 97 |
+
stride_rstd, # stride of each row in rstd
|
| 98 |
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
stride_dx, # stride of each row in input grad
|
| 100 |
+
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
| 101 |
stride_dw, # stride of each row in weights grad
|
| 102 |
+
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
| 103 |
stride_db, # stride of each row in bias grad
|
| 104 |
+
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
| 105 |
stride_dy, # stride of each row in output grad
|
| 106 |
n_rows,
|
| 107 |
n_cols,
|
| 108 |
rows_per_program: tl.constexpr,
|
| 109 |
BLOCK_SIZE: tl.constexpr,
|
|
|
|
| 110 |
):
|
| 111 |
"""
|
| 112 |
References:
|
| 113 |
https://arxiv.org/abs/1607.06450
|
| 114 |
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
|
|
|
|
|
| 115 |
"""
|
| 116 |
+
row_block_id = tl.program_id(0).to(tl.int64)
|
| 117 |
row_start = row_block_id * rows_per_program
|
| 118 |
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
| 119 |
cols = tl.arange(0, BLOCK_SIZE)
|
| 120 |
mask = cols < n_cols
|
| 121 |
|
| 122 |
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
| 123 |
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
| 124 |
|
| 125 |
+
# Pre-load weights once (same optimization as forward pass)
|
| 126 |
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
| 127 |
+
w_f32 = w.to(tl.float32)
|
| 128 |
+
|
| 129 |
+
for row_idx in range(row_start, row_end):
|
| 130 |
+
# Calculate pointers for this specific row
|
| 131 |
+
row_X_ptr = X_ptr + row_idx * stride_x
|
| 132 |
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
| 133 |
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
| 134 |
+
row_Mean_ptr = Mean_ptr + row_idx * stride_mean
|
| 135 |
+
row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
|
| 136 |
+
|
| 137 |
+
# Load data for this row
|
| 138 |
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
| 139 |
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
| 140 |
+
mean = tl.load(row_Mean_ptr)
|
| 141 |
+
rstd = tl.load(row_RSTD_ptr)
|
| 142 |
+
|
| 143 |
+
# Convert to fp32 for numerical stability
|
| 144 |
+
x_f32 = x.to(tl.float32)
|
| 145 |
+
dy_f32 = dy.to(tl.float32)
|
| 146 |
+
mean_f32 = mean.to(tl.float32)
|
| 147 |
+
rstd_f32 = rstd.to(tl.float32)
|
| 148 |
+
|
| 149 |
+
# Compute backward pass for this row
|
| 150 |
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
| 151 |
+
wdy = w_f32 * dy_f32
|
| 152 |
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
| 153 |
c2 = tl.sum(wdy, axis=0) / n_cols
|
| 154 |
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
|
|
| 155 |
|
| 156 |
+
# Store input gradient
|
| 157 |
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
| 158 |
|
| 159 |
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
| 160 |
+
dw = dy_f32 * x_hat
|
| 161 |
+
db = dy_f32
|
| 162 |
+
dW_row += dw
|
| 163 |
+
db_row += db
|
| 164 |
|
| 165 |
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
| 166 |
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
| 167 |
|
| 168 |
|
| 169 |
def layer_norm_forward(X, W, B, eps):
|
| 170 |
+
"""
|
| 171 |
+
Args:
|
| 172 |
+
X: Input tensor of shape (..., hidden_size)
|
| 173 |
+
W: Weight tensor of shape (hidden_size,)
|
| 174 |
+
B: Bias tensor of shape (hidden_size,)
|
| 175 |
+
eps: Small constant for numerical stability
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
| 179 |
+
"""
|
| 180 |
shape = X.shape
|
| 181 |
dim = shape[-1]
|
| 182 |
X = X.view(-1, dim)
|
| 183 |
n_rows, n_cols = X.shape
|
| 184 |
+
|
| 185 |
+
# Calculate optimal block size and warp configuration
|
| 186 |
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
| 187 |
+
|
| 188 |
+
# Allocate output tensors
|
| 189 |
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
| 190 |
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
| 191 |
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
| 192 |
+
|
| 193 |
+
# Validate input dimensions
|
| 194 |
if X.shape[1] != W.shape[0]:
|
| 195 |
raise ValueError(
|
| 196 |
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
|
|
| 200 |
# XPU-specific optimization
|
| 201 |
kernel_args = {}
|
| 202 |
if X.device.type == "xpu":
|
| 203 |
+
set_large_grf_mode(kernel_args)
|
| 204 |
|
| 205 |
+
# Launch kernel with one thread block per row for optimal performance
|
| 206 |
+
grid = (n_rows,)
|
| 207 |
+
_layer_norm_forward_kernel[grid](
|
| 208 |
Y,
|
| 209 |
Y.stride(0),
|
| 210 |
X,
|
|
|
|
| 221 |
eps,
|
| 222 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 223 |
num_warps=num_warps,
|
| 224 |
+
**kernel_args,
|
| 225 |
)
|
| 226 |
+
|
| 227 |
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
| 228 |
|
| 229 |
|
| 230 |
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
| 231 |
+
"""
|
| 232 |
+
Args:
|
| 233 |
+
dY: Gradient of output
|
| 234 |
+
X: Input tensor
|
| 235 |
+
W: Weight tensor
|
| 236 |
+
B: Bias tensor
|
| 237 |
+
Mean: Pre-computed mean
|
| 238 |
+
RSTD: Pre-computed reciprocal standard deviation
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Tuple of (input_grad, weight_grad, bias_grad)
|
| 242 |
+
"""
|
| 243 |
shape = dY.shape
|
| 244 |
dim = shape[-1]
|
| 245 |
dY = dY.view(-1, dim)
|
|
|
|
| 250 |
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
| 251 |
elif X.device.type == "xpu":
|
| 252 |
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
| 253 |
+
elif X.device.type == "npu":
|
| 254 |
+
sm_count = get_npu_core_count()
|
| 255 |
|
| 256 |
+
# fp32 for numerical stability especially.
|
| 257 |
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
| 258 |
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
| 259 |
|
| 260 |
+
# Calculate optimal block size and warp configuration
|
| 261 |
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
| 262 |
if n_cols > BLOCK_SIZE:
|
| 263 |
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
|
|
|
|
|
|
|
|
| 264 |
rows_per_program = math.ceil(n_rows / sm_count)
|
| 265 |
grid = (sm_count,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
# Allocate gradient tensors
|
| 268 |
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
| 269 |
+
|
| 270 |
+
kernel_args = {"num_warps": num_warps}
|
| 271 |
# XPU-specific optimization
|
|
|
|
| 272 |
if X.device.type == "xpu":
|
| 273 |
+
kernel_args.update({"num_warps": 32, "num_stages": 4})
|
| 274 |
+
set_large_grf_mode(kernel_args)
|
| 275 |
|
| 276 |
+
# Launch kernel with one thread block per row for optimal performance
|
| 277 |
_layer_norm_backward_kernel[grid](
|
| 278 |
X,
|
| 279 |
+
X.stride(0),
|
| 280 |
W,
|
| 281 |
Mean,
|
| 282 |
+
Mean.stride(0),
|
| 283 |
RSTD,
|
| 284 |
+
RSTD.stride(0),
|
| 285 |
DX,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
DX.stride(0),
|
| 287 |
+
_DW,
|
| 288 |
_DW.stride(0),
|
| 289 |
+
_DB,
|
| 290 |
_DB.stride(0),
|
| 291 |
+
dY,
|
| 292 |
dY.stride(0),
|
| 293 |
n_rows,
|
| 294 |
n_cols,
|
| 295 |
+
rows_per_program=rows_per_program,
|
| 296 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 297 |
+
**kernel_args,
|
|
|
|
| 298 |
)
|
| 299 |
|
| 300 |
+
DX = DX.view(*shape)
|
| 301 |
DW = _DW.sum(dim=0).to(W.dtype)
|
| 302 |
+
DB = _DB.sum(dim=0).to(B.dtype)
|
| 303 |
|
|
|
|
| 304 |
return DX, DW, DB
|
| 305 |
|
| 306 |
|
|
|
|
| 317 |
def backward(ctx, dY):
|
| 318 |
X, W, B, Mean, RSTD = ctx.saved_tensors
|
| 319 |
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
|
| 320 |
+
return DX, DW, DB, None
|
build/torch-cuda/layers.py
CHANGED
|
@@ -1,39 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from .rms_norm import LigerRMSNormFunction
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
class LigerRMSNorm(
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
return LigerRMSNormFunction.apply(
|
| 31 |
-
hidden_states,
|
| 32 |
-
self.weight,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
self.variance_epsilon,
|
| 34 |
-
0,
|
| 35 |
-
"llama",
|
| 36 |
-
True
|
| 37 |
)
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from .cross_entropy import LigerCrossEntropyFunction
|
| 9 |
+
from .dyt import LigerDyTFunction
|
| 10 |
+
from .fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
| 11 |
+
from .geglu import LigerGELUMulFunction
|
| 12 |
+
from .group_norm import LigerGroupNormFunction
|
| 13 |
+
from .jsd import LigerJSDFunction
|
| 14 |
+
from .kl_div import LigerKLDivLossFunction
|
| 15 |
+
from .layer_norm import LigerLayerNormFunction
|
| 16 |
+
from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
| 17 |
from .rms_norm import LigerRMSNormFunction
|
| 18 |
+
from .rope import LigerRopeFunction
|
| 19 |
+
from .swiglu import LigerSiLUMulFunction
|
| 20 |
+
from .tvd import LigerTVDLossFunction
|
| 21 |
+
|
| 22 |
|
| 23 |
+
class LigerRMSNorm(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
hidden_size: int,
|
| 27 |
+
eps: float = 1e-6,
|
| 28 |
+
offset: float = 0.0,
|
| 29 |
+
casting_mode: str = "llama",
|
| 30 |
+
init_fn: str = "ones",
|
| 31 |
+
in_place: bool = True,
|
| 32 |
+
row_mode: Optional[bool] = None,
|
| 33 |
+
elementwise_affine: bool = True,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert init_fn in ("ones", "zeros"), f"init_fn must be 'ones' or 'zeros', got {init_fn}"
|
| 37 |
+
self.hidden_size = hidden_size
|
| 38 |
+
self.variance_epsilon = eps
|
| 39 |
+
self.offset = offset
|
| 40 |
+
self.casting_mode = casting_mode
|
| 41 |
+
self.in_place = in_place
|
| 42 |
+
self.row_mode = row_mode
|
| 43 |
+
self.elementwise_affine = elementwise_affine
|
| 44 |
+
if elementwise_affine:
|
| 45 |
+
init = torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
|
| 46 |
+
self.weight = nn.Parameter(init)
|
| 47 |
+
else:
|
| 48 |
+
self.register_parameter("weight", None)
|
| 49 |
+
|
| 50 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 51 |
return LigerRMSNormFunction.apply(
|
| 52 |
+
hidden_states,
|
| 53 |
+
self.weight,
|
| 54 |
+
self.variance_epsilon,
|
| 55 |
+
self.offset,
|
| 56 |
+
self.casting_mode,
|
| 57 |
+
self.in_place,
|
| 58 |
+
self.row_mode,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def extra_repr(self) -> str:
|
| 62 |
+
return (
|
| 63 |
+
f"{self.hidden_size}, eps={self.variance_epsilon}, offset={self.offset}, "
|
| 64 |
+
f"in_place={self.in_place}, row_mode={self.row_mode}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LigerLayerNorm(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
hidden_size: int,
|
| 72 |
+
eps: float = 1e-6,
|
| 73 |
+
bias: bool = False,
|
| 74 |
+
init_fn: str = "ones",
|
| 75 |
+
):
|
| 76 |
+
super().__init__()
|
| 77 |
+
assert init_fn in ("ones", "zeros"), f"init_fn must be 'ones' or 'zeros', got {init_fn}"
|
| 78 |
+
self.hidden_size = hidden_size
|
| 79 |
+
self.variance_epsilon = eps
|
| 80 |
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
| 81 |
+
self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
|
| 82 |
+
|
| 83 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
|
| 85 |
+
|
| 86 |
+
def extra_repr(self) -> str:
|
| 87 |
+
return f"{self.hidden_size}, eps={self.variance_epsilon}"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class LigerGroupNorm(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
num_channels: int,
|
| 94 |
+
num_groups: int,
|
| 95 |
+
eps: float = 1e-6,
|
| 96 |
+
bias: bool = False,
|
| 97 |
+
init_fn: str = "ones",
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
assert init_fn in ("ones", "zeros"), f"init_fn must be 'ones' or 'zeros', got {init_fn}"
|
| 101 |
+
assert num_channels % num_groups == 0, (
|
| 102 |
+
f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
|
| 103 |
+
)
|
| 104 |
+
self.num_channels = num_channels
|
| 105 |
+
self.num_groups = num_groups
|
| 106 |
+
self.variance_epsilon = eps
|
| 107 |
+
self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
|
| 108 |
+
self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
|
| 109 |
+
|
| 110 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
assert hidden_states.dim() >= 3, f"Input must have at least 3 dimensions, got {hidden_states.dim()}"
|
| 112 |
+
assert hidden_states.size(1) == self.num_channels, (
|
| 113 |
+
f"Input must have {self.num_channels} channels, got {hidden_states.size(1)}"
|
| 114 |
+
)
|
| 115 |
+
return LigerGroupNormFunction.apply(
|
| 116 |
+
hidden_states,
|
| 117 |
+
self.weight,
|
| 118 |
+
self.bias,
|
| 119 |
+
self.num_channels,
|
| 120 |
+
self.num_groups,
|
| 121 |
self.variance_epsilon,
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
+
|
| 124 |
+
def extra_repr(self) -> str:
|
| 125 |
+
return f"num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.variance_epsilon}"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class LigerDyT(nn.Module):
|
| 129 |
+
def __init__(self, hidden_size: int, beta: bool = True, init_alpha: float = 0.5):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.hidden_size = hidden_size
|
| 132 |
+
self.init_alpha = init_alpha
|
| 133 |
+
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
| 134 |
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
| 135 |
+
self.beta = nn.Parameter(torch.zeros(hidden_size)) if beta else None
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
| 139 |
+
|
| 140 |
+
def extra_repr(self) -> str:
|
| 141 |
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta is not None}"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class LigerCrossEntropyLoss(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
weight: Optional[torch.Tensor] = None,
|
| 148 |
+
ignore_index: int = -100,
|
| 149 |
+
lse_square_scale: float = 0.0,
|
| 150 |
+
label_smoothing: float = 0.0,
|
| 151 |
+
reduction: str = "mean",
|
| 152 |
+
softcap: Optional[float] = None,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
assert 0.0 <= label_smoothing <= 1.0, f"label_smoothing must be in [0, 1], got {label_smoothing}"
|
| 156 |
+
assert reduction in ("mean", "sum", "none"), f"reduction must be 'mean', 'sum', or 'none', got {reduction}"
|
| 157 |
+
assert softcap is None or softcap > 0, f"softcap must be > 0 or None, got {softcap}"
|
| 158 |
+
self.weight = weight
|
| 159 |
+
self.ignore_index = ignore_index
|
| 160 |
+
self.lse_square_scale = lse_square_scale
|
| 161 |
+
self.label_smoothing = label_smoothing
|
| 162 |
+
self.reduction = reduction
|
| 163 |
+
self.softcap = softcap
|
| 164 |
+
|
| 165 |
+
def forward(self, _input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
loss, _, _, _ = LigerCrossEntropyFunction.apply(
|
| 167 |
+
_input,
|
| 168 |
+
target,
|
| 169 |
+
self.weight,
|
| 170 |
+
self.ignore_index,
|
| 171 |
+
self.lse_square_scale,
|
| 172 |
+
self.label_smoothing,
|
| 173 |
+
self.reduction,
|
| 174 |
+
self.softcap,
|
| 175 |
+
False,
|
| 176 |
+
False,
|
| 177 |
+
False,
|
| 178 |
+
)
|
| 179 |
+
return loss
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class LigerFusedLinearCrossEntropyLoss(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
ce_weight: Optional[torch.Tensor] = None,
|
| 186 |
+
ignore_index: int = -100,
|
| 187 |
+
lse_square_scale: float = 0.0,
|
| 188 |
+
label_smoothing: float = 0.0,
|
| 189 |
+
reduction: str = "mean",
|
| 190 |
+
softcap: Optional[float] = None,
|
| 191 |
+
accum_dtype: Optional[torch.dtype] = None,
|
| 192 |
+
use_token_scaling: bool = False,
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
assert 0.0 <= label_smoothing <= 1.0, f"label_smoothing must be in [0, 1], got {label_smoothing}"
|
| 196 |
+
assert reduction in ("mean", "sum", "none"), f"reduction must be 'mean', 'sum', or 'none', got {reduction}"
|
| 197 |
+
assert softcap is None or softcap > 0, f"softcap must be > 0 or None, got {softcap}"
|
| 198 |
+
self.ce_weight = ce_weight
|
| 199 |
+
self.ignore_index = ignore_index
|
| 200 |
+
self.lse_square_scale = lse_square_scale
|
| 201 |
+
self.label_smoothing = label_smoothing
|
| 202 |
+
self.reduction = reduction
|
| 203 |
+
self.softcap = softcap
|
| 204 |
+
self.accum_dtype = accum_dtype
|
| 205 |
+
self.use_token_scaling = use_token_scaling
|
| 206 |
+
|
| 207 |
+
def forward(
|
| 208 |
+
self,
|
| 209 |
+
lin_weight: torch.Tensor,
|
| 210 |
+
_input: torch.Tensor,
|
| 211 |
+
target: torch.Tensor,
|
| 212 |
+
bias: Optional[torch.Tensor] = None,
|
| 213 |
+
) -> torch.Tensor:
|
| 214 |
+
loss, _, _, _ = LigerFusedLinearCrossEntropyFunction.apply(
|
| 215 |
+
_input,
|
| 216 |
+
lin_weight,
|
| 217 |
+
target,
|
| 218 |
+
bias,
|
| 219 |
+
self.ce_weight,
|
| 220 |
+
self.ignore_index,
|
| 221 |
+
self.lse_square_scale,
|
| 222 |
+
self.label_smoothing,
|
| 223 |
+
self.reduction,
|
| 224 |
+
self.softcap,
|
| 225 |
+
False,
|
| 226 |
+
self.accum_dtype,
|
| 227 |
+
self.use_token_scaling,
|
| 228 |
+
False,
|
| 229 |
+
False,
|
| 230 |
+
)
|
| 231 |
+
return loss
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class LigerJSD(nn.Module):
|
| 235 |
+
def __init__(self, beta: float = 0.5, ignore_index: int = -100):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.beta = beta
|
| 238 |
+
self.ignore_index = ignore_index
|
| 239 |
+
|
| 240 |
+
def forward(
|
| 241 |
+
self,
|
| 242 |
+
log_q: torch.Tensor,
|
| 243 |
+
log_p: torch.Tensor,
|
| 244 |
+
shift_labels: Optional[torch.Tensor] = None,
|
| 245 |
+
) -> torch.Tensor:
|
| 246 |
+
return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class LigerKLDIVLoss(nn.KLDivLoss):
|
| 250 |
+
def __init__(self, eps: float = 1e-10, *args, **kwargs):
|
| 251 |
+
super().__init__(*args, **kwargs)
|
| 252 |
+
self.eps = eps
|
| 253 |
+
|
| 254 |
+
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
|
| 255 |
+
return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class LigerTVDLoss(nn.Module):
|
| 259 |
+
def __init__(self, reduction: str = "batchmean", ignore_index: int = -100):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.reduction = reduction
|
| 262 |
+
self.ignore_index = ignore_index
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
p: torch.Tensor,
|
| 267 |
+
q: torch.Tensor,
|
| 268 |
+
shift_labels: Optional[torch.Tensor] = None,
|
| 269 |
+
) -> torch.Tensor:
|
| 270 |
+
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class LigerSwiGLUMLP(nn.Module):
|
| 274 |
+
"""SwiGLU MLP block. ``config`` must expose ``hidden_size``, ``intermediate_size``,
|
| 275 |
+
and ``hidden_act`` (must be ``silu`` or ``swish``)."""
|
| 276 |
+
|
| 277 |
+
def __init__(self, config):
|
| 278 |
+
super().__init__()
|
| 279 |
+
if config.hidden_act not in ("silu", "swish"):
|
| 280 |
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
| 281 |
+
self.config = config
|
| 282 |
+
self.hidden_size = config.hidden_size
|
| 283 |
+
self.intermediate_size = config.intermediate_size
|
| 284 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 285 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 286 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 287 |
+
|
| 288 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 289 |
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class LigerGEGLUMLP(nn.Module):
|
| 293 |
+
"""GEGLU MLP block. ``config`` must expose ``hidden_size`` and ``intermediate_size``.
|
| 294 |
+
Uses the tanh approximation of GELU (matches Gemma 1/1.1/2)."""
|
| 295 |
+
|
| 296 |
+
def __init__(self, config):
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.config = config
|
| 299 |
+
self.hidden_size = config.hidden_size
|
| 300 |
+
self.intermediate_size = config.intermediate_size
|
| 301 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 302 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 303 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 304 |
+
|
| 305 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 306 |
+
return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@dataclass
|
| 310 |
+
class CrossEntropyOutput:
|
| 311 |
+
loss: torch.Tensor
|
| 312 |
+
z_loss: Optional[torch.Tensor] = None
|
| 313 |
+
token_accuracy: Optional[torch.Tensor] = None
|
| 314 |
+
predicted_tokens: Optional[torch.Tensor] = None
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def liger_fused_linear_cross_entropy(
|
| 318 |
+
input: torch.Tensor,
|
| 319 |
+
weight: torch.Tensor,
|
| 320 |
+
target: torch.Tensor,
|
| 321 |
+
bias: Optional[torch.Tensor] = None,
|
| 322 |
+
ce_weight: Optional[torch.Tensor] = None,
|
| 323 |
+
ignore_index: int = -100,
|
| 324 |
+
lse_square_scale: float = 0.0,
|
| 325 |
+
label_smoothing: float = 0.0,
|
| 326 |
+
reduction: str = "mean",
|
| 327 |
+
softcap: Optional[float] = None,
|
| 328 |
+
return_z_loss: bool = False,
|
| 329 |
+
accum_dtype: Optional[torch.dtype] = None,
|
| 330 |
+
use_token_scaling: bool = False,
|
| 331 |
+
return_token_accuracy: bool = False,
|
| 332 |
+
return_predicted_tokens: bool = False,
|
| 333 |
+
):
|
| 334 |
+
loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply(
|
| 335 |
+
input,
|
| 336 |
+
weight,
|
| 337 |
+
target,
|
| 338 |
+
bias,
|
| 339 |
+
ce_weight,
|
| 340 |
+
ignore_index,
|
| 341 |
+
lse_square_scale,
|
| 342 |
+
label_smoothing,
|
| 343 |
+
reduction,
|
| 344 |
+
softcap,
|
| 345 |
+
return_z_loss,
|
| 346 |
+
accum_dtype,
|
| 347 |
+
use_token_scaling,
|
| 348 |
+
return_token_accuracy,
|
| 349 |
+
return_predicted_tokens,
|
| 350 |
+
)
|
| 351 |
+
if not return_z_loss and not return_token_accuracy and not return_predicted_tokens:
|
| 352 |
+
return loss
|
| 353 |
+
return CrossEntropyOutput(
|
| 354 |
+
loss=loss,
|
| 355 |
+
z_loss=z_loss,
|
| 356 |
+
token_accuracy=token_accuracy,
|
| 357 |
+
predicted_tokens=predicted_tokens,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def LigerForCausalLMLoss(
|
| 362 |
+
hidden_states: torch.Tensor,
|
| 363 |
+
lm_head_weight: torch.Tensor,
|
| 364 |
+
labels: torch.Tensor,
|
| 365 |
+
hidden_size: int,
|
| 366 |
+
num_items_in_batch: Optional[int] = None,
|
| 367 |
+
ignore_index: int = -100,
|
| 368 |
+
shift_labels: Optional[torch.Tensor] = None,
|
| 369 |
+
final_logit_softcapping: Optional[float] = None,
|
| 370 |
+
return_token_accuracy: bool = False,
|
| 371 |
+
return_predicted_tokens: bool = False,
|
| 372 |
+
**kwargs,
|
| 373 |
+
):
|
| 374 |
+
"""Drop-in replacement for ``transformers.loss.ForCausalLMLoss`` that fuses the
|
| 375 |
+
final ``lm_head`` projection with the cross-entropy loss. Returns a scalar
|
| 376 |
+
``loss`` by default; returns a :class:`CrossEntropyOutput` when
|
| 377 |
+
``return_token_accuracy`` or ``return_predicted_tokens`` is set."""
|
| 378 |
+
applicable_params = inspect.signature(liger_fused_linear_cross_entropy).parameters
|
| 379 |
+
kwargs = {k: v for k, v in kwargs.items() if k in applicable_params}
|
| 380 |
+
|
| 381 |
+
if shift_labels is None:
|
| 382 |
+
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
| 383 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 384 |
+
|
| 385 |
+
hidden_states = hidden_states.view(-1, hidden_size)
|
| 386 |
+
shift_labels = shift_labels.view(-1).to(hidden_states.device)
|
| 387 |
+
|
| 388 |
+
reduction = "sum" if num_items_in_batch is not None else "mean"
|
| 389 |
+
result = liger_fused_linear_cross_entropy(
|
| 390 |
+
hidden_states,
|
| 391 |
+
lm_head_weight,
|
| 392 |
+
shift_labels,
|
| 393 |
+
reduction=reduction,
|
| 394 |
+
ignore_index=ignore_index,
|
| 395 |
+
softcap=final_logit_softcapping,
|
| 396 |
+
return_token_accuracy=return_token_accuracy,
|
| 397 |
+
return_predicted_tokens=return_predicted_tokens,
|
| 398 |
+
**kwargs,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if isinstance(result, CrossEntropyOutput):
|
| 402 |
+
loss = result.loss
|
| 403 |
+
token_accuracy = result.token_accuracy
|
| 404 |
+
predicted_tokens = result.predicted_tokens
|
| 405 |
+
else:
|
| 406 |
+
loss = result
|
| 407 |
+
token_accuracy = None
|
| 408 |
+
predicted_tokens = None
|
| 409 |
+
|
| 410 |
+
if reduction == "sum":
|
| 411 |
+
loss = loss / num_items_in_batch
|
| 412 |
+
|
| 413 |
+
if return_token_accuracy or return_predicted_tokens:
|
| 414 |
+
return CrossEntropyOutput(
|
| 415 |
+
loss=loss,
|
| 416 |
+
token_accuracy=token_accuracy,
|
| 417 |
+
predicted_tokens=predicted_tokens,
|
| 418 |
+
)
|
| 419 |
+
return loss
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def liger_rotary_pos_emb(
|
| 423 |
+
q: torch.Tensor,
|
| 424 |
+
k: torch.Tensor,
|
| 425 |
+
cos: torch.Tensor,
|
| 426 |
+
sin: torch.Tensor,
|
| 427 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 428 |
+
unsqueeze_dim: int = 1,
|
| 429 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 430 |
+
"""Apply standard rotary positional embedding to ``q`` and ``k``."""
|
| 431 |
+
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def liger_multimodal_rotary_pos_emb(
|
| 435 |
+
q: torch.Tensor,
|
| 436 |
+
k: torch.Tensor,
|
| 437 |
+
cos: torch.Tensor,
|
| 438 |
+
sin: torch.Tensor,
|
| 439 |
+
mrope_section,
|
| 440 |
+
unsqueeze_dim: int = 1,
|
| 441 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 442 |
+
"""Apply Qwen2-VL multimodal rotary positional embedding (M-RoPE) to ``q`` and ``k``."""
|
| 443 |
+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
__all__ = [
|
| 447 |
+
"LigerRMSNorm",
|
| 448 |
+
"LigerLayerNorm",
|
| 449 |
+
"LigerGroupNorm",
|
| 450 |
+
"LigerDyT",
|
| 451 |
+
"LigerCrossEntropyLoss",
|
| 452 |
+
"LigerFusedLinearCrossEntropyLoss",
|
| 453 |
+
"LigerJSD",
|
| 454 |
+
"LigerKLDIVLoss",
|
| 455 |
+
"LigerTVDLoss",
|
| 456 |
+
"LigerSwiGLUMLP",
|
| 457 |
+
"LigerGEGLUMLP",
|
| 458 |
+
"CrossEntropyOutput",
|
| 459 |
+
"liger_fused_linear_cross_entropy",
|
| 460 |
+
"LigerForCausalLMLoss",
|
| 461 |
+
"liger_rotary_pos_emb",
|
| 462 |
+
"liger_multimodal_rotary_pos_emb",
|
| 463 |
+
]
|
build/torch-cuda/metadata.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"name": "liger-kernels",
|
| 3 |
-
"id": "
|
| 4 |
"version": 1,
|
| 5 |
"license": "BSD-2-Clause",
|
| 6 |
"python-depends": [],
|
|
|
|
| 1 |
{
|
| 2 |
"name": "liger-kernels",
|
| 3 |
+
"id": "_liger_kernels_cuda_08b4d53",
|
| 4 |
"version": 1,
|
| 5 |
"license": "BSD-2-Clause",
|
| 6 |
"python-depends": [],
|
build/torch-cuda/qwen2vl_mrope.py
CHANGED
|
@@ -219,4 +219,4 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
| 219 |
cos, sin = ctx.saved_tensors
|
| 220 |
mrope_section = ctx.mrope_section
|
| 221 |
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
| 222 |
-
return dq, dk, None, None, None, None
|
|
|
|
| 219 |
cos, sin = ctx.saved_tensors
|
| 220 |
mrope_section = ctx.mrope_section
|
| 221 |
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
| 222 |
+
return dq, dk, None, None, None, None
|
build/torch-cuda/rms_norm.py
CHANGED
|
@@ -20,9 +20,12 @@ import triton.language as tl
|
|
| 20 |
from .utils import calculate_settings
|
| 21 |
from .utils import compare_version
|
| 22 |
from .utils import ensure_contiguous
|
|
|
|
|
|
|
| 23 |
from .utils import torch_to_triton_dtype
|
|
|
|
| 24 |
|
| 25 |
-
if compare_version("triton", operator.ge, "3.0.0"):
|
| 26 |
try:
|
| 27 |
# typical import path with dispatch available
|
| 28 |
from triton.language.extra.libdevice import rsqrt
|
|
@@ -52,6 +55,7 @@ def _rms_norm_forward_kernel(
|
|
| 52 |
eps,
|
| 53 |
offset,
|
| 54 |
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
|
|
| 55 |
BLOCK_SIZE: tl.constexpr,
|
| 56 |
):
|
| 57 |
"""
|
|
@@ -63,17 +67,18 @@ def _rms_norm_forward_kernel(
|
|
| 63 |
3. https://arxiv.org/pdf/1910.07467
|
| 64 |
"""
|
| 65 |
|
| 66 |
-
row_idx = tl.program_id(0)
|
| 67 |
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 68 |
mask = col_offsets < n_cols
|
| 69 |
|
| 70 |
-
Y_ptr +
|
| 71 |
-
X_ptr +
|
| 72 |
-
RSTD_ptr +
|
| 73 |
|
| 74 |
-
X_row = tl.load(
|
| 75 |
X_row_dtype = X_row.dtype
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
# On Llama, only rstd is computed on fp32
|
| 79 |
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -81,7 +86,8 @@ def _rms_norm_forward_kernel(
|
|
| 81 |
|
| 82 |
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
| 83 |
if casting_mode == _CASTING_MODE_GEMMA:
|
| 84 |
-
|
|
|
|
| 85 |
X_row = X_row.to(tl.float32)
|
| 86 |
|
| 87 |
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -94,7 +100,7 @@ def _rms_norm_forward_kernel(
|
|
| 94 |
# We can save time by caching rms with minimal memory overhead
|
| 95 |
# because rms is much smaller compared to X_row, as rms is for each row.
|
| 96 |
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
| 97 |
-
tl.store(
|
| 98 |
|
| 99 |
X_row = X_row * rstd
|
| 100 |
|
|
@@ -102,12 +108,15 @@ def _rms_norm_forward_kernel(
|
|
| 102 |
if casting_mode == _CASTING_MODE_LLAMA:
|
| 103 |
X_row = X_row.to(X_row_dtype)
|
| 104 |
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
if casting_mode == _CASTING_MODE_GEMMA:
|
| 108 |
Y_row = Y_row.to(X_row_dtype)
|
| 109 |
|
| 110 |
-
tl.store(
|
| 111 |
|
| 112 |
|
| 113 |
@triton.jit
|
|
@@ -128,8 +137,9 @@ def _rms_norm_backward_kernel(
|
|
| 128 |
n_rows,
|
| 129 |
n_cols,
|
| 130 |
offset,
|
| 131 |
-
rows_per_program
|
| 132 |
casting_mode: tl.constexpr,
|
|
|
|
| 133 |
BLOCK_SIZE: tl.constexpr,
|
| 134 |
):
|
| 135 |
"""
|
|
@@ -137,61 +147,256 @@ def _rms_norm_backward_kernel(
|
|
| 137 |
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
| 138 |
"""
|
| 139 |
|
| 140 |
-
row_block_id = tl.program_id(0)
|
| 141 |
row_start = row_block_id * rows_per_program
|
| 142 |
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
| 143 |
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 144 |
mask = col_offsets < n_cols
|
| 145 |
|
| 146 |
-
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
| 160 |
|
| 161 |
# Get cached rms
|
| 162 |
-
rstd_row = tl.load(
|
| 163 |
|
| 164 |
X_row = X_row.to(tl.float32)
|
| 165 |
|
| 166 |
# Different bacward graphs for different casting modes
|
| 167 |
if casting_mode == _CASTING_MODE_LLAMA:
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
elif casting_mode == _CASTING_MODE_GEMMA:
|
| 171 |
dY_row = dY_row.to(tl.float32)
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
| 173 |
else:
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
dX_row = rstd_row * m
|
| 177 |
|
| 178 |
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if casting_mode == _CASTING_MODE_LLAMA:
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
else:
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
_str_to_casting_mode = {
|
|
@@ -201,7 +406,7 @@ _str_to_casting_mode = {
|
|
| 201 |
}
|
| 202 |
|
| 203 |
|
| 204 |
-
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
| 205 |
if not isinstance(casting_mode, int):
|
| 206 |
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
| 207 |
casting_mode = _str_to_casting_mode[casting_mode]
|
|
@@ -220,34 +425,64 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
| 220 |
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
| 221 |
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
# XPU-specific optimization
|
| 227 |
kernel_args = {}
|
| 228 |
if X.device.type == "xpu":
|
| 229 |
-
kernel_args
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
| 248 |
|
| 249 |
|
| 250 |
-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
| 251 |
shape = dY.shape
|
| 252 |
dim = shape[-1]
|
| 253 |
dY = dY.view(-1, dim)
|
|
@@ -258,9 +493,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
| 258 |
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
| 259 |
elif X.device.type == "xpu":
|
| 260 |
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
if n_cols > BLOCK_SIZE:
|
| 266 |
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -275,33 +517,65 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
| 275 |
# XPU-specific optimization
|
| 276 |
kernel_args = {}
|
| 277 |
if X.device.type == "xpu":
|
| 278 |
-
kernel_args
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
dX = dX.view(*shape)
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
return dX, dW
|
| 307 |
|
|
@@ -330,18 +604,30 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
| 330 |
|
| 331 |
@staticmethod
|
| 332 |
@ensure_contiguous
|
| 333 |
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
| 334 |
"""
|
| 335 |
X: (B, T, H) or (BxT, H)
|
| 336 |
W: (H,)
|
| 337 |
"""
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
ctx.offset = offset
|
| 340 |
ctx.casting_mode = casting_mode
|
| 341 |
ctx.in_place = in_place
|
|
|
|
| 342 |
ctx.BLOCK_SIZE = BLOCK_SIZE
|
| 343 |
ctx.num_warps = num_warps
|
| 344 |
-
ctx.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
return Y
|
| 346 |
|
| 347 |
@staticmethod
|
|
@@ -350,16 +636,19 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
| 350 |
"""
|
| 351 |
Y: (B, T, H) or (BxT, H)
|
| 352 |
"""
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
dX, dW = rms_norm_backward(
|
| 355 |
-
dY,
|
| 356 |
-
X,
|
| 357 |
-
W,
|
| 358 |
-
RSTD,
|
| 359 |
-
ctx.offset,
|
| 360 |
-
ctx.casting_mode,
|
| 361 |
-
ctx.BLOCK_SIZE,
|
| 362 |
-
ctx.num_warps,
|
| 363 |
-
ctx.in_place,
|
| 364 |
)
|
| 365 |
-
return dX, dW, None, None, None, None
|
|
|
|
| 20 |
from .utils import calculate_settings
|
| 21 |
from .utils import compare_version
|
| 22 |
from .utils import ensure_contiguous
|
| 23 |
+
from .utils import get_npu_core_count
|
| 24 |
+
from .utils import set_large_grf_mode
|
| 25 |
from .utils import torch_to_triton_dtype
|
| 26 |
+
from .utils import is_npu_available
|
| 27 |
|
| 28 |
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
| 29 |
try:
|
| 30 |
# typical import path with dispatch available
|
| 31 |
from triton.language.extra.libdevice import rsqrt
|
|
|
|
| 55 |
eps,
|
| 56 |
offset,
|
| 57 |
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
| 58 |
+
elementwise_affine: tl.constexpr,
|
| 59 |
BLOCK_SIZE: tl.constexpr,
|
| 60 |
):
|
| 61 |
"""
|
|
|
|
| 67 |
3. https://arxiv.org/pdf/1910.07467
|
| 68 |
"""
|
| 69 |
|
| 70 |
+
row_idx = tl.program_id(0).to(tl.int64)
|
| 71 |
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 72 |
mask = col_offsets < n_cols
|
| 73 |
|
| 74 |
+
y_base = Y_ptr + row_idx * Y_row_stride
|
| 75 |
+
x_base = X_ptr + row_idx * X_row_stride
|
| 76 |
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
| 77 |
|
| 78 |
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
|
| 79 |
X_row_dtype = X_row.dtype
|
| 80 |
+
if elementwise_affine:
|
| 81 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
| 82 |
|
| 83 |
# On Llama, only rstd is computed on fp32
|
| 84 |
if casting_mode == _CASTING_MODE_LLAMA:
|
|
|
|
| 86 |
|
| 87 |
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
| 88 |
if casting_mode == _CASTING_MODE_GEMMA:
|
| 89 |
+
if elementwise_affine:
|
| 90 |
+
W_row = W_row.to(tl.float32)
|
| 91 |
X_row = X_row.to(tl.float32)
|
| 92 |
|
| 93 |
if casting_mode == _CASTING_MODE_NONE:
|
|
|
|
| 100 |
# We can save time by caching rms with minimal memory overhead
|
| 101 |
# because rms is much smaller compared to X_row, as rms is for each row.
|
| 102 |
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
| 103 |
+
tl.store(rstd_base, rstd)
|
| 104 |
|
| 105 |
X_row = X_row * rstd
|
| 106 |
|
|
|
|
| 108 |
if casting_mode == _CASTING_MODE_LLAMA:
|
| 109 |
X_row = X_row.to(X_row_dtype)
|
| 110 |
|
| 111 |
+
if elementwise_affine:
|
| 112 |
+
Y_row = X_row * (offset + W_row)
|
| 113 |
+
else:
|
| 114 |
+
Y_row = X_row
|
| 115 |
|
| 116 |
if casting_mode == _CASTING_MODE_GEMMA:
|
| 117 |
Y_row = Y_row.to(X_row_dtype)
|
| 118 |
|
| 119 |
+
tl.store(y_base + col_offsets, Y_row, mask=mask)
|
| 120 |
|
| 121 |
|
| 122 |
@triton.jit
|
|
|
|
| 137 |
n_rows,
|
| 138 |
n_cols,
|
| 139 |
offset,
|
| 140 |
+
rows_per_program,
|
| 141 |
casting_mode: tl.constexpr,
|
| 142 |
+
elementwise_affine: tl.constexpr,
|
| 143 |
BLOCK_SIZE: tl.constexpr,
|
| 144 |
):
|
| 145 |
"""
|
|
|
|
| 147 |
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
| 148 |
"""
|
| 149 |
|
| 150 |
+
row_block_id = tl.program_id(0).to(tl.int64)
|
| 151 |
row_start = row_block_id * rows_per_program
|
| 152 |
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
| 153 |
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 154 |
mask = col_offsets < n_cols
|
| 155 |
|
| 156 |
+
if elementwise_affine:
|
| 157 |
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
| 158 |
|
| 159 |
+
if elementwise_affine:
|
| 160 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
| 161 |
+
W_row = W_row + offset
|
| 162 |
|
| 163 |
+
for row_idx in range(row_start, row_end):
|
| 164 |
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
| 165 |
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
| 166 |
|
| 167 |
+
x_base = X_ptr + row_idx * X_row_stride
|
| 168 |
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
| 169 |
|
| 170 |
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
| 171 |
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
|
|
| 172 |
|
| 173 |
# Get cached rms
|
| 174 |
+
rstd_row = tl.load(rstd_base)
|
| 175 |
|
| 176 |
X_row = X_row.to(tl.float32)
|
| 177 |
|
| 178 |
# Different bacward graphs for different casting modes
|
| 179 |
if casting_mode == _CASTING_MODE_LLAMA:
|
| 180 |
+
if elementwise_affine:
|
| 181 |
+
m = (dY_row * W_row).to(tl.float32)
|
| 182 |
+
else:
|
| 183 |
+
m = dY_row.to(tl.float32)
|
| 184 |
|
| 185 |
elif casting_mode == _CASTING_MODE_GEMMA:
|
| 186 |
dY_row = dY_row.to(tl.float32)
|
| 187 |
+
if elementwise_affine:
|
| 188 |
+
m = dY_row * W_row
|
| 189 |
+
else:
|
| 190 |
+
m = dY_row
|
| 191 |
else:
|
| 192 |
+
if elementwise_affine:
|
| 193 |
+
m = dY_row * W_row
|
| 194 |
+
else:
|
| 195 |
+
m = dY_row
|
| 196 |
|
| 197 |
dX_row = rstd_row * m
|
| 198 |
|
| 199 |
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
| 200 |
|
| 201 |
+
if elementwise_affine:
|
| 202 |
+
# calculate the gradient of W
|
| 203 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
| 204 |
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
| 205 |
+
else:
|
| 206 |
+
# here X_row is already in fp32 (see previous if block)
|
| 207 |
+
dW_row += dY_row * (X_row * rstd_row)
|
| 208 |
+
|
| 209 |
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
| 210 |
+
|
| 211 |
+
if elementwise_affine:
|
| 212 |
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@triton.jit
|
| 216 |
+
def _block_rms_norm_forward_kernel(
|
| 217 |
+
Y_ptr,
|
| 218 |
+
Y_row_stride,
|
| 219 |
+
X_ptr,
|
| 220 |
+
X_row_stride,
|
| 221 |
+
W_ptr,
|
| 222 |
+
W_row_stride,
|
| 223 |
+
RSTD_ptr,
|
| 224 |
+
RSTD_row_stride,
|
| 225 |
+
n_rows,
|
| 226 |
+
n_cols,
|
| 227 |
+
eps,
|
| 228 |
+
offset,
|
| 229 |
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
| 230 |
+
elementwise_affine: tl.constexpr,
|
| 231 |
+
BLOCK_SIZE: tl.constexpr,
|
| 232 |
+
BLOCK_ROW: tl.constexpr,
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
| 236 |
+
|
| 237 |
+
Reference:
|
| 238 |
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
| 239 |
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
| 240 |
+
3. https://arxiv.org/pdf/1910.07467
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
|
| 244 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 245 |
+
row_mask = row_idx < n_rows
|
| 246 |
+
col_mask = col_offsets < n_cols
|
| 247 |
+
|
| 248 |
+
X_row = tl.load(
|
| 249 |
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
| 250 |
+
mask=row_mask[:, None] & col_mask[None, :],
|
| 251 |
+
other=0,
|
| 252 |
+
)
|
| 253 |
+
X_row_dtype = X_row.dtype
|
| 254 |
+
if elementwise_affine:
|
| 255 |
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
| 256 |
+
|
| 257 |
+
# On Llama, only rstd is computed on fp32
|
| 258 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
| 259 |
+
X_row = X_row.to(tl.float32)
|
| 260 |
+
|
| 261 |
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
| 262 |
+
if casting_mode == _CASTING_MODE_GEMMA:
|
| 263 |
+
if elementwise_affine:
|
| 264 |
+
W_row = W_row.to(tl.float32)
|
| 265 |
+
X_row = X_row.to(tl.float32)
|
| 266 |
+
|
| 267 |
+
if casting_mode == _CASTING_MODE_NONE:
|
| 268 |
+
eps = eps.to(X_row_dtype)
|
| 269 |
+
offset = offset.to(X_row_dtype)
|
| 270 |
+
|
| 271 |
+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
|
| 272 |
+
rstd = rsqrt(mean_square + eps)
|
| 273 |
+
|
| 274 |
+
# We can save time by caching rms with minimal memory overhead
|
| 275 |
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
| 276 |
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
| 277 |
+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
|
| 278 |
+
|
| 279 |
+
X_row = X_row * rstd[:, None]
|
| 280 |
+
|
| 281 |
+
# On Llama, the multiplication with the weight is done on the original dtype
|
| 282 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
| 283 |
+
X_row = X_row.to(X_row_dtype)
|
| 284 |
+
|
| 285 |
+
if elementwise_affine:
|
| 286 |
+
Y_row = X_row * (offset + W_row)[None, :]
|
| 287 |
+
else:
|
| 288 |
+
Y_row = X_row
|
| 289 |
+
|
| 290 |
+
if casting_mode == _CASTING_MODE_GEMMA:
|
| 291 |
+
Y_row = Y_row.to(X_row_dtype)
|
| 292 |
+
|
| 293 |
+
tl.store(
|
| 294 |
+
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
|
| 295 |
+
Y_row,
|
| 296 |
+
mask=row_mask[:, None] & col_mask[None, :],
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@triton.jit
|
| 301 |
+
def _block_rms_norm_backward_kernel(
|
| 302 |
+
dY_ptr,
|
| 303 |
+
dY_row_stride,
|
| 304 |
+
dX_ptr,
|
| 305 |
+
dX_row_stride,
|
| 306 |
+
X_ptr,
|
| 307 |
+
X_row_stride,
|
| 308 |
+
X_dtype: tl.constexpr,
|
| 309 |
+
W_ptr,
|
| 310 |
+
W_row_stride,
|
| 311 |
+
RSTD_ptr,
|
| 312 |
+
RSTD_row_stride,
|
| 313 |
+
dW_ptr,
|
| 314 |
+
dW_row_stride,
|
| 315 |
+
n_rows,
|
| 316 |
+
n_cols,
|
| 317 |
+
offset,
|
| 318 |
+
casting_mode: tl.constexpr,
|
| 319 |
+
elementwise_affine: tl.constexpr,
|
| 320 |
+
BLOCK_SIZE: tl.constexpr,
|
| 321 |
+
BLOCK_ROW: tl.constexpr,
|
| 322 |
+
):
|
| 323 |
+
"""
|
| 324 |
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
| 325 |
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
pid = tl.program_id(0).cast(tl.int64)
|
| 329 |
+
NUM_SMS = tl.num_programs(0)
|
| 330 |
+
|
| 331 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 332 |
+
col_mask = col_offsets < n_cols
|
| 333 |
+
|
| 334 |
+
if elementwise_affine:
|
| 335 |
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
| 336 |
+
|
| 337 |
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
| 338 |
+
W_row = W_row + offset
|
| 339 |
+
|
| 340 |
+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
| 341 |
+
row_idx = start + tl.arange(0, BLOCK_ROW)
|
| 342 |
+
row_mask = row_idx < n_rows
|
| 343 |
+
dY_row = tl.load(
|
| 344 |
+
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
|
| 345 |
+
mask=row_mask[:, None] & col_mask[None, :],
|
| 346 |
+
other=0.0,
|
| 347 |
+
)
|
| 348 |
+
X_row = tl.load(
|
| 349 |
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
| 350 |
+
mask=row_mask[:, None] & col_mask[None, :],
|
| 351 |
+
other=0.0,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Get cached rms
|
| 355 |
+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
|
| 356 |
+
|
| 357 |
+
X_row = X_row.to(tl.float32)
|
| 358 |
+
|
| 359 |
+
# Different bacward graphs for different casting modes
|
| 360 |
if casting_mode == _CASTING_MODE_LLAMA:
|
| 361 |
+
if elementwise_affine:
|
| 362 |
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
| 363 |
+
else:
|
| 364 |
+
m = dY_row.to(tl.float32)
|
| 365 |
+
|
| 366 |
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
| 367 |
+
dY_row = dY_row.to(tl.float32)
|
| 368 |
+
if elementwise_affine:
|
| 369 |
+
m = dY_row * W_row[None, :]
|
| 370 |
+
else:
|
| 371 |
+
m = dY_row
|
| 372 |
else:
|
| 373 |
+
if elementwise_affine:
|
| 374 |
+
m = dY_row * W_row[None, :]
|
| 375 |
+
else:
|
| 376 |
+
m = dY_row
|
| 377 |
+
|
| 378 |
+
dX_row = rstd_row[:, None] * m
|
| 379 |
|
| 380 |
+
dX_row += (rstd_row[:, None]) * (
|
| 381 |
+
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
| 382 |
+
)
|
| 383 |
|
| 384 |
+
if elementwise_affine:
|
| 385 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
| 386 |
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
| 387 |
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
| 388 |
+
else:
|
| 389 |
+
# here X_row is already in fp32 (see previous if block)
|
| 390 |
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
| 391 |
+
|
| 392 |
+
tl.store(
|
| 393 |
+
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
| 394 |
+
dX_row,
|
| 395 |
+
mask=row_mask[:, None] & col_mask[None, :],
|
| 396 |
+
)
|
| 397 |
|
| 398 |
+
if elementwise_affine:
|
| 399 |
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
| 400 |
|
| 401 |
|
| 402 |
_str_to_casting_mode = {
|
|
|
|
| 406 |
}
|
| 407 |
|
| 408 |
|
| 409 |
+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
| 410 |
if not isinstance(casting_mode, int):
|
| 411 |
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
| 412 |
casting_mode = _str_to_casting_mode[casting_mode]
|
|
|
|
| 425 |
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
| 426 |
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
| 427 |
|
| 428 |
+
if W is not None:
|
| 429 |
+
# Check constraints.
|
| 430 |
+
assert X.shape[1] == W.shape[0], (
|
| 431 |
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
| 432 |
+
)
|
| 433 |
+
elementwise_affine = True
|
| 434 |
+
else:
|
| 435 |
+
elementwise_affine = False
|
| 436 |
|
| 437 |
# XPU-specific optimization
|
| 438 |
kernel_args = {}
|
| 439 |
if X.device.type == "xpu":
|
| 440 |
+
set_large_grf_mode(kernel_args)
|
| 441 |
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
| 442 |
+
_rms_norm_forward_kernel[(n_rows,)](
|
| 443 |
+
Y,
|
| 444 |
+
Y.stride(0),
|
| 445 |
+
X,
|
| 446 |
+
X.stride(0),
|
| 447 |
+
W,
|
| 448 |
+
W.stride(0) if elementwise_affine else 0,
|
| 449 |
+
RSTD,
|
| 450 |
+
RSTD.stride(0),
|
| 451 |
+
n_cols,
|
| 452 |
+
eps,
|
| 453 |
+
offset,
|
| 454 |
+
casting_mode,
|
| 455 |
+
elementwise_affine=elementwise_affine,
|
| 456 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 457 |
+
num_warps=num_warps,
|
| 458 |
+
**kernel_args, # XPU-specific optimization
|
| 459 |
+
)
|
| 460 |
+
else:
|
| 461 |
+
BLOCK_ROW = 16
|
| 462 |
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
| 463 |
+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
|
| 464 |
+
Y,
|
| 465 |
+
Y.stride(0),
|
| 466 |
+
X,
|
| 467 |
+
X.stride(0),
|
| 468 |
+
W,
|
| 469 |
+
W.stride(0) if elementwise_affine else 0,
|
| 470 |
+
RSTD,
|
| 471 |
+
RSTD.stride(0),
|
| 472 |
+
n_rows,
|
| 473 |
+
n_cols,
|
| 474 |
+
eps,
|
| 475 |
+
offset,
|
| 476 |
+
casting_mode,
|
| 477 |
+
elementwise_affine=elementwise_affine,
|
| 478 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 479 |
+
num_warps=num_warps,
|
| 480 |
+
**kernel_args, # XPU-specific optimization
|
| 481 |
+
)
|
| 482 |
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
| 483 |
|
| 484 |
|
| 485 |
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
|
| 486 |
shape = dY.shape
|
| 487 |
dim = shape[-1]
|
| 488 |
dY = dY.view(-1, dim)
|
|
|
|
| 493 |
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
| 494 |
elif X.device.type == "xpu":
|
| 495 |
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
| 496 |
+
elif X.device.type == "npu":
|
| 497 |
+
sm_count = get_npu_core_count()
|
| 498 |
|
| 499 |
+
if W is not None:
|
| 500 |
+
# fp32 for numerical stability especially.
|
| 501 |
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
| 502 |
+
elementwise_affine = True
|
| 503 |
+
else:
|
| 504 |
+
_dW = None
|
| 505 |
+
elementwise_affine = False
|
| 506 |
|
| 507 |
if n_cols > BLOCK_SIZE:
|
| 508 |
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
|
|
| 517 |
# XPU-specific optimization
|
| 518 |
kernel_args = {}
|
| 519 |
if X.device.type == "xpu":
|
| 520 |
+
set_large_grf_mode(kernel_args)
|
| 521 |
+
|
| 522 |
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
| 523 |
+
_rms_norm_backward_kernel[grid](
|
| 524 |
+
dY,
|
| 525 |
+
dY.stride(0),
|
| 526 |
+
dX,
|
| 527 |
+
dX.stride(0),
|
| 528 |
+
X,
|
| 529 |
+
X.stride(0),
|
| 530 |
+
torch_to_triton_dtype[X.dtype],
|
| 531 |
+
W,
|
| 532 |
+
W.stride(0) if elementwise_affine else 0,
|
| 533 |
+
RSTD,
|
| 534 |
+
RSTD.stride(0),
|
| 535 |
+
_dW,
|
| 536 |
+
_dW.stride(0) if elementwise_affine else 0,
|
| 537 |
+
n_rows,
|
| 538 |
+
n_cols,
|
| 539 |
+
offset,
|
| 540 |
+
rows_per_program,
|
| 541 |
+
casting_mode,
|
| 542 |
+
elementwise_affine=elementwise_affine,
|
| 543 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 544 |
+
num_warps=num_warps,
|
| 545 |
+
**kernel_args, # XPU-specific optimization
|
| 546 |
+
)
|
| 547 |
+
else:
|
| 548 |
+
BLOCK_ROW = 16
|
| 549 |
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
| 550 |
+
_block_rms_norm_backward_kernel[grid](
|
| 551 |
+
dY,
|
| 552 |
+
dY.stride(0),
|
| 553 |
+
dX,
|
| 554 |
+
dX.stride(0),
|
| 555 |
+
X,
|
| 556 |
+
X.stride(0),
|
| 557 |
+
torch_to_triton_dtype[X.dtype],
|
| 558 |
+
W,
|
| 559 |
+
W.stride(0) if elementwise_affine else 0,
|
| 560 |
+
RSTD,
|
| 561 |
+
RSTD.stride(0),
|
| 562 |
+
_dW,
|
| 563 |
+
_dW.stride(0) if elementwise_affine else 0,
|
| 564 |
+
n_rows,
|
| 565 |
+
n_cols,
|
| 566 |
+
offset,
|
| 567 |
+
casting_mode,
|
| 568 |
+
elementwise_affine=elementwise_affine,
|
| 569 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 570 |
+
num_warps=num_warps,
|
| 571 |
+
**kernel_args, # XPU-specific optimization
|
| 572 |
+
)
|
| 573 |
dX = dX.view(*shape)
|
| 574 |
+
|
| 575 |
+
if elementwise_affine:
|
| 576 |
+
dW = _dW.sum(dim=0).to(W.dtype)
|
| 577 |
+
else:
|
| 578 |
+
dW = None
|
| 579 |
|
| 580 |
return dX, dW
|
| 581 |
|
|
|
|
| 604 |
|
| 605 |
@staticmethod
|
| 606 |
@ensure_contiguous
|
| 607 |
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
|
| 608 |
"""
|
| 609 |
X: (B, T, H) or (BxT, H)
|
| 610 |
W: (H,)
|
| 611 |
"""
|
| 612 |
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
| 613 |
+
# Input tensor is output of a tensor parallel module and
|
| 614 |
+
# needs to be gathered to a local tensor to compute
|
| 615 |
+
# RMSE layer norm on each TP worker.
|
| 616 |
+
# TODO: support CP.
|
| 617 |
+
X = X.full_tensor()
|
| 618 |
+
|
| 619 |
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
| 620 |
ctx.offset = offset
|
| 621 |
ctx.casting_mode = casting_mode
|
| 622 |
ctx.in_place = in_place
|
| 623 |
+
ctx.row_mode = row_mode
|
| 624 |
ctx.BLOCK_SIZE = BLOCK_SIZE
|
| 625 |
ctx.num_warps = num_warps
|
| 626 |
+
ctx.elementwise_affine = W is not None
|
| 627 |
+
if W is not None:
|
| 628 |
+
ctx.save_for_backward(X, W, RSTD)
|
| 629 |
+
else:
|
| 630 |
+
ctx.save_for_backward(X, RSTD)
|
| 631 |
return Y
|
| 632 |
|
| 633 |
@staticmethod
|
|
|
|
| 636 |
"""
|
| 637 |
Y: (B, T, H) or (BxT, H)
|
| 638 |
"""
|
| 639 |
+
if ctx.elementwise_affine:
|
| 640 |
+
X, W, RSTD = ctx.saved_tensors
|
| 641 |
+
else:
|
| 642 |
+
X, RSTD = ctx.saved_tensors
|
| 643 |
+
W = None
|
| 644 |
+
|
| 645 |
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
| 646 |
+
# Gradients are output of a tensor parallel module and
|
| 647 |
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
| 648 |
+
# TODO: support CP.
|
| 649 |
+
dY = dY.full_tensor()
|
| 650 |
+
|
| 651 |
dX, dW = rms_norm_backward(
|
| 652 |
+
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
)
|
| 654 |
+
return dX, dW, None, None, None, None, None
|
build/torch-cuda/rope.py
CHANGED
|
@@ -32,7 +32,7 @@ def _triton_rope(
|
|
| 32 |
|
| 33 |
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
| 34 |
# stride: (seq_len * head_dim, head_dim, 1)
|
| 35 |
-
pid = tl.program_id(0)
|
| 36 |
|
| 37 |
# locate start address
|
| 38 |
q_ptr = q_ptr + pid * q_row_stride
|
|
@@ -236,4 +236,4 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
| 236 |
|
| 237 |
cos, sin = ctx.saved_tensors
|
| 238 |
dq, dk = rope_backward(dq, dk, cos, sin)
|
| 239 |
-
return dq, dk, None, None, None, None
|
|
|
|
| 32 |
|
| 33 |
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
| 34 |
# stride: (seq_len * head_dim, head_dim, 1)
|
| 35 |
+
pid = tl.program_id(0).to(tl.int64)
|
| 36 |
|
| 37 |
# locate start address
|
| 38 |
q_ptr = q_ptr + pid * q_row_stride
|
|
|
|
| 236 |
|
| 237 |
cos, sin = ctx.saved_tensors
|
| 238 |
dq, dk = rope_backward(dq, dk, cos, sin)
|
| 239 |
+
return dq, dk, None, None, None, None
|
build/torch-cuda/swiglu.py
CHANGED
|
@@ -12,7 +12,9 @@ def silu(x):
|
|
| 12 |
|
| 13 |
|
| 14 |
@triton.jit
|
| 15 |
-
def _swiglu_forward_kernel(
|
|
|
|
|
|
|
| 16 |
program_id = tl.program_id(0).to(tl.int64)
|
| 17 |
|
| 18 |
# locate start index
|
|
@@ -24,14 +26,16 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
|
|
| 24 |
mask = col_offsets < n_cols
|
| 25 |
|
| 26 |
# sigmoid requires type float32
|
| 27 |
-
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
| 28 |
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
| 29 |
-
c_row = silu(a_row) * b_row
|
| 30 |
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
| 31 |
|
| 32 |
|
| 33 |
@triton.jit
|
| 34 |
-
def _swiglu_backward_kernel(
|
|
|
|
|
|
|
| 35 |
program_id = tl.program_id(0).to(tl.int64)
|
| 36 |
|
| 37 |
# locate start index
|
|
@@ -44,20 +48,21 @@ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr,
|
|
| 44 |
|
| 45 |
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
| 46 |
# sigmoid requires type float32
|
| 47 |
-
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
| 48 |
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
| 49 |
|
| 50 |
-
# recomputation to save memory
|
| 51 |
sig_a = tl.sigmoid(a_row)
|
| 52 |
silu_a = a_row * sig_a
|
| 53 |
db_row = dc_row * silu_a
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
| 57 |
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
| 58 |
|
| 59 |
|
| 60 |
-
def swiglu_forward(a, b):
|
| 61 |
ori_shape = a.shape
|
| 62 |
|
| 63 |
n_cols = ori_shape[-1]
|
|
@@ -73,6 +78,7 @@ def swiglu_forward(a, b):
|
|
| 73 |
b,
|
| 74 |
c,
|
| 75 |
c.stride(-2),
|
|
|
|
| 76 |
n_cols=n_cols,
|
| 77 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 78 |
num_warps=num_warps,
|
|
@@ -80,7 +86,7 @@ def swiglu_forward(a, b):
|
|
| 80 |
return a, b, c.view(*ori_shape)
|
| 81 |
|
| 82 |
|
| 83 |
-
def swiglu_backward(a, b, dc):
|
| 84 |
ori_shape = dc.shape
|
| 85 |
n_cols = ori_shape[-1]
|
| 86 |
dc = dc.view(-1, n_cols)
|
|
@@ -93,6 +99,7 @@ def swiglu_backward(a, b, dc):
|
|
| 93 |
a,
|
| 94 |
b,
|
| 95 |
dc.stride(-2),
|
|
|
|
| 96 |
n_cols=n_cols,
|
| 97 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 98 |
num_warps=num_warps,
|
|
@@ -103,14 +110,67 @@ def swiglu_backward(a, b, dc):
|
|
| 103 |
class LigerSiLUMulFunction(torch.autograd.Function):
|
| 104 |
@staticmethod
|
| 105 |
@ensure_contiguous
|
| 106 |
-
def forward(ctx, a, b):
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
@staticmethod
|
| 112 |
@ensure_contiguous
|
| 113 |
def backward(ctx, dc):
|
| 114 |
a, b = ctx.saved_tensors
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
@triton.jit
|
| 15 |
+
def _swiglu_forward_kernel(
|
| 16 |
+
a_ptr, b_ptr, c_ptr, stride, gate_multiplier, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
| 17 |
+
):
|
| 18 |
program_id = tl.program_id(0).to(tl.int64)
|
| 19 |
|
| 20 |
# locate start index
|
|
|
|
| 26 |
mask = col_offsets < n_cols
|
| 27 |
|
| 28 |
# sigmoid requires type float32
|
| 29 |
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) * gate_multiplier
|
| 30 |
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
| 31 |
+
c_row = silu(a_row).cast(b_row.dtype) * b_row
|
| 32 |
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
| 33 |
|
| 34 |
|
| 35 |
@triton.jit
|
| 36 |
+
def _swiglu_backward_kernel(
|
| 37 |
+
dc_ptr, a_ptr, b_ptr, stride, gate_multiplier, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
| 38 |
+
):
|
| 39 |
program_id = tl.program_id(0).to(tl.int64)
|
| 40 |
|
| 41 |
# locate start index
|
|
|
|
| 48 |
|
| 49 |
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
| 50 |
# sigmoid requires type float32
|
| 51 |
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) * gate_multiplier
|
| 52 |
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
| 53 |
|
| 54 |
+
# recomputation to save memory. a_row already holds a * gate_multiplier.
|
| 55 |
sig_a = tl.sigmoid(a_row)
|
| 56 |
silu_a = a_row * sig_a
|
| 57 |
db_row = dc_row * silu_a
|
| 58 |
+
# chain rule pulls an extra factor of gate_multiplier through the pre-activation scaling
|
| 59 |
+
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row * gate_multiplier
|
| 60 |
|
| 61 |
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
| 62 |
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
| 63 |
|
| 64 |
|
| 65 |
+
def swiglu_forward(a, b, gate_multiplier: float = 1.0):
|
| 66 |
ori_shape = a.shape
|
| 67 |
|
| 68 |
n_cols = ori_shape[-1]
|
|
|
|
| 78 |
b,
|
| 79 |
c,
|
| 80 |
c.stride(-2),
|
| 81 |
+
float(gate_multiplier),
|
| 82 |
n_cols=n_cols,
|
| 83 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 84 |
num_warps=num_warps,
|
|
|
|
| 86 |
return a, b, c.view(*ori_shape)
|
| 87 |
|
| 88 |
|
| 89 |
+
def swiglu_backward(a, b, dc, gate_multiplier: float = 1.0):
|
| 90 |
ori_shape = dc.shape
|
| 91 |
n_cols = ori_shape[-1]
|
| 92 |
dc = dc.view(-1, n_cols)
|
|
|
|
| 99 |
a,
|
| 100 |
b,
|
| 101 |
dc.stride(-2),
|
| 102 |
+
float(gate_multiplier),
|
| 103 |
n_cols=n_cols,
|
| 104 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 105 |
num_warps=num_warps,
|
|
|
|
| 110 |
class LigerSiLUMulFunction(torch.autograd.Function):
|
| 111 |
@staticmethod
|
| 112 |
@ensure_contiguous
|
| 113 |
+
def forward(ctx, a, b, gate_multiplier: float = 1.0, down_multiplier: float = 1.0):
|
| 114 |
+
gate_multiplier = float(gate_multiplier)
|
| 115 |
+
down_multiplier = float(down_multiplier)
|
| 116 |
+
ctx.gate_multiplier = gate_multiplier
|
| 117 |
+
ctx.down_multiplier = down_multiplier
|
| 118 |
+
|
| 119 |
+
if isinstance(a, torch.distributed.tensor.DTensor) or isinstance(b, torch.distributed.tensor.DTensor):
|
| 120 |
+
device_mesh, placements = (
|
| 121 |
+
(a.device_mesh, a.placements)
|
| 122 |
+
if isinstance(a, torch.distributed.tensor.DTensor)
|
| 123 |
+
else (b.device_mesh, b.placements)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Assume that full tensors are gathered before and identical across
|
| 127 |
+
# the associated process groups.
|
| 128 |
+
if not isinstance(a, torch.distributed.tensor.DTensor):
|
| 129 |
+
a = torch.distributed.tensor.distribute_tensor(a, device_mesh=device_mesh, placements=placements)
|
| 130 |
+
if not isinstance(b, torch.distributed.tensor.DTensor):
|
| 131 |
+
b = torch.distributed.tensor.distribute_tensor(b, device_mesh=device_mesh, placements=placements)
|
| 132 |
+
a_local, b_local, c_local = swiglu_forward(a.to_local(), b.to_local(), gate_multiplier)
|
| 133 |
+
if down_multiplier != 1.0:
|
| 134 |
+
c_local = c_local * down_multiplier
|
| 135 |
+
ctx.save_for_backward(a_local, b_local)
|
| 136 |
+
ctx.dtensor_metadata = (device_mesh, placements)
|
| 137 |
+
return torch.distributed.tensor.DTensor.from_local(c_local, device_mesh, placements)
|
| 138 |
+
else:
|
| 139 |
+
a, b, c = swiglu_forward(a, b, gate_multiplier)
|
| 140 |
+
if down_multiplier != 1.0:
|
| 141 |
+
c = c * down_multiplier
|
| 142 |
+
ctx.save_for_backward(a, b)
|
| 143 |
+
ctx.dtensor_metadata = None
|
| 144 |
+
return c
|
| 145 |
|
| 146 |
@staticmethod
|
| 147 |
@ensure_contiguous
|
| 148 |
def backward(ctx, dc):
|
| 149 |
a, b = ctx.saved_tensors
|
| 150 |
+
gate_multiplier = ctx.gate_multiplier
|
| 151 |
+
down_multiplier = ctx.down_multiplier
|
| 152 |
+
|
| 153 |
+
if ctx.dtensor_metadata is not None:
|
| 154 |
+
device_mesh, placements = ctx.dtensor_metadata
|
| 155 |
+
|
| 156 |
+
# Assume that full tensors are gathered before and identical across
|
| 157 |
+
# the associated process groups.
|
| 158 |
+
dc_local = (
|
| 159 |
+
dc.to_local()
|
| 160 |
+
if isinstance(dc, torch.distributed.tensor.DTensor)
|
| 161 |
+
else torch.distributed.tensor.distribute_tensor(dc, device_mesh=device_mesh, placements=placements)
|
| 162 |
+
)
|
| 163 |
+
if down_multiplier != 1.0:
|
| 164 |
+
dc_local = dc_local * down_multiplier
|
| 165 |
+
a_local, b_local = swiglu_backward(a, b, dc_local, gate_multiplier)
|
| 166 |
+
return (
|
| 167 |
+
torch.distributed.tensor.DTensor.from_local(a_local, device_mesh, placements),
|
| 168 |
+
torch.distributed.tensor.DTensor.from_local(b_local, device_mesh, placements),
|
| 169 |
+
None,
|
| 170 |
+
None,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if down_multiplier != 1.0:
|
| 174 |
+
dc = dc * down_multiplier
|
| 175 |
+
a, b = swiglu_backward(a, b, dc, gate_multiplier)
|
| 176 |
+
return a, b, None, None
|
build/torch-cuda/tvd.py
CHANGED
|
@@ -49,6 +49,7 @@ def _tv_distance_kernel(
|
|
| 49 |
label_ptr,
|
| 50 |
ignore_index: tl.constexpr,
|
| 51 |
n_cols,
|
|
|
|
| 52 |
BLOCK_SIZE: tl.constexpr,
|
| 53 |
HAS_LABEL: tl.constexpr,
|
| 54 |
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
@@ -84,7 +85,8 @@ def _tv_distance_kernel(
|
|
| 84 |
# TVD(P || Q) = 0.5 * |P - Q|
|
| 85 |
tv_loss = 0.5 * tl.abs(p - q)
|
| 86 |
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
| 90 |
|
|
@@ -94,7 +96,8 @@ def _tv_distance_kernel(
|
|
| 94 |
loss_sum += tl.sum(tv_loss, axis=0)
|
| 95 |
|
| 96 |
if reduction != _REDUCTION_MODE_NONE:
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
@@ -113,6 +116,14 @@ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_
|
|
| 113 |
|
| 114 |
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
_tv_distance_kernel[grid](
|
| 117 |
p,
|
| 118 |
p.stride(0),
|
|
@@ -125,18 +136,18 @@ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_
|
|
| 125 |
shift_labels if has_label else torch.empty(1, device=p.device),
|
| 126 |
ignore_index,
|
| 127 |
V,
|
|
|
|
| 128 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 129 |
HAS_LABEL=has_label,
|
| 130 |
num_warps=num_warps,
|
| 131 |
reduction=reduction,
|
| 132 |
)
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
elif reduction == _REDUCTION_MODE_SUM.value:
|
| 137 |
return output_tensor.sum(dim=0), grads
|
| 138 |
-
elif reduction == _REDUCTION_MODE_MEAN.value:
|
| 139 |
-
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
| 140 |
else:
|
| 141 |
return output_tensor, grads
|
| 142 |
|
|
@@ -204,4 +215,4 @@ class LigerTVDLossFunction(torch.autograd.Function):
|
|
| 204 |
(grads,) = ctx.saved_tensors
|
| 205 |
grads = tvd_backward_triton(grad_output, grads)
|
| 206 |
|
| 207 |
-
return grads, None, None, None, None
|
|
|
|
| 49 |
label_ptr,
|
| 50 |
ignore_index: tl.constexpr,
|
| 51 |
n_cols,
|
| 52 |
+
scale, # pre-computed reduction scale for gradients (fused into kernel)
|
| 53 |
BLOCK_SIZE: tl.constexpr,
|
| 54 |
HAS_LABEL: tl.constexpr,
|
| 55 |
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
|
|
| 85 |
# TVD(P || Q) = 0.5 * |P - Q|
|
| 86 |
tv_loss = 0.5 * tl.abs(p - q)
|
| 87 |
|
| 88 |
+
# Fuse reduction scaling into gradient computation (eliminates separate Python division)
|
| 89 |
+
grad_res = tl.where(p > q, 0.5 * scale, -0.5 * scale)
|
| 90 |
|
| 91 |
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
| 92 |
|
|
|
|
| 96 |
loss_sum += tl.sum(tv_loss, axis=0)
|
| 97 |
|
| 98 |
if reduction != _REDUCTION_MODE_NONE:
|
| 99 |
+
# Fuse reduction scaling into loss (same scale as gradients; avoids Python division)
|
| 100 |
+
tl.store(loss_ptr, loss_sum * scale)
|
| 101 |
|
| 102 |
|
| 103 |
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
|
|
| 116 |
|
| 117 |
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
| 118 |
|
| 119 |
+
# Pre-compute gradient scale factor (fused into kernel to avoid separate division)
|
| 120 |
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
| 121 |
+
scale = 1.0 / n_non_ignore
|
| 122 |
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
| 123 |
+
scale = 1.0 / (n_non_ignore * V)
|
| 124 |
+
else:
|
| 125 |
+
scale = 1.0
|
| 126 |
+
|
| 127 |
_tv_distance_kernel[grid](
|
| 128 |
p,
|
| 129 |
p.stride(0),
|
|
|
|
| 136 |
shift_labels if has_label else torch.empty(1, device=p.device),
|
| 137 |
ignore_index,
|
| 138 |
V,
|
| 139 |
+
scale,
|
| 140 |
BLOCK_SIZE=BLOCK_SIZE,
|
| 141 |
HAS_LABEL=has_label,
|
| 142 |
num_warps=num_warps,
|
| 143 |
reduction=reduction,
|
| 144 |
)
|
| 145 |
|
| 146 |
+
# Loss and gradients are already scaled inside the kernel — no separate division needed
|
| 147 |
+
if reduction in (_REDUCTION_MODE_BATCHMEAN.value, _REDUCTION_MODE_MEAN.value):
|
| 148 |
+
return output_tensor.sum(), grads
|
| 149 |
elif reduction == _REDUCTION_MODE_SUM.value:
|
| 150 |
return output_tensor.sum(dim=0), grads
|
|
|
|
|
|
|
| 151 |
else:
|
| 152 |
return output_tensor, grads
|
| 153 |
|
|
|
|
| 215 |
(grads,) = ctx.saved_tensors
|
| 216 |
grads = tvd_backward_triton(grad_output, grads)
|
| 217 |
|
| 218 |
+
return grads, None, None, None, None
|
build/torch-cuda/utils.py
CHANGED
|
@@ -22,17 +22,33 @@ import triton.language as tl
|
|
| 22 |
|
| 23 |
from packaging.version import Version
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def infer_device():
|
| 26 |
"""
|
| 27 |
Get current device name based on available devices
|
| 28 |
"""
|
| 29 |
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
| 30 |
return "cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
elif torch.xpu.is_available():
|
| 32 |
return "xpu"
|
| 33 |
else:
|
| 34 |
return "cpu"
|
| 35 |
|
|
|
|
| 36 |
def is_hip() -> bool:
|
| 37 |
return torch.version.hip is not None
|
| 38 |
|
|
@@ -86,6 +102,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
|
|
| 86 |
functools.partial(torch.amp.custom_fwd, device_type=device),
|
| 87 |
functools.partial(torch.amp.custom_bwd, device_type=device),
|
| 88 |
)
|
|
|
|
|
|
|
| 89 |
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
| 90 |
|
| 91 |
|
|
@@ -132,4 +150,27 @@ def element_mul_kernel(
|
|
| 132 |
for i in range(0, n_cols, BLOCK_SIZE):
|
| 133 |
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
| 134 |
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
| 135 |
-
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
from packaging.version import Version
|
| 24 |
|
| 25 |
+
|
| 26 |
+
def is_npu_available() -> bool:
|
| 27 |
+
"""Detect Ascend NPU availability."""
|
| 28 |
+
try:
|
| 29 |
+
from transformers.utils import is_torch_npu_available
|
| 30 |
+
|
| 31 |
+
return is_torch_npu_available()
|
| 32 |
+
except Exception:
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
def infer_device():
|
| 37 |
"""
|
| 38 |
Get current device name based on available devices
|
| 39 |
"""
|
| 40 |
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
| 41 |
return "cuda"
|
| 42 |
+
# Use Ascend NPU if available (torch.npu)
|
| 43 |
+
elif is_npu_available():
|
| 44 |
+
return "npu"
|
| 45 |
+
# XPU (Intel) if available
|
| 46 |
elif torch.xpu.is_available():
|
| 47 |
return "xpu"
|
| 48 |
else:
|
| 49 |
return "cpu"
|
| 50 |
|
| 51 |
+
|
| 52 |
def is_hip() -> bool:
|
| 53 |
return torch.version.hip is not None
|
| 54 |
|
|
|
|
| 102 |
functools.partial(torch.amp.custom_fwd, device_type=device),
|
| 103 |
functools.partial(torch.amp.custom_bwd, device_type=device),
|
| 104 |
)
|
| 105 |
+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
|
| 106 |
+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
|
| 107 |
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
| 108 |
|
| 109 |
|
|
|
|
| 150 |
for i in range(0, n_cols, BLOCK_SIZE):
|
| 151 |
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
| 152 |
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
| 153 |
+
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_npu_core_count(default: int = 20) -> int:
|
| 157 |
+
"""Return NPU vector core count.
|
| 158 |
+
Fallback to `default` if Triton runtime or NPU device is unavailable.
|
| 159 |
+
"""
|
| 160 |
+
try:
|
| 161 |
+
utils = triton.runtime.driver.active.utils
|
| 162 |
+
props = utils.get_device_properties(0)
|
| 163 |
+
return int(props.get("num_vectorcore", default))
|
| 164 |
+
except Exception:
|
| 165 |
+
return default
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def set_large_grf_mode(kernel_args: dict):
|
| 169 |
+
"""Set large GRF mode for XPU devices."""
|
| 170 |
+
# On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
|
| 171 |
+
# triton XPU installed from source will be called `triton`.
|
| 172 |
+
if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
|
| 173 |
+
kernel_args["grf_mode"] = "256"
|
| 174 |
+
else:
|
| 175 |
+
# API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
|
| 176 |
+
kernel_args["grf_mode"] = "large"
|