kernels-bot commited on
Commit
0c953b9
·
verified ·
1 Parent(s): 741aabc

Uploaded using `kernel-builder`.

Browse files
build/torch-cuda/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from . import layers
 
2
 
3
- __all__ = ["layers"]
 
1
  from . import layers
2
+ from .layers import CrossEntropyOutput, LigerForCausalLMLoss
3
 
4
+ __all__ = ["layers", "LigerForCausalLMLoss", "CrossEntropyOutput"]
build/torch-cuda/_ops.py CHANGED
@@ -22,7 +22,7 @@ def get_backend() -> str:
22
 
23
  def _find_ops_name() -> str:
24
  kernel_name = "liger_kernels"
25
- unique_id = "e29f7ec"
26
  backend = get_backend()
27
  return f"_{kernel_name}_{backend}_{unique_id}"
28
 
 
22
 
23
  def _find_ops_name() -> str:
24
  kernel_name = "liger_kernels"
25
+ unique_id = "08b4d53"
26
  backend = get_backend()
27
  return f"_{kernel_name}_{backend}_{unique_id}"
28
 
build/torch-cuda/cross_entropy.py CHANGED
@@ -10,8 +10,9 @@ from .utils import compare_version
10
  from .utils import element_mul_kernel
11
  from .utils import is_hip
12
  from .utils import infer_device
 
13
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
  try:
16
  # typical import path with dispatch available
17
  from triton.language.extra.libdevice import tanh
@@ -32,6 +33,10 @@ def liger_cross_entropy_kernel(
32
  loss_ptr,
33
  z_loss_ptr,
34
  loss_stride,
 
 
 
 
35
  n_cols,
36
  n_non_ignore,
37
  sum_non_ignore_weight,
@@ -42,9 +47,12 @@ def liger_cross_entropy_kernel(
42
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
  softcap,
44
  RETURN_Z_LOSS: tl.constexpr,
 
 
45
  BLOCK_SIZE: tl.constexpr,
46
  HAS_WEIGHT: tl.constexpr,
47
  HAS_SOFTCAPPING: tl.constexpr,
 
48
  ):
49
  """
50
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -59,6 +67,8 @@ def liger_cross_entropy_kernel(
59
  loss_ptr: Pointer to tensor to store the loss.
60
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
  loss_stride (int): The stride of the loss tensor.
 
 
62
  n_cols (int): The number of columns in the input tensor.
63
  n_non_ignore (float): The number of non-ignored elements in the batch.
64
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -68,10 +78,12 @@ def liger_cross_entropy_kernel(
68
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
  reduction (str): The string for the reduction to apply
70
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
 
72
  BLOCK_SIZE (int): The block size for Triton operations.
73
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
 
75
  """
76
 
77
  # https://github.com/triton-lang/triton/issues/1058
@@ -90,11 +102,22 @@ def liger_cross_entropy_kernel(
90
  for i in range(0, n_cols, BLOCK_SIZE):
91
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
 
 
 
 
 
 
 
93
  return
94
 
95
  loss_ptr += program_id * loss_stride
96
  if RETURN_Z_LOSS:
97
  z_loss_ptr += program_id * loss_stride
 
 
 
 
98
 
99
  if HAS_WEIGHT:
100
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -105,6 +128,7 @@ def liger_cross_entropy_kernel(
105
  # 3. [Online softmax] first pass: find max + sum
106
  m = float("-inf") # m is the max value. use the notation from the paper
107
  d = 0.0 # d is the sum. use the notation from the paper
 
108
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
  if HAS_SOFTCAPPING:
110
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -125,6 +149,19 @@ def liger_cross_entropy_kernel(
125
  if HAS_SOFTCAPPING:
126
  X_block = softcap * tanh(X_block / softcap)
127
  block_max = tl.max(X_block)
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  if label_smoothing > 0:
129
  # scale X beforehand to avoid overflow
130
  if HAS_WEIGHT:
@@ -155,58 +192,58 @@ def liger_cross_entropy_kernel(
155
  # For 'sum' reduction, no normalization is applied:
156
  # dx_y = softmax(x_y) - 1
157
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
 
211
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -254,12 +291,24 @@ def liger_cross_entropy_kernel(
254
  tl.store(loss_ptr, loss)
255
  if RETURN_Z_LOSS:
256
  tl.store(z_loss_ptr, z_loss)
 
 
 
 
 
 
257
 
258
 
259
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
260
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
261
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
262
- MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
 
 
 
 
 
 
263
 
264
 
265
  def cross_entropy_forward(
@@ -272,8 +321,16 @@ def cross_entropy_forward(
272
  reduction,
273
  softcap,
274
  return_z_loss,
 
 
275
  ):
276
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
 
 
 
 
 
 
277
 
278
  BT, V = _input.shape
279
  n_rows = BT
@@ -283,6 +340,12 @@ def cross_entropy_forward(
283
  # unreduced loss
284
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
 
 
 
 
 
 
286
 
287
  target_mask = target != ignore_index
288
  n_non_ignore = target_mask.sum().item()
@@ -319,6 +382,14 @@ def cross_entropy_forward(
319
  loss_ptr=loss_1d,
320
  z_loss_ptr=z_loss_1d,
321
  loss_stride=loss_1d.stride(-1), # always 1
 
 
 
 
 
 
 
 
322
  n_cols=V,
323
  n_non_ignore=n_non_ignore,
324
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -329,9 +400,12 @@ def cross_entropy_forward(
329
  reduction=reduction,
330
  softcap=softcap,
331
  RETURN_Z_LOSS=return_z_loss,
 
 
332
  BLOCK_SIZE=BLOCK_SIZE,
333
  HAS_WEIGHT=True if weight is not None else False,
334
  HAS_SOFTCAPPING=True if softcap is not None else False,
 
335
  # TODO: 32 seems to give the best performance
336
  # Performance is quite sensitive to num_warps
337
  num_warps=32 if not is_hip() else 16,
@@ -340,11 +414,16 @@ def cross_entropy_forward(
340
  if reduction == "none":
341
  loss = loss_1d
342
  z_loss = z_loss_1d if return_z_loss else None
 
343
  else:
344
  loss = torch.sum(loss_1d)
345
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
 
 
346
 
347
- return loss, z_loss, _input
 
 
348
 
349
 
350
  def cross_entropy_backward(_input, grad_output):
@@ -392,6 +471,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
392
  reduction: str = "mean",
393
  softcap: Optional[float] = None,
394
  return_z_loss: bool = False,
 
 
395
  ):
396
  """
397
  The forward pass of the Liger Cross Entropy loss.
@@ -406,12 +487,16 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
406
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
 
 
410
 
411
  Returns:
412
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
413
  """
414
- loss, z_loss, _input = cross_entropy_forward(
 
 
415
  _input,
416
  target,
417
  weight,
@@ -421,29 +506,40 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
421
  reduction,
422
  softcap,
423
  return_z_loss,
 
 
424
  )
425
  # TODO: investigation
426
  # If we don't detach the _input tensor, the memory will double
427
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
 
429
  ctx.return_z_loss = return_z_loss
 
 
430
 
431
- return loss, z_loss
432
 
433
  @staticmethod
434
- def backward(ctx, grad_output, grad_ouput2):
435
  """
436
  The backward pass of the Liger Cross Entropy loss.
437
 
438
  Parameters:
439
  ctx : The context object with saved tensors.
440
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
- grad_output2 (tenosr): No use.
 
 
442
  Returns:
443
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
  """
445
  if ctx.return_z_loss:
446
- del grad_ouput2 # z_loss is only for logging
 
 
 
 
447
 
448
  (_input,) = ctx.saved_tensors
449
  _input = cross_entropy_backward(_input, grad_output)
@@ -457,4 +553,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
457
  None,
458
  None,
459
  None,
460
- )
 
 
 
10
  from .utils import element_mul_kernel
11
  from .utils import is_hip
12
  from .utils import infer_device
13
+ from .utils import is_npu_available
14
 
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
16
  try:
17
  # typical import path with dispatch available
18
  from triton.language.extra.libdevice import tanh
 
33
  loss_ptr,
34
  z_loss_ptr,
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
38
+ predicted_tokens_ptr,
39
+ predicted_tokens_stride,
40
  n_cols,
41
  n_non_ignore,
42
  sum_non_ignore_weight,
 
47
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
48
  softcap,
49
  RETURN_Z_LOSS: tl.constexpr,
50
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
51
+ RETURN_PREDICTED_TOKENS: tl.constexpr,
52
  BLOCK_SIZE: tl.constexpr,
53
  HAS_WEIGHT: tl.constexpr,
54
  HAS_SOFTCAPPING: tl.constexpr,
55
+ HAS_GRADIENTS: tl.constexpr,
56
  ):
57
  """
58
  This kernel computes both cross entropy loss and the gradient of the input.
 
67
  loss_ptr: Pointer to tensor to store the loss.
68
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
69
  loss_stride (int): The stride of the loss tensor.
70
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
71
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
72
  n_cols (int): The number of columns in the input tensor.
73
  n_non_ignore (float): The number of non-ignored elements in the batch.
74
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
 
78
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
79
  reduction (str): The string for the reduction to apply
80
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
81
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
82
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
83
  BLOCK_SIZE (int): The block size for Triton operations.
84
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
85
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
86
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
87
  """
88
 
89
  # https://github.com/triton-lang/triton/issues/1058
 
102
  for i in range(0, n_cols, BLOCK_SIZE):
103
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
104
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
105
+ # For ignored tokens, set token accuracy to 0
106
+ if RETURN_TOKEN_ACCURACY:
107
+ token_accuracy_ptr += program_id * token_accuracy_stride
108
+ tl.store(token_accuracy_ptr, 0.0)
109
+ if RETURN_PREDICTED_TOKENS:
110
+ predicted_tokens_ptr += program_id * predicted_tokens_stride
111
+ tl.store(predicted_tokens_ptr, -1)
112
  return
113
 
114
  loss_ptr += program_id * loss_stride
115
  if RETURN_Z_LOSS:
116
  z_loss_ptr += program_id * loss_stride
117
+ if RETURN_TOKEN_ACCURACY:
118
+ token_accuracy_ptr += program_id * token_accuracy_stride
119
+ if RETURN_PREDICTED_TOKENS:
120
+ predicted_tokens_ptr += program_id * predicted_tokens_stride
121
 
122
  if HAS_WEIGHT:
123
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
 
128
  # 3. [Online softmax] first pass: find max + sum
129
  m = float("-inf") # m is the max value. use the notation from the paper
130
  d = 0.0 # d is the sum. use the notation from the paper
131
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy / predicted tokens computation
132
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
133
  if HAS_SOFTCAPPING:
134
  ori_X_y = softcap * tanh(ori_X_y / softcap)
 
149
  if HAS_SOFTCAPPING:
150
  X_block = softcap * tanh(X_block / softcap)
151
  block_max = tl.max(X_block)
152
+
153
+ # Track argmax for accuracy / predicted tokens computation
154
+ if RETURN_TOKEN_ACCURACY or RETURN_PREDICTED_TOKENS:
155
+ # Find the index of the maximum value in this block
156
+ is_max_mask = X_block == block_max
157
+ # Mask out invalid indices with a value larger than n_cols
158
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
159
+ # Get the first (smallest) index where max occurs
160
+ current_block_argmax_idx = tl.min(masked_offsets)
161
+
162
+ is_new_max = block_max > m
163
+ argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
164
+
165
  if label_smoothing > 0:
166
  # scale X beforehand to avoid overflow
167
  if HAS_WEIGHT:
 
192
  # For 'sum' reduction, no normalization is applied:
193
  # dx_y = softmax(x_y) - 1
194
  # dx_i = softmax(x_i), for i ≠ y
195
+ if HAS_GRADIENTS:
196
+ for i in range(0, n_cols, BLOCK_SIZE):
197
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
198
+ X_block = tl.load(
199
+ X_ptr + X_offsets,
200
+ mask=X_offsets < n_cols,
201
+ other=float("-inf"),
202
+ # Ensure float32 precision for softmax calculation
203
+ ).cast(tl.float32)
204
+ if HAS_SOFTCAPPING:
205
+ intermediate = tanh(X_block / softcap)
206
+ X_block = softcap * intermediate
207
+
208
+ if not HAS_WEIGHT:
209
+ # softmax(x_i)
210
+ X_block = tl.exp(X_block - m) / d
211
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
212
+ X_block += 2 * lse_square_scale * lse * X_block
213
+ # smoothing term
214
+ X_block += -eps
215
+ # special handle dx_y
216
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
217
+ # reduction scale
218
+ if reduction == "mean":
219
+ X_block = X_block / n_non_ignore
220
+ else:
221
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
222
+ softmax_X = tl.exp(X_block - m) / d
223
+ # derivative of original_loss
224
+ dloss_ori = (1 - label_smoothing) * softmax_X
225
+ # specially handle dx_y
226
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
227
+ dloss_ori = dloss_ori * weight_y
228
+ # derivative of smooth_loss
229
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
230
+ # derivative of z-loss
231
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
232
+ # reduction scale
233
+ if reduction == "mean":
234
+ dloss_ori = dloss_ori / sum_non_ignore_weight
235
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
236
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
237
+ dz_loss = dz_loss / n_non_ignore
238
+ # derivative of total_loss
239
+ X_block = dloss_ori + dloss_smooth + dz_loss
240
+
241
+ # chain rule softcapping
242
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
243
+ if HAS_SOFTCAPPING:
244
+ X_block = X_block * (1 - intermediate * intermediate)
245
+
246
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
247
 
248
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
249
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
 
291
  tl.store(loss_ptr, loss)
292
  if RETURN_Z_LOSS:
293
  tl.store(z_loss_ptr, z_loss)
294
+ if RETURN_TOKEN_ACCURACY:
295
+ # Store 1.0 if prediction is correct, 0.0 otherwise
296
+ is_correct = 1.0 if argmax_idx == y else 0.0
297
+ tl.store(token_accuracy_ptr, is_correct)
298
+ if RETURN_PREDICTED_TOKENS:
299
+ tl.store(predicted_tokens_ptr, argmax_idx)
300
 
301
 
302
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
303
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
304
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
305
+ # the best size we found by manually tuning on xpu and npu.
306
+ if infer_device() == "xpu":
307
+ MAX_FUSED_SIZE = 4096
308
+ elif infer_device() == "npu":
309
+ MAX_FUSED_SIZE = 2048
310
+ else:
311
+ MAX_FUSED_SIZE = 65536 // 2
312
 
313
 
314
  def cross_entropy_forward(
 
321
  reduction,
322
  softcap,
323
  return_z_loss,
324
+ return_token_accuracy=False,
325
+ return_predicted_tokens=False,
326
  ):
327
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
328
+ assert isinstance(return_token_accuracy, bool), (
329
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
330
+ )
331
+ assert isinstance(return_predicted_tokens, bool), (
332
+ f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
333
+ )
334
 
335
  BT, V = _input.shape
336
  n_rows = BT
 
340
  # unreduced loss
341
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
342
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
343
+ token_accuracy_1d = (
344
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
345
+ )
346
+ predicted_tokens_1d = (
347
+ torch.full((n_rows,), -1, dtype=torch.int64, device=_input.device) if return_predicted_tokens else None
348
+ )
349
 
350
  target_mask = target != ignore_index
351
  n_non_ignore = target_mask.sum().item()
 
382
  loss_ptr=loss_1d,
383
  z_loss_ptr=z_loss_1d,
384
  loss_stride=loss_1d.stride(-1), # always 1
385
+ token_accuracy_ptr=token_accuracy_1d,
386
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
387
+ if return_token_accuracy
388
+ else 0, # always 1 if accuracy is enabled
389
+ predicted_tokens_ptr=predicted_tokens_1d,
390
+ predicted_tokens_stride=predicted_tokens_1d.stride(-1)
391
+ if return_predicted_tokens
392
+ else 0, # always 1 if predicted tokens is enabled
393
  n_cols=V,
394
  n_non_ignore=n_non_ignore,
395
  sum_non_ignore_weight=sum_non_ignore_weight,
 
400
  reduction=reduction,
401
  softcap=softcap,
402
  RETURN_Z_LOSS=return_z_loss,
403
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
404
+ RETURN_PREDICTED_TOKENS=return_predicted_tokens,
405
  BLOCK_SIZE=BLOCK_SIZE,
406
  HAS_WEIGHT=True if weight is not None else False,
407
  HAS_SOFTCAPPING=True if softcap is not None else False,
408
+ HAS_GRADIENTS=_input.requires_grad,
409
  # TODO: 32 seems to give the best performance
410
  # Performance is quite sensitive to num_warps
411
  num_warps=32 if not is_hip() else 16,
 
414
  if reduction == "none":
415
  loss = loss_1d
416
  z_loss = z_loss_1d if return_z_loss else None
417
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
418
  else:
419
  loss = torch.sum(loss_1d)
420
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
421
+ # For accuracy, we compute the mean across all non-ignored tokens
422
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
423
 
424
+ predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
425
+
426
+ return loss, z_loss, token_accuracy, predicted_tokens, _input
427
 
428
 
429
  def cross_entropy_backward(_input, grad_output):
 
471
  reduction: str = "mean",
472
  softcap: Optional[float] = None,
473
  return_z_loss: bool = False,
474
+ return_token_accuracy: bool = False,
475
+ return_predicted_tokens: bool = False,
476
  ):
477
  """
478
  The forward pass of the Liger Cross Entropy loss.
 
487
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
488
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
489
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
490
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy, predicted_tokens) instead of (loss, None, None, None). Default: `False`
491
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
492
+ return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
493
 
494
  Returns:
495
+ tuple: A tuple with the computed losses, accuracy, and predicted tokens: (loss, z_loss, token_accuracy, predicted_tokens). z_loss, token_accuracy, and predicted_tokens are None if not requested.
496
  """
497
+ input_requires_grad = _input.requires_grad
498
+
499
+ loss, z_loss, token_accuracy, predicted_tokens, _input = cross_entropy_forward(
500
  _input,
501
  target,
502
  weight,
 
506
  reduction,
507
  softcap,
508
  return_z_loss,
509
+ return_token_accuracy,
510
+ return_predicted_tokens,
511
  )
512
  # TODO: investigation
513
  # If we don't detach the _input tensor, the memory will double
514
  # Not sure why but seems that there will be a time both grad and value exist but in different location
515
+ if input_requires_grad:
516
+ ctx.save_for_backward(_input.detach())
517
  ctx.return_z_loss = return_z_loss
518
+ ctx.return_token_accuracy = return_token_accuracy
519
+ ctx.return_predicted_tokens = return_predicted_tokens
520
 
521
+ return loss, z_loss, token_accuracy, predicted_tokens
522
 
523
  @staticmethod
524
+ def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
525
  """
526
  The backward pass of the Liger Cross Entropy loss.
527
 
528
  Parameters:
529
  ctx : The context object with saved tensors.
530
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
531
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
532
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
533
+ grad_output4 (tensor): No use. Gradient for predicted_tokens (not used as predicted_tokens is only for metrics).
534
  Returns:
535
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
536
  """
537
  if ctx.return_z_loss:
538
+ del grad_output2 # z_loss is only for logging
539
+ if ctx.return_token_accuracy:
540
+ del grad_output3 # token_accuracy is only for metrics
541
+ if ctx.return_predicted_tokens:
542
+ del grad_output4 # predicted_tokens is only for metrics
543
 
544
  (_input,) = ctx.saved_tensors
545
  _input = cross_entropy_backward(_input, grad_output)
 
553
  None,
554
  None,
555
  None,
556
+ None,
557
+ None,
558
+ )
build/torch-cuda/dyt.py CHANGED
@@ -4,12 +4,13 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- from .utils import calculate_settings
8
  from .utils import compare_version
9
  from .utils import ensure_contiguous
 
10
  from .utils import infer_device
 
11
 
12
- if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
14
  # typical import path with dispatch available
15
  from triton.language.extra.libdevice import tanh
@@ -20,187 +21,131 @@ else:
20
  from triton.language.math import tanh
21
 
22
 
 
 
 
 
 
 
 
 
 
23
  @triton.jit
24
- def _dyt_fwd_kernel(
25
- x_ptr,
26
- x_row_stride,
27
- alpha_ptr,
28
- gamma_ptr,
29
- beta_ptr,
30
- y_ptr,
31
- y_row_stride,
32
- n_cols,
33
- BLOCK_SIZE: tl.constexpr,
34
- ):
35
- """
36
- Reference:
37
- https://arxiv.org/abs/2503.10622
38
-
39
- Shapes:
40
- - x: (BT, C)
41
- - alpha: (1)
42
- - gamma: (C)
43
- - beta: (C)
44
- """
45
- row_idx = tl.program_id(0)
46
- offsets = tl.arange(0, BLOCK_SIZE)
47
- mask = offsets < n_cols
48
-
49
- x_ptr += row_idx * x_row_stride
50
- y_ptr += row_idx * y_row_stride
51
-
52
- alpha = tl.load(alpha_ptr)
53
- gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
- beta = tl.load(beta_ptr + offsets, mask=mask)
55
- x = tl.load(x_ptr + offsets, mask=mask)
56
- y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
- tl.store(y_ptr + offsets, y, mask=mask)
58
-
59
-
60
  @triton.jit
61
  def _dyt_bwd_kernel(
62
- x_ptr,
63
- x_row_stride,
64
- dy_ptr,
65
- dy_row_stride,
66
- dx_ptr,
67
- dx_row_stride,
68
- alpha_ptr,
69
- dalpha_ptr,
70
- gamma_ptr,
71
- dgamma_ptr,
72
- dgamma_row_stride,
73
- n_cols,
74
- n_rows,
75
- ROWS_PER_PROGRAM: tl.constexpr,
76
- BLOCK_SIZE: tl.constexpr,
77
  ):
78
- """
79
- Reference:
80
- https://arxiv.org/abs/2503.10622
81
-
82
- Shapes:
83
- - x: (BT, C)
84
- - alpha: (1)
85
- - gamma: (C)
86
- - dx: (BT, C)
87
- - dy: (BT, C)
88
- - dgamma: (sm_count, C)
89
- - dalpha: (sm_count,)
90
- """
91
- # d(gamma * tanh(alpha * x) + beta) / dx
92
- # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
- # d(gamma * tanh(alpha * x) + beta) / dalpha
94
- # = gamma * (1 - tanh^2(alpha * x)) * x
95
- # d(gamma * tanh(alpha * x) + beta) / dgamma
96
- # = tanh(alpha * x)
97
- # d(gamma * tanh(alpha * x)) / dbeta = 1
98
- pid = tl.program_id(0)
99
-
100
- row_start = pid * ROWS_PER_PROGRAM
101
- row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
- offsets = tl.arange(0, BLOCK_SIZE)
103
- mask = offsets < n_cols
104
-
105
- dalpha = 0.0
106
- dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
-
108
- x_ptr += row_start * x_row_stride
109
- dx_ptr += row_start * dx_row_stride
110
- dy_ptr += row_start * dy_row_stride
111
- alpha = tl.load(alpha_ptr)
112
- gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
-
114
- for _ in tl.range(row_start, row_end):
115
- dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
- tanh_ax = tanh((alpha * x).cast(tl.float32))
118
- sech2_ax = 1 - tanh_ax * tanh_ax
119
-
120
- dx = dy * gamma * sech2_ax * alpha
121
- dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
- dgamma += dy * tanh_ax
123
- tl.store(dx_ptr + offsets, dx, mask=mask)
124
-
125
- dy_ptr += dy_row_stride
126
- x_ptr += x_row_stride
127
- dx_ptr += dx_row_stride
128
-
129
- tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
- tl.store(dalpha_ptr + pid, dalpha)
131
-
132
- pass
133
 
134
 
135
  def liger_dyt_fwd(x, alpha, gamma, beta):
136
- shape = x.shape
137
- dim = shape[-1]
138
- x = x.view(-1, dim)
139
- n_rows, n_cols = x.shape
 
 
140
  y = torch.empty_like(x)
141
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
- _dyt_fwd_kernel[(n_rows,)](
143
- x_ptr=x,
144
- alpha_ptr=alpha,
145
- gamma_ptr=gamma,
146
- beta_ptr=beta,
147
- y_ptr=y,
148
- x_row_stride=x.stride(0),
149
- y_row_stride=y.stride(0),
150
- n_cols=n_cols,
151
- BLOCK_SIZE=BLOCK_SIZE,
152
- num_warps=num_warps,
153
  )
154
- return y.view(*shape)
155
-
156
-
157
- def liger_dyt_bwd(dy, x, alpha, gamma):
158
- shape = dy.shape
159
- dtype = x.dtype
160
- dim = shape[-1]
161
- dy = dy.view(-1, dim)
162
- x = x.view(-1, dim)
163
- n_rows, n_cols = dy.shape
164
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
- sm_count = 1
166
  device = infer_device()
167
  if device == "cuda":
168
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
  elif device == "xpu":
170
- sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
- if n_cols > BLOCK_SIZE:
172
- raise RuntimeError(
173
- f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
- )
175
-
176
- dx = torch.empty_like(x, dtype=torch.float32)
177
- _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
- _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
-
180
- grid = (sm_count,)
181
- rows_per_program = triton.cdiv(n_rows, sm_count)
182
- _dyt_bwd_kernel[grid](
183
- x_ptr=x,
184
- x_row_stride=x.stride(0),
185
- dy_ptr=dy,
186
- dy_row_stride=dy.stride(0),
187
- dx_ptr=dx,
188
- dx_row_stride=dx.stride(0),
189
- alpha_ptr=alpha,
190
- dalpha_ptr=_dalpha,
191
- gamma_ptr=gamma,
192
- dgamma_ptr=_dgamma,
193
- dgamma_row_stride=_dgamma.stride(0),
194
- n_cols=n_cols,
195
- n_rows=n_rows,
196
- ROWS_PER_PROGRAM=rows_per_program,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- num_warps=num_warps,
199
- )
200
- dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
- dgamma = _dgamma.sum(dim=0).to(dtype)
202
- dbeta = dy.sum(dim=0).to(dtype)
203
- return dx.view(*shape), dalpha, dgamma, dbeta
204
 
205
 
206
  class LigerDyTFunction(torch.autograd.Function):
@@ -208,18 +153,12 @@ class LigerDyTFunction(torch.autograd.Function):
208
  @ensure_contiguous
209
  def forward(ctx, x, alpha, gamma, beta):
210
  y = liger_dyt_fwd(x, alpha, gamma, beta)
211
- ctx.save_for_backward(x, alpha, gamma)
212
  return y
213
 
214
  @staticmethod
215
  @ensure_contiguous
216
- def backward(ctx, grad_output):
217
- x, alpha, gamma = ctx.saved_tensors
218
- dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
- grad_output,
220
- x,
221
- alpha,
222
- gamma,
223
- )
224
-
225
- return (dx, dalpha, dgamma, dbeta)
 
4
  import triton
5
  import triton.language as tl
6
 
 
7
  from .utils import compare_version
8
  from .utils import ensure_contiguous
9
+ from .utils import get_npu_core_count
10
  from .utils import infer_device
11
+ from .utils import is_npu_available
12
 
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
  try:
15
  # typical import path with dispatch available
16
  from triton.language.extra.libdevice import tanh
 
21
  from triton.language.math import tanh
22
 
23
 
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({"BLOCK_N": bn}, num_stages=ns, num_warps=nw)
27
+ for bn in [1024, 2048, 4096]
28
+ for ns in [1, 2]
29
+ for nw in [4, 8, 16]
30
+ ],
31
+ key=["N"],
32
+ )
33
  @triton.jit
34
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr):
35
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
36
+ mask = col < N
37
+ row_id = tl.cast(tl.program_id(1), tl.int64)
38
+
39
+ X += row_id * N
40
+ Y += row_id * N
41
+ alpha = tl.load(Alpha).to(tl.float32)
42
+
43
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
44
+
45
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
46
+
47
+ tanh_x = tanh(alpha * x)
48
+ y = tanh_x * gamma
49
+ if HAVE_BETA:
50
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
51
+ y += beta
52
+ tl.store(Y + col, y, mask=mask)
53
+
54
+
55
+ @triton.autotune(
56
+ configs=[
57
+ triton.Config({"BLOCK_N": bn}, num_stages=ns, num_warps=nw)
58
+ for bn in [1024, 2048, 4096]
59
+ for ns in [1, 2]
60
+ for nw in [4, 8, 16]
61
+ ],
62
+ key=["N"],
63
+ # DA is indexed by program_id(0), so different BLOCK_N configs write to
64
+ # different slot counts per SM. Autotune trials don't zero outputs between
65
+ # runs, so stale slots from a prior trial would leak into da.sum(). Reset
66
+ # DA between trials to isolate each config's writes.
67
+ reset_to_zero=["DA"],
68
+ )
 
69
  @triton.jit
70
  def _dyt_bwd_kernel(
71
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  ):
73
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
74
+ mask = col < N
75
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
76
+
77
+ alpha = tl.load(Alpha).to(tl.float32)
78
+ da = 0.0
79
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
80
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
81
+ if HAVE_BETA:
82
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
83
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
84
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
85
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
86
+ tanh_x = tanh(alpha * x)
87
+ if HAVE_BETA:
88
+ db += dy
89
+ dg += dy * tanh_x
90
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
91
+ da += tl.sum(x * tmp, 0)
92
+ dx = alpha * tmp
93
+ tl.store(DX + row_id * N + col, dx, mask=mask)
94
+
95
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
96
+ if HAVE_BETA:
97
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
98
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  def liger_dyt_fwd(x, alpha, gamma, beta):
102
+ assert x.is_contiguous()
103
+ HAVE_BETA = True if beta is not None else False
104
+ input_shape = x.shape
105
+ x = x.view(-1, input_shape[-1])
106
+ M, N = x.shape
107
+
108
  y = torch.empty_like(x)
109
+
110
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
111
+ _dyt_fwd_kernel[grid](
112
+ x,
113
+ y,
114
+ alpha,
115
+ gamma,
116
+ beta,
117
+ HAVE_BETA,
118
+ N,
 
 
119
  )
120
+ return y.view(input_shape)
121
+
122
+
123
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
124
+ assert dy.is_contiguous()
125
+ input_shape = x.shape
126
+ x = x.view(-1, input_shape[-1])
127
+ M, N = x.shape
128
+ HAVE_BETA = True if beta is not None else False
129
+
 
 
130
  device = infer_device()
131
  if device == "cuda":
132
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
133
  elif device == "xpu":
134
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
135
+ elif device == "npu":
136
+ NUM_SMS = get_npu_core_count()
137
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
138
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
139
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
140
+ dx = torch.empty_like(dy)
141
+
142
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
143
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N)
144
+ if HAVE_BETA:
145
+ db = db.sum(0).to(x.dtype)
146
+ dg = dg.sum(0).to(gamma.dtype)
147
+ da = da.sum().to(x.dtype).unsqueeze(0)
148
+ return dx.view(input_shape), da, dg, db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  class LigerDyTFunction(torch.autograd.Function):
 
153
  @ensure_contiguous
154
  def forward(ctx, x, alpha, gamma, beta):
155
  y = liger_dyt_fwd(x, alpha, gamma, beta)
156
+ ctx.save_for_backward(x, alpha, gamma, beta)
157
  return y
158
 
159
  @staticmethod
160
  @ensure_contiguous
161
+ def backward(ctx, dy):
162
+ x, alpha, gamma, beta = ctx.saved_tensors
163
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
164
+ return dx, dalpha, dgamma, dbeta
 
 
 
 
 
 
build/torch-cuda/fused_linear_cross_entropy.py CHANGED
@@ -6,11 +6,12 @@ from .utils import amp_custom_bwd
6
  from .utils import amp_custom_fwd
7
  from .utils import element_mul_kernel
8
  from .utils import is_hip
 
9
 
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
- MAX_FUSED_SIZE = 65536 // 2
14
 
15
 
16
  def fused_linear_cross_entropy_forward(
@@ -25,10 +26,22 @@ def fused_linear_cross_entropy_forward(
25
  reduction="mean",
26
  softcap=None,
27
  return_z_loss=False,
 
 
 
 
28
  ):
29
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
 
 
 
 
 
 
30
  device = _input.device
31
 
 
 
32
  # inputs have shape: BT x H
33
  # materialized activations will have shape: BT x V
34
  # the increase in memory = BT x V
@@ -44,12 +57,24 @@ def fused_linear_cross_entropy_forward(
44
  chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
46
 
47
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
48
  grad_input = torch.zeros_like(_input, device=device)
49
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
50
- # we use fp32 for loss accumulator
 
 
 
 
 
 
 
 
 
 
 
51
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
52
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
 
 
53
 
54
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
55
  target_mask = target != ignore_index
@@ -82,9 +107,41 @@ def fused_linear_cross_entropy_forward(
82
 
83
  n_rows = logits_chunk.shape[0]
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # unreduced loss
86
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
87
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
 
 
88
 
89
  # ensure _input and target are contiguous
90
  logits_chunk = logits_chunk.contiguous()
@@ -100,6 +157,14 @@ def fused_linear_cross_entropy_forward(
100
  loss_ptr=loss_1d_slice,
101
  z_loss_ptr=z_loss_1d_slice,
102
  loss_stride=loss_1d_slice.stride(-1), # always 1
 
 
 
 
 
 
 
 
103
  n_cols=V,
104
  n_non_ignore=total_n_non_ignore,
105
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -110,35 +175,46 @@ def fused_linear_cross_entropy_forward(
110
  reduction=reduction,
111
  softcap=softcap,
112
  RETURN_Z_LOSS=return_z_loss,
 
 
113
  HAS_WEIGHT=True if ce_weight is not None else False,
114
  HAS_SOFTCAPPING=True if softcap is not None else False,
 
115
  BLOCK_SIZE=BLOCK_SIZE,
116
  num_warps=32 if not is_hip() else 16,
117
  )
118
 
 
 
 
 
 
 
119
  loss_1d[start_idx:end_idx] = loss_1d_slice
120
  if return_z_loss:
121
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
 
 
 
 
122
  grad_logits_chunk = logits_chunk # chunk_size x V
123
 
124
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
 
 
 
 
125
 
126
- if grad_weight is not None:
127
- torch.addmm(
128
- input=grad_weight,
129
- mat1=logits_chunk.t().to(
130
- _input_chunk.dtype
131
- ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
132
- mat2=_input_chunk,
133
- out=grad_weight,
134
- alpha=1.0,
135
- beta=1.0,
136
- )
137
 
138
- if bias is not None:
 
 
 
139
  torch.add(
140
  input=grad_bias,
141
- other=logits_chunk.sum(dim=0),
142
  out=grad_bias,
143
  alpha=1.0,
144
  )
@@ -148,10 +224,24 @@ def fused_linear_cross_entropy_forward(
148
  # loss = loss_1d
149
  # z_loss = z_loss_1d if return_z_loss else None
150
 
 
 
 
 
 
151
  else:
152
  loss = torch.sum(loss_1d)
153
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
154
- return loss, z_loss, grad_input, grad_weight, grad_bias
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -217,6 +307,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
217
  reduction="mean",
218
  softcap=None,
219
  return_z_loss: bool = False,
 
 
 
 
220
  ):
221
  """
222
  Fusing the last linear layer with cross-entropy loss
@@ -235,35 +329,54 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
235
  ignore_index: the index to ignore in the target
236
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
237
  reduction: reduction to apply
 
 
 
 
 
 
 
238
  """
239
 
240
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
241
- _input=_input,
242
- weight=weight,
243
- target=target,
244
- bias=bias,
245
- ce_weight=ce_weight,
246
- ignore_index=ignore_index,
247
- lse_square_scale=lse_square_scale,
248
- label_smoothing=label_smoothing,
249
- reduction=reduction,
250
- softcap=softcap,
251
- return_z_loss=return_z_loss,
 
 
 
 
 
 
252
  )
253
  # downcast to dtype and store for backward
254
  ctx.save_for_backward(
255
  grad_input.detach(),
256
  grad_weight.detach() if grad_weight is not None else None,
257
- grad_bias.detach() if bias is not None else None,
258
  )
259
  ctx.return_z_loss = return_z_loss
260
- return loss, z_loss
 
 
261
 
262
  @staticmethod
263
  @amp_custom_bwd
264
- def backward(ctx, grad_output, grad_output2):
265
  if ctx.return_z_loss:
266
  del grad_output2 # z_loss is only for logging
 
 
 
 
267
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
268
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
269
  grad_output, grad_input, grad_weight, grad_bias
@@ -280,4 +393,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
280
  None,
281
  None,
282
  None,
283
- )
 
 
 
 
 
6
  from .utils import amp_custom_fwd
7
  from .utils import element_mul_kernel
8
  from .utils import is_hip
9
+ from .utils import infer_device
10
 
11
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
12
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
13
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
14
+ MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
15
 
16
 
17
  def fused_linear_cross_entropy_forward(
 
26
  reduction="mean",
27
  softcap=None,
28
  return_z_loss=False,
29
+ accum_dtype=None,
30
+ use_token_scaling=False,
31
+ return_token_accuracy=False,
32
+ return_predicted_tokens=False,
33
  ):
34
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
35
+ assert isinstance(return_token_accuracy, bool), (
36
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
37
+ )
38
+ assert isinstance(return_predicted_tokens, bool), (
39
+ f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
40
+ )
41
  device = _input.device
42
 
43
+ input_requires_grad = _input.requires_grad
44
+
45
  # inputs have shape: BT x H
46
  # materialized activations will have shape: BT x V
47
  # the increase in memory = BT x V
 
57
  chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
58
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
59
 
 
60
  grad_input = torch.zeros_like(_input, device=device)
61
+
62
+ # we use fp32 for loss and gradients accumulator
63
+ if input_requires_grad:
64
+ if accum_dtype is None:
65
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
66
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
67
+ else:
68
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
69
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
70
+ else:
71
+ grad_weight = None
72
+ grad_bias = None
73
+
74
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
75
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
76
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
77
+ predicted_tokens_1d = torch.full((BT,), -1, dtype=torch.int64, device=device) if return_predicted_tokens else None
78
 
79
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
80
  target_mask = target != ignore_index
 
107
 
108
  n_rows = logits_chunk.shape[0]
109
 
110
+ # Compute predicted probabilities for token scaling if needed
111
+ if use_token_scaling:
112
+ # Compute softmax probabilities for scaling
113
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
114
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
115
+ if softcap is not None:
116
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
117
+
118
+ # Compute softmax to get predicted probabilities
119
+ probs = torch.softmax(logits_for_softmax, dim=-1)
120
+
121
+ # Get predicted probabilities for token scaling, handling ignored targets
122
+ valid_target_mask = target_chunk != ignore_index
123
+ valid_targets = target_chunk[valid_target_mask]
124
+
125
+ if len(valid_targets) > 0:
126
+ # Gather probabilities only for valid targets
127
+ valid_probs = probs[valid_target_mask]
128
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
129
+
130
+ # Create full tensor with zeros for ignored targets
131
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
132
+ pred_probs[valid_target_mask] = pred_probs_valid
133
+ else:
134
+ # All targets are ignored
135
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
136
+
137
+ # Store the scaling factors
138
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
139
+
140
  # unreduced loss
141
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
142
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
143
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
144
+ predicted_tokens_1d_slice = predicted_tokens_1d[start_idx:end_idx] if return_predicted_tokens else None
145
 
146
  # ensure _input and target are contiguous
147
  logits_chunk = logits_chunk.contiguous()
 
157
  loss_ptr=loss_1d_slice,
158
  z_loss_ptr=z_loss_1d_slice,
159
  loss_stride=loss_1d_slice.stride(-1), # always 1
160
+ token_accuracy_ptr=token_accuracy_1d_slice,
161
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
162
+ if return_token_accuracy
163
+ else 0, # always 1 if accuracy is enabled
164
+ predicted_tokens_ptr=predicted_tokens_1d_slice,
165
+ predicted_tokens_stride=predicted_tokens_1d_slice.stride(-1)
166
+ if return_predicted_tokens
167
+ else 0, # always 1 if predicted tokens is enabled
168
  n_cols=V,
169
  n_non_ignore=total_n_non_ignore,
170
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
 
175
  reduction=reduction,
176
  softcap=softcap,
177
  RETURN_Z_LOSS=return_z_loss,
178
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
179
+ RETURN_PREDICTED_TOKENS=return_predicted_tokens,
180
  HAS_WEIGHT=True if ce_weight is not None else False,
181
  HAS_SOFTCAPPING=True if softcap is not None else False,
182
+ HAS_GRADIENTS=input_requires_grad,
183
  BLOCK_SIZE=BLOCK_SIZE,
184
  num_warps=32 if not is_hip() else 16,
185
  )
186
 
187
+ # Apply token scaling if requested
188
+ if use_token_scaling:
189
+ loss_1d_slice = loss_1d_slice * scaling_factors
190
+ if return_z_loss:
191
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
192
+
193
  loss_1d[start_idx:end_idx] = loss_1d_slice
194
  if return_z_loss:
195
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
196
+ if return_token_accuracy:
197
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
198
+ if return_predicted_tokens:
199
+ predicted_tokens_1d[start_idx:end_idx] = predicted_tokens_1d_slice
200
  grad_logits_chunk = logits_chunk # chunk_size x V
201
 
202
+ # Apply token scaling to gradients if requested
203
+ if use_token_scaling:
204
+ # Expand scaling factors to match gradient dimensions
205
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
206
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
207
 
208
+ if input_requires_grad:
209
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
 
 
 
 
 
 
 
 
 
210
 
211
+ if grad_weight is not None and input_requires_grad:
212
+ grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
213
+
214
+ if bias is not None and input_requires_grad:
215
  torch.add(
216
  input=grad_bias,
217
+ other=grad_logits_chunk.sum(dim=0),
218
  out=grad_bias,
219
  alpha=1.0,
220
  )
 
224
  # loss = loss_1d
225
  # z_loss = z_loss_1d if return_z_loss else None
226
 
227
+ if reduction == "none":
228
+ # Return per-token losses
229
+ loss = loss_1d
230
+ z_loss = z_loss_1d if return_z_loss else None
231
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
232
  else:
233
  loss = torch.sum(loss_1d)
234
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
235
+ # For accuracy, we compute the mean across all non-ignored tokens
236
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
237
+
238
+ predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
239
+
240
+ # Cast back to original dtype
241
+ grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
242
+ grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
243
+
244
+ return loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias
245
 
246
 
247
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
 
307
  reduction="mean",
308
  softcap=None,
309
  return_z_loss: bool = False,
310
+ accum_dtype=None,
311
+ use_token_scaling: bool = False,
312
+ return_token_accuracy: bool = False,
313
+ return_predicted_tokens: bool = False,
314
  ):
315
  """
316
  Fusing the last linear layer with cross-entropy loss
 
329
  ignore_index: the index to ignore in the target
330
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
331
  reduction: reduction to apply
332
+ accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
333
+ Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
334
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
335
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
336
+ Default: False.
337
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
338
+ return_predicted_tokens (bool): When `return_predicted_tokens` is `True`, returns per-token predicted class indices (argmax) without materializing logits. Default: `False`
339
  """
340
 
341
+ loss, z_loss, token_accuracy, predicted_tokens, grad_input, grad_weight, grad_bias = (
342
+ fused_linear_cross_entropy_forward(
343
+ _input=_input,
344
+ weight=weight,
345
+ target=target,
346
+ bias=bias,
347
+ ce_weight=ce_weight,
348
+ ignore_index=ignore_index,
349
+ lse_square_scale=lse_square_scale,
350
+ label_smoothing=label_smoothing,
351
+ reduction=reduction,
352
+ softcap=softcap,
353
+ return_z_loss=return_z_loss,
354
+ accum_dtype=accum_dtype,
355
+ use_token_scaling=use_token_scaling,
356
+ return_token_accuracy=return_token_accuracy,
357
+ return_predicted_tokens=return_predicted_tokens,
358
+ )
359
  )
360
  # downcast to dtype and store for backward
361
  ctx.save_for_backward(
362
  grad_input.detach(),
363
  grad_weight.detach() if grad_weight is not None else None,
364
+ grad_bias.detach() if grad_bias is not None else None,
365
  )
366
  ctx.return_z_loss = return_z_loss
367
+ ctx.return_token_accuracy = return_token_accuracy
368
+ ctx.return_predicted_tokens = return_predicted_tokens
369
+ return loss, z_loss, token_accuracy, predicted_tokens
370
 
371
  @staticmethod
372
  @amp_custom_bwd
373
+ def backward(ctx, grad_output, grad_output2, grad_output3, grad_output4):
374
  if ctx.return_z_loss:
375
  del grad_output2 # z_loss is only for logging
376
+ if ctx.return_token_accuracy:
377
+ del grad_output3 # token_accuracy is only for metrics
378
+ if ctx.return_predicted_tokens:
379
+ del grad_output4 # predicted_tokens is only for metrics
380
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
381
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
382
  grad_output, grad_input, grad_weight, grad_bias
 
393
  None,
394
  None,
395
  None,
396
+ None,
397
+ None, # use_token_scaling
398
+ None, # return_token_accuracy
399
+ None, # return_predicted_tokens
400
+ )
build/torch-cuda/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
  from .utils import calculate_settings
8
  from .utils import compare_version
9
  from .utils import ensure_contiguous
 
10
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
  try:
13
  # typical import path with dispatch available
14
  from triton.language.extra.libdevice import tanh
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
  tanh_result = tanh(tanh_arg)
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
44
  tl.store(c + col_offsets, c_row, mask=mask)
45
 
46
 
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
66
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
  tanh_result = tanh(tanh_arg)
68
  geglu_a = 0.5 * a_row * (1 + tanh_result)
 
69
 
70
- db_row = dc_row * geglu_a
71
 
72
  # Gradient w.r.t. a can be computed with:
73
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
78
  da_row = dc_row * b_row * (term1 + term2)
79
 
80
  tl.store(a + col_offsets, da_row, mask=mask)
81
- tl.store(b + col_offsets, db_row, mask=mask)
82
 
83
 
84
  def geglu_forward(a, b):
@@ -138,4 +140,4 @@ class LigerGELUMulFunction(torch.autograd.Function):
138
  def backward(ctx, dc):
139
  a, b = ctx.saved_tensors
140
  a, b = geglu_backward(a, b, dc)
141
- return a, b
 
7
  from .utils import calculate_settings
8
  from .utils import compare_version
9
  from .utils import ensure_contiguous
10
+ from .utils import is_npu_available
11
 
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
  try:
14
  # typical import path with dispatch available
15
  from triton.language.extra.libdevice import tanh
 
41
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
42
  tanh_result = tanh(tanh_arg)
43
  geglu_a = 0.5 * a_row * (1 + tanh_result)
44
+ c_row = geglu_a.cast(b_row.dtype) * b_row
45
  tl.store(c + col_offsets, c_row, mask=mask)
46
 
47
 
 
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
68
  tanh_result = tanh(tanh_arg)
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
71
 
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
73
 
74
  # Gradient w.r.t. a can be computed with:
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
 
80
  da_row = dc_row * b_row * (term1 + term2)
81
 
82
  tl.store(a + col_offsets, da_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
84
 
85
 
86
  def geglu_forward(a, b):
 
140
  def backward(ctx, dc):
141
  a, b = ctx.saved_tensors
142
  a, b = geglu_backward(a, b, dc)
143
+ return a, b
build/torch-cuda/group_norm.py CHANGED
@@ -6,8 +6,10 @@ import triton.language as tl
6
 
7
  from .utils import compare_version
8
  from .utils import ensure_contiguous
 
 
9
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
11
  try:
12
  # typical import path with dispatch available
13
  from triton.language.extra.libdevice import rsqrt
@@ -17,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0"):
17
  else:
18
  from triton.language.math import rsqrt
19
 
20
- MAX_FUSED_SIZE = 65536
 
 
 
21
 
22
 
23
  @triton.jit
@@ -72,20 +77,21 @@ def _group_norm_forward_kernel(
72
  # 1/std
73
  rstd = rsqrt(variance + eps)
74
 
75
- # Normalize
 
 
76
  hidden_size_per_channel = hidden_size // channels_per_group
77
- for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
- W = tl.load(W_ptr + channel_idx)
79
- B = tl.load(B_ptr + channel_idx)
80
- for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
81
- hidden_size_offsets = i + block_range
82
- mask = hidden_size_offsets < hidden_size_per_channel
83
- X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
84
- Y = (X - m) * rstd * W + B
85
- tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
-
87
- X_ptr += hidden_size_per_channel
88
- Y_ptr += hidden_size_per_channel
89
 
90
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@@ -302,4 +308,4 @@ class LigerGroupNormFunction(torch.autograd.Function):
302
  def backward(ctx, dY):
303
  X, W, B, Mean, RSTD = ctx.saved_tensors
304
  DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
305
- return DX, DW, DB, None, None, None
 
6
 
7
  from .utils import compare_version
8
  from .utils import ensure_contiguous
9
+ from .utils import infer_device
10
+ from .utils import is_npu_available
11
 
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
  try:
14
  # typical import path with dispatch available
15
  from triton.language.extra.libdevice import rsqrt
 
19
  else:
20
  from triton.language.math import rsqrt
21
 
22
+ if infer_device() == "npu":
23
+ MAX_FUSED_SIZE = 16384 # 8192
24
+ else:
25
+ MAX_FUSED_SIZE = 65536
26
 
27
 
28
  @triton.jit
 
77
  # 1/std
78
  rstd = rsqrt(variance + eps)
79
 
80
+ # Normalize — flat loop over full hidden_size (not per-channel)
81
+ # This avoids the nested channel × per_channel_hidden loop where
82
+ # BLOCK_SIZE >> hidden_size_per_channel causes massive padding waste.
83
  hidden_size_per_channel = hidden_size // channels_per_group
84
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
85
+ hidden_size_offsets = i + block_range
86
+ mask = hidden_size_offsets < hidden_size
87
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
88
+ # Determine which channel each element belongs to, then load W/B
89
+ local_channel = hidden_size_offsets // hidden_size_per_channel
90
+ global_channel = group_idx * channels_per_group + local_channel
91
+ W = tl.load(W_ptr + global_channel, mask=mask)
92
+ B = tl.load(B_ptr + global_channel, mask=mask)
93
+ Y = (X - m) * rstd * W + B
94
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
 
95
 
96
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
97
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
 
308
  def backward(ctx, dY):
309
  X, W, B, Mean, RSTD = ctx.saved_tensors
310
  DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
311
+ return DX, DW, DB, None, None, None
build/torch-cuda/jsd.py CHANGED
@@ -198,4 +198,4 @@ class LigerJSDFunction(torch.autograd.Function):
198
  None,
199
  None,
200
  None,
201
- )
 
198
  None,
199
  None,
200
  None,
201
+ )
build/torch-cuda/kl_div.py CHANGED
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
21
  return num_warps
22
 
23
 
24
- MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
 
 
 
 
 
25
 
26
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
 
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
116
 
117
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
  BT, V = y_pred.shape
119
- BLOCK_SIZE = (
120
- min(8192, triton.next_power_of_2(V))
121
- if infer_device() == "xpu"
122
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
- )
124
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
 
126
  grid = (BT,)
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
159
 
160
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
  BT, V = target.shape
162
- BLOCK_SIZE = (
163
- min(8192, triton.next_power_of_2(V))
164
- if infer_device() == "xpu"
165
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
- )
167
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
 
169
  grid = (BT,)
@@ -259,4 +256,4 @@ class LigerKLDivLossFunction(torch.autograd.Function):
259
  None,
260
  None,
261
  None,
262
- )
 
21
  return num_warps
22
 
23
 
24
+ if infer_device() == "xpu":
25
+ MAX_FUSED_SIZE = 8192
26
+ elif infer_device() == "npu":
27
+ MAX_FUSED_SIZE = 8192
28
+ else:
29
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
30
 
31
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
32
 
 
121
 
122
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
123
  BT, V = y_pred.shape
124
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
 
 
 
 
125
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
126
 
127
  grid = (BT,)
 
160
 
161
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
162
  BT, V = target.shape
163
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
 
 
 
 
164
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
165
 
166
  grid = (BT,)
 
256
  None,
257
  None,
258
  None,
259
+ )
build/torch-cuda/layer_norm.py CHANGED
@@ -8,8 +8,11 @@ import triton.language as tl
8
  from .utils import calculate_settings
9
  from .utils import compare_version
10
  from .utils import ensure_contiguous
 
 
 
11
 
12
- if compare_version("triton", operator.ge, "3.0.0"):
13
  try:
14
  # typical import path with dispatch available
15
  from triton.language.extra.libdevice import rsqrt
@@ -43,111 +46,151 @@ def _layer_norm_forward_kernel(
43
  https://arxiv.org/abs/1607.06450
44
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
  """
46
- row_idx = tl.program_id(0)
47
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
  mask = col_offsets < n_cols
49
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
 
 
 
 
 
 
 
 
 
 
62
  rstd = rsqrt(var + eps)
63
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
 
66
 
67
- Y_row = Xmm * rstd * W_row + B_row
 
 
68
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
 
70
 
71
 
72
  @triton.jit
73
  def _layer_norm_backward_kernel(
74
  X_ptr, # pointer to input, shape (n_rows, n_cols)
 
75
  W_ptr, # pointer to weights, shape (n_cols,)
76
  Mean_ptr, # pointer to mean, shape (n_rows,)
 
77
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
 
78
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
79
- DW_ptr, # pointer to weights grad, shape (n_cols,)
80
- DB_ptr, # pointer to bias grad, shape (n_cols,)
81
- DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
- stride_x, # stride of each row in input
83
  stride_dx, # stride of each row in input grad
 
84
  stride_dw, # stride of each row in weights grad
 
85
  stride_db, # stride of each row in bias grad
 
86
  stride_dy, # stride of each row in output grad
87
  n_rows,
88
  n_cols,
89
  rows_per_program: tl.constexpr,
90
  BLOCK_SIZE: tl.constexpr,
91
- dtype: tl.constexpr,
92
  ):
93
  """
94
  References:
95
  https://arxiv.org/abs/1607.06450
96
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
  """
100
- row_block_id = tl.program_id(0)
101
  row_start = row_block_id * rows_per_program
102
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
103
  cols = tl.arange(0, BLOCK_SIZE)
104
  mask = cols < n_cols
105
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
  db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
 
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
114
-
115
- for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
 
 
 
 
 
 
 
 
 
 
 
 
124
  c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
  c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
 
129
- dw_row += dy * x_hat
130
- db_row += dy
131
 
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
137
 
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
140
 
141
 
142
  def layer_norm_forward(X, W, B, eps):
 
 
 
 
 
 
 
 
 
 
143
  shape = X.shape
144
  dim = shape[-1]
145
  X = X.view(-1, dim)
146
  n_rows, n_cols = X.shape
 
 
147
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
 
 
148
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
 
 
151
  if X.shape[1] != W.shape[0]:
152
  raise ValueError(
153
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -157,9 +200,11 @@ def layer_norm_forward(X, W, B, eps):
157
  # XPU-specific optimization
158
  kernel_args = {}
159
  if X.device.type == "xpu":
160
- kernel_args["grf_mode"] = "large"
161
 
162
- _layer_norm_forward_kernel[(n_rows,)](
 
 
163
  Y,
164
  Y.stride(0),
165
  X,
@@ -176,12 +221,25 @@ def layer_norm_forward(X, W, B, eps):
176
  eps,
177
  BLOCK_SIZE=BLOCK_SIZE,
178
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
180
  )
 
181
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
 
183
 
184
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
 
 
 
 
 
 
 
 
 
 
 
 
185
  shape = dY.shape
186
  dim = shape[-1]
187
  dY = dY.view(-1, dim)
@@ -192,60 +250,57 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
192
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
  elif X.device.type == "xpu":
194
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
 
 
195
 
196
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
199
 
 
200
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
205
-
206
  rows_per_program = math.ceil(n_rows / sm_count)
207
  grid = (sm_count,)
208
- triton_dtype = (
209
- tl.float32
210
- if X.dtype == torch.float32
211
- else tl.bfloat16
212
- if X.dtype == torch.bfloat16
213
- else tl.float16
214
- if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
216
- )
217
 
 
 
 
 
218
  # XPU-specific optimization
219
- kernel_args = {}
220
  if X.device.type == "xpu":
221
- kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
 
222
 
 
223
  _layer_norm_backward_kernel[grid](
224
  X,
 
225
  W,
226
  Mean,
 
227
  RSTD,
 
228
  DX,
229
- _DW,
230
- _DB,
231
- dY,
232
- X.stride(0),
233
  DX.stride(0),
 
234
  _DW.stride(0),
 
235
  _DB.stride(0),
 
236
  dY.stride(0),
237
  n_rows,
238
  n_cols,
239
- rows_per_program,
240
  BLOCK_SIZE=BLOCK_SIZE,
241
- dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
243
  )
244
 
 
245
  DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
 
248
- DX = DX.view(*shape)
249
  return DX, DW, DB
250
 
251
 
@@ -262,4 +317,4 @@ class LigerLayerNormFunction(torch.autograd.Function):
262
  def backward(ctx, dY):
263
  X, W, B, Mean, RSTD = ctx.saved_tensors
264
  DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
265
- return DX, DW, DB, None
 
8
  from .utils import calculate_settings
9
  from .utils import compare_version
10
  from .utils import ensure_contiguous
11
+ from .utils import get_npu_core_count
12
+ from .utils import set_large_grf_mode
13
+ from .utils import is_npu_available
14
 
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
16
  try:
17
  # typical import path with dispatch available
18
  from triton.language.extra.libdevice import rsqrt
 
46
  https://arxiv.org/abs/1607.06450
47
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
48
  """
49
+ row_idx = tl.program_id(0).to(tl.int64)
50
  col_offsets = tl.arange(0, BLOCK_SIZE)
51
  mask = col_offsets < n_cols
52
 
53
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
54
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
55
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
56
+ W_f32 = W_row.to(tl.float32)
57
+ B_f32 = B_row.to(tl.float32)
58
+
59
+ # Calculate pointers for this row
60
+ row_X_ptr = X_ptr + row_idx * X_row_stride
61
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
62
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
63
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
64
+
65
+ # Load input data and convert to fp32 for numerical stability
66
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
67
+ X_f32 = X_row.to(tl.float32)
68
+
69
+ # Compute statistics in fp32 for numerical stability
70
+ mean = tl.sum(X_f32, axis=0) / n_cols
71
+ X_centered = X_f32 - mean
72
+ # Apply mask to variance calculation to exclude contributions from masked elements
73
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
74
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
75
  rstd = rsqrt(var + eps)
76
 
77
+ # Store statistics (convert back to original dtype only once)
78
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
79
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
80
 
81
+ # Fused normalization and affine transformation
82
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
83
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
84
 
85
+ # Store output (single conversion back to original dtype)
86
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
87
 
88
 
89
  @triton.jit
90
  def _layer_norm_backward_kernel(
91
  X_ptr, # pointer to input, shape (n_rows, n_cols)
92
+ stride_x, # stride of each row in input
93
  W_ptr, # pointer to weights, shape (n_cols,)
94
  Mean_ptr, # pointer to mean, shape (n_rows,)
95
+ stride_mean, # stride of each row in mean
96
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
97
+ stride_rstd, # stride of each row in rstd
98
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
 
 
 
 
99
  stride_dx, # stride of each row in input grad
100
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
101
  stride_dw, # stride of each row in weights grad
102
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
103
  stride_db, # stride of each row in bias grad
104
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
105
  stride_dy, # stride of each row in output grad
106
  n_rows,
107
  n_cols,
108
  rows_per_program: tl.constexpr,
109
  BLOCK_SIZE: tl.constexpr,
 
110
  ):
111
  """
112
  References:
113
  https://arxiv.org/abs/1607.06450
114
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
 
 
115
  """
116
+ row_block_id = tl.program_id(0).to(tl.int64)
117
  row_start = row_block_id * rows_per_program
118
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
119
  cols = tl.arange(0, BLOCK_SIZE)
120
  mask = cols < n_cols
121
 
122
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
123
  db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
124
 
125
+ # Pre-load weights once (same optimization as forward pass)
126
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
127
+ w_f32 = w.to(tl.float32)
128
+
129
+ for row_idx in range(row_start, row_end):
130
+ # Calculate pointers for this specific row
131
+ row_X_ptr = X_ptr + row_idx * stride_x
132
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
133
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
134
+ row_Mean_ptr = Mean_ptr + row_idx * stride_mean
135
+ row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
136
+
137
+ # Load data for this row
138
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
139
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
140
+ mean = tl.load(row_Mean_ptr)
141
+ rstd = tl.load(row_RSTD_ptr)
142
+
143
+ # Convert to fp32 for numerical stability
144
+ x_f32 = x.to(tl.float32)
145
+ dy_f32 = dy.to(tl.float32)
146
+ mean_f32 = mean.to(tl.float32)
147
+ rstd_f32 = rstd.to(tl.float32)
148
+
149
+ # Compute backward pass for this row
150
+ x_hat = (x_f32 - mean_f32) * rstd_f32
151
+ wdy = w_f32 * dy_f32
152
  c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
153
  c2 = tl.sum(wdy, axis=0) / n_cols
154
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
 
155
 
156
+ # Store input gradient
157
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
158
 
159
+ # Accumulate weight and bias gradients for this thread block's assigned rows
160
+ dw = dy_f32 * x_hat
161
+ db = dy_f32
162
+ dW_row += dw
163
+ db_row += db
164
 
165
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
166
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
167
 
168
 
169
  def layer_norm_forward(X, W, B, eps):
170
+ """
171
+ Args:
172
+ X: Input tensor of shape (..., hidden_size)
173
+ W: Weight tensor of shape (hidden_size,)
174
+ B: Bias tensor of shape (hidden_size,)
175
+ eps: Small constant for numerical stability
176
+
177
+ Returns:
178
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
179
+ """
180
  shape = X.shape
181
  dim = shape[-1]
182
  X = X.view(-1, dim)
183
  n_rows, n_cols = X.shape
184
+
185
+ # Calculate optimal block size and warp configuration
186
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
187
+
188
+ # Allocate output tensors
189
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
190
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
191
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
192
+
193
+ # Validate input dimensions
194
  if X.shape[1] != W.shape[0]:
195
  raise ValueError(
196
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
 
200
  # XPU-specific optimization
201
  kernel_args = {}
202
  if X.device.type == "xpu":
203
+ set_large_grf_mode(kernel_args)
204
 
205
+ # Launch kernel with one thread block per row for optimal performance
206
+ grid = (n_rows,)
207
+ _layer_norm_forward_kernel[grid](
208
  Y,
209
  Y.stride(0),
210
  X,
 
221
  eps,
222
  BLOCK_SIZE=BLOCK_SIZE,
223
  num_warps=num_warps,
224
+ **kernel_args,
225
  )
226
+
227
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
228
 
229
 
230
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
231
+ """
232
+ Args:
233
+ dY: Gradient of output
234
+ X: Input tensor
235
+ W: Weight tensor
236
+ B: Bias tensor
237
+ Mean: Pre-computed mean
238
+ RSTD: Pre-computed reciprocal standard deviation
239
+
240
+ Returns:
241
+ Tuple of (input_grad, weight_grad, bias_grad)
242
+ """
243
  shape = dY.shape
244
  dim = shape[-1]
245
  dY = dY.view(-1, dim)
 
250
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
251
  elif X.device.type == "xpu":
252
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
253
+ elif X.device.type == "npu":
254
+ sm_count = get_npu_core_count()
255
 
256
+ # fp32 for numerical stability especially.
257
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
258
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
259
 
260
+ # Calculate optimal block size and warp configuration
261
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
262
  if n_cols > BLOCK_SIZE:
263
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
 
 
 
264
  rows_per_program = math.ceil(n_rows / sm_count)
265
  grid = (sm_count,)
 
 
 
 
 
 
 
 
 
266
 
267
+ # Allocate gradient tensors
268
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
269
+
270
+ kernel_args = {"num_warps": num_warps}
271
  # XPU-specific optimization
 
272
  if X.device.type == "xpu":
273
+ kernel_args.update({"num_warps": 32, "num_stages": 4})
274
+ set_large_grf_mode(kernel_args)
275
 
276
+ # Launch kernel with one thread block per row for optimal performance
277
  _layer_norm_backward_kernel[grid](
278
  X,
279
+ X.stride(0),
280
  W,
281
  Mean,
282
+ Mean.stride(0),
283
  RSTD,
284
+ RSTD.stride(0),
285
  DX,
 
 
 
 
286
  DX.stride(0),
287
+ _DW,
288
  _DW.stride(0),
289
+ _DB,
290
  _DB.stride(0),
291
+ dY,
292
  dY.stride(0),
293
  n_rows,
294
  n_cols,
295
+ rows_per_program=rows_per_program,
296
  BLOCK_SIZE=BLOCK_SIZE,
297
+ **kernel_args,
 
298
  )
299
 
300
+ DX = DX.view(*shape)
301
  DW = _DW.sum(dim=0).to(W.dtype)
302
+ DB = _DB.sum(dim=0).to(B.dtype)
303
 
 
304
  return DX, DW, DB
305
 
306
 
 
317
  def backward(ctx, dY):
318
  X, W, B, Mean, RSTD = ctx.saved_tensors
319
  DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
320
+ return DX, DW, DB, None
build/torch-cuda/layers.py CHANGED
@@ -1,39 +1,463 @@
 
 
 
 
1
  import torch
 
 
 
 
 
 
 
 
 
 
 
2
  from .rms_norm import LigerRMSNormFunction
 
 
 
 
3
 
4
- class LigerRMSNorm(torch.nn.Module):
5
- """
6
- RMSNorm module that uses the optimized LigerRMSNormFunction.
7
-
8
- Args:
9
- hidden_size (int): The size of the hidden dimension.
10
- eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
11
- offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
12
- casting_mode (str, optional): The casting mode to use. Defaults to "llama".
13
- in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
14
- """
15
-
16
-
17
- weight: torch.Tensor
18
- variance_epsilon: float
19
-
20
- def forward(self, hidden_states):
21
- """
22
- Apply RMS normalization to the input tensor.
23
-
24
- Args:
25
- hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
26
-
27
- Returns:
28
- torch.Tensor: Normalized tensor of the same shape as input
29
- """
 
 
30
  return LigerRMSNormFunction.apply(
31
- hidden_states,
32
- self.weight,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  self.variance_epsilon,
34
- 0,
35
- "llama",
36
- True
37
  )
38
-
39
- __all__ = ["LigerRMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple
4
+
5
  import torch
6
+ import torch.nn as nn
7
+
8
+ from .cross_entropy import LigerCrossEntropyFunction
9
+ from .dyt import LigerDyTFunction
10
+ from .fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
11
+ from .geglu import LigerGELUMulFunction
12
+ from .group_norm import LigerGroupNormFunction
13
+ from .jsd import LigerJSDFunction
14
+ from .kl_div import LigerKLDivLossFunction
15
+ from .layer_norm import LigerLayerNormFunction
16
+ from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
17
  from .rms_norm import LigerRMSNormFunction
18
+ from .rope import LigerRopeFunction
19
+ from .swiglu import LigerSiLUMulFunction
20
+ from .tvd import LigerTVDLossFunction
21
+
22
 
23
+ class LigerRMSNorm(nn.Module):
24
+ def __init__(
25
+ self,
26
+ hidden_size: int,
27
+ eps: float = 1e-6,
28
+ offset: float = 0.0,
29
+ casting_mode: str = "llama",
30
+ init_fn: str = "ones",
31
+ in_place: bool = True,
32
+ row_mode: Optional[bool] = None,
33
+ elementwise_affine: bool = True,
34
+ ):
35
+ super().__init__()
36
+ assert init_fn in ("ones", "zeros"), f"init_fn must be 'ones' or 'zeros', got {init_fn}"
37
+ self.hidden_size = hidden_size
38
+ self.variance_epsilon = eps
39
+ self.offset = offset
40
+ self.casting_mode = casting_mode
41
+ self.in_place = in_place
42
+ self.row_mode = row_mode
43
+ self.elementwise_affine = elementwise_affine
44
+ if elementwise_affine:
45
+ init = torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
46
+ self.weight = nn.Parameter(init)
47
+ else:
48
+ self.register_parameter("weight", None)
49
+
50
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
51
  return LigerRMSNormFunction.apply(
52
+ hidden_states,
53
+ self.weight,
54
+ self.variance_epsilon,
55
+ self.offset,
56
+ self.casting_mode,
57
+ self.in_place,
58
+ self.row_mode,
59
+ )
60
+
61
+ def extra_repr(self) -> str:
62
+ return (
63
+ f"{self.hidden_size}, eps={self.variance_epsilon}, offset={self.offset}, "
64
+ f"in_place={self.in_place}, row_mode={self.row_mode}"
65
+ )
66
+
67
+
68
+ class LigerLayerNorm(nn.Module):
69
+ def __init__(
70
+ self,
71
+ hidden_size: int,
72
+ eps: float = 1e-6,
73
+ bias: bool = False,
74
+ init_fn: str = "ones",
75
+ ):
76
+ super().__init__()
77
+ assert init_fn in ("ones", "zeros"), f"init_fn must be 'ones' or 'zeros', got {init_fn}"
78
+ self.hidden_size = hidden_size
79
+ self.variance_epsilon = eps
80
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
81
+ self.bias = nn.Parameter(torch.randn(hidden_size) if bias else torch.zeros(hidden_size))
82
+
83
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
84
+ return LigerLayerNormFunction.apply(hidden_states, self.weight, self.bias, self.variance_epsilon)
85
+
86
+ def extra_repr(self) -> str:
87
+ return f"{self.hidden_size}, eps={self.variance_epsilon}"
88
+
89
+
90
+ class LigerGroupNorm(nn.Module):
91
+ def __init__(
92
+ self,
93
+ num_channels: int,
94
+ num_groups: int,
95
+ eps: float = 1e-6,
96
+ bias: bool = False,
97
+ init_fn: str = "ones",
98
+ ):
99
+ super().__init__()
100
+ assert init_fn in ("ones", "zeros"), f"init_fn must be 'ones' or 'zeros', got {init_fn}"
101
+ assert num_channels % num_groups == 0, (
102
+ f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
103
+ )
104
+ self.num_channels = num_channels
105
+ self.num_groups = num_groups
106
+ self.variance_epsilon = eps
107
+ self.weight = nn.Parameter(torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels))
108
+ self.bias = nn.Parameter(torch.randn(num_channels) if bias else torch.zeros(num_channels))
109
+
110
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ assert hidden_states.dim() >= 3, f"Input must have at least 3 dimensions, got {hidden_states.dim()}"
112
+ assert hidden_states.size(1) == self.num_channels, (
113
+ f"Input must have {self.num_channels} channels, got {hidden_states.size(1)}"
114
+ )
115
+ return LigerGroupNormFunction.apply(
116
+ hidden_states,
117
+ self.weight,
118
+ self.bias,
119
+ self.num_channels,
120
+ self.num_groups,
121
  self.variance_epsilon,
 
 
 
122
  )
123
+
124
+ def extra_repr(self) -> str:
125
+ return f"num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.variance_epsilon}"
126
+
127
+
128
+ class LigerDyT(nn.Module):
129
+ def __init__(self, hidden_size: int, beta: bool = True, init_alpha: float = 0.5):
130
+ super().__init__()
131
+ self.hidden_size = hidden_size
132
+ self.init_alpha = init_alpha
133
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
134
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
135
+ self.beta = nn.Parameter(torch.zeros(hidden_size)) if beta else None
136
+
137
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
138
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
139
+
140
+ def extra_repr(self) -> str:
141
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta is not None}"
142
+
143
+
144
+ class LigerCrossEntropyLoss(nn.Module):
145
+ def __init__(
146
+ self,
147
+ weight: Optional[torch.Tensor] = None,
148
+ ignore_index: int = -100,
149
+ lse_square_scale: float = 0.0,
150
+ label_smoothing: float = 0.0,
151
+ reduction: str = "mean",
152
+ softcap: Optional[float] = None,
153
+ ):
154
+ super().__init__()
155
+ assert 0.0 <= label_smoothing <= 1.0, f"label_smoothing must be in [0, 1], got {label_smoothing}"
156
+ assert reduction in ("mean", "sum", "none"), f"reduction must be 'mean', 'sum', or 'none', got {reduction}"
157
+ assert softcap is None or softcap > 0, f"softcap must be > 0 or None, got {softcap}"
158
+ self.weight = weight
159
+ self.ignore_index = ignore_index
160
+ self.lse_square_scale = lse_square_scale
161
+ self.label_smoothing = label_smoothing
162
+ self.reduction = reduction
163
+ self.softcap = softcap
164
+
165
+ def forward(self, _input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
166
+ loss, _, _, _ = LigerCrossEntropyFunction.apply(
167
+ _input,
168
+ target,
169
+ self.weight,
170
+ self.ignore_index,
171
+ self.lse_square_scale,
172
+ self.label_smoothing,
173
+ self.reduction,
174
+ self.softcap,
175
+ False,
176
+ False,
177
+ False,
178
+ )
179
+ return loss
180
+
181
+
182
+ class LigerFusedLinearCrossEntropyLoss(nn.Module):
183
+ def __init__(
184
+ self,
185
+ ce_weight: Optional[torch.Tensor] = None,
186
+ ignore_index: int = -100,
187
+ lse_square_scale: float = 0.0,
188
+ label_smoothing: float = 0.0,
189
+ reduction: str = "mean",
190
+ softcap: Optional[float] = None,
191
+ accum_dtype: Optional[torch.dtype] = None,
192
+ use_token_scaling: bool = False,
193
+ ):
194
+ super().__init__()
195
+ assert 0.0 <= label_smoothing <= 1.0, f"label_smoothing must be in [0, 1], got {label_smoothing}"
196
+ assert reduction in ("mean", "sum", "none"), f"reduction must be 'mean', 'sum', or 'none', got {reduction}"
197
+ assert softcap is None or softcap > 0, f"softcap must be > 0 or None, got {softcap}"
198
+ self.ce_weight = ce_weight
199
+ self.ignore_index = ignore_index
200
+ self.lse_square_scale = lse_square_scale
201
+ self.label_smoothing = label_smoothing
202
+ self.reduction = reduction
203
+ self.softcap = softcap
204
+ self.accum_dtype = accum_dtype
205
+ self.use_token_scaling = use_token_scaling
206
+
207
+ def forward(
208
+ self,
209
+ lin_weight: torch.Tensor,
210
+ _input: torch.Tensor,
211
+ target: torch.Tensor,
212
+ bias: Optional[torch.Tensor] = None,
213
+ ) -> torch.Tensor:
214
+ loss, _, _, _ = LigerFusedLinearCrossEntropyFunction.apply(
215
+ _input,
216
+ lin_weight,
217
+ target,
218
+ bias,
219
+ self.ce_weight,
220
+ self.ignore_index,
221
+ self.lse_square_scale,
222
+ self.label_smoothing,
223
+ self.reduction,
224
+ self.softcap,
225
+ False,
226
+ self.accum_dtype,
227
+ self.use_token_scaling,
228
+ False,
229
+ False,
230
+ )
231
+ return loss
232
+
233
+
234
+ class LigerJSD(nn.Module):
235
+ def __init__(self, beta: float = 0.5, ignore_index: int = -100):
236
+ super().__init__()
237
+ self.beta = beta
238
+ self.ignore_index = ignore_index
239
+
240
+ def forward(
241
+ self,
242
+ log_q: torch.Tensor,
243
+ log_p: torch.Tensor,
244
+ shift_labels: Optional[torch.Tensor] = None,
245
+ ) -> torch.Tensor:
246
+ return LigerJSDFunction.apply(log_q, log_p, shift_labels, self.beta, self.ignore_index)
247
+
248
+
249
+ class LigerKLDIVLoss(nn.KLDivLoss):
250
+ def __init__(self, eps: float = 1e-10, *args, **kwargs):
251
+ super().__init__(*args, **kwargs)
252
+ self.eps = eps
253
+
254
+ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
255
+ return LigerKLDivLossFunction.apply(y_pred, y_true, self.reduction, self.log_target, self.eps)
256
+
257
+
258
+ class LigerTVDLoss(nn.Module):
259
+ def __init__(self, reduction: str = "batchmean", ignore_index: int = -100):
260
+ super().__init__()
261
+ self.reduction = reduction
262
+ self.ignore_index = ignore_index
263
+
264
+ def forward(
265
+ self,
266
+ p: torch.Tensor,
267
+ q: torch.Tensor,
268
+ shift_labels: Optional[torch.Tensor] = None,
269
+ ) -> torch.Tensor:
270
+ return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
271
+
272
+
273
+ class LigerSwiGLUMLP(nn.Module):
274
+ """SwiGLU MLP block. ``config`` must expose ``hidden_size``, ``intermediate_size``,
275
+ and ``hidden_act`` (must be ``silu`` or ``swish``)."""
276
+
277
+ def __init__(self, config):
278
+ super().__init__()
279
+ if config.hidden_act not in ("silu", "swish"):
280
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
281
+ self.config = config
282
+ self.hidden_size = config.hidden_size
283
+ self.intermediate_size = config.intermediate_size
284
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
285
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
286
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
287
+
288
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
289
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
290
+
291
+
292
+ class LigerGEGLUMLP(nn.Module):
293
+ """GEGLU MLP block. ``config`` must expose ``hidden_size`` and ``intermediate_size``.
294
+ Uses the tanh approximation of GELU (matches Gemma 1/1.1/2)."""
295
+
296
+ def __init__(self, config):
297
+ super().__init__()
298
+ self.config = config
299
+ self.hidden_size = config.hidden_size
300
+ self.intermediate_size = config.intermediate_size
301
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
302
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
303
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
304
+
305
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
306
+ return self.down_proj(LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
307
+
308
+
309
+ @dataclass
310
+ class CrossEntropyOutput:
311
+ loss: torch.Tensor
312
+ z_loss: Optional[torch.Tensor] = None
313
+ token_accuracy: Optional[torch.Tensor] = None
314
+ predicted_tokens: Optional[torch.Tensor] = None
315
+
316
+
317
+ def liger_fused_linear_cross_entropy(
318
+ input: torch.Tensor,
319
+ weight: torch.Tensor,
320
+ target: torch.Tensor,
321
+ bias: Optional[torch.Tensor] = None,
322
+ ce_weight: Optional[torch.Tensor] = None,
323
+ ignore_index: int = -100,
324
+ lse_square_scale: float = 0.0,
325
+ label_smoothing: float = 0.0,
326
+ reduction: str = "mean",
327
+ softcap: Optional[float] = None,
328
+ return_z_loss: bool = False,
329
+ accum_dtype: Optional[torch.dtype] = None,
330
+ use_token_scaling: bool = False,
331
+ return_token_accuracy: bool = False,
332
+ return_predicted_tokens: bool = False,
333
+ ):
334
+ loss, z_loss, token_accuracy, predicted_tokens = LigerFusedLinearCrossEntropyFunction.apply(
335
+ input,
336
+ weight,
337
+ target,
338
+ bias,
339
+ ce_weight,
340
+ ignore_index,
341
+ lse_square_scale,
342
+ label_smoothing,
343
+ reduction,
344
+ softcap,
345
+ return_z_loss,
346
+ accum_dtype,
347
+ use_token_scaling,
348
+ return_token_accuracy,
349
+ return_predicted_tokens,
350
+ )
351
+ if not return_z_loss and not return_token_accuracy and not return_predicted_tokens:
352
+ return loss
353
+ return CrossEntropyOutput(
354
+ loss=loss,
355
+ z_loss=z_loss,
356
+ token_accuracy=token_accuracy,
357
+ predicted_tokens=predicted_tokens,
358
+ )
359
+
360
+
361
+ def LigerForCausalLMLoss(
362
+ hidden_states: torch.Tensor,
363
+ lm_head_weight: torch.Tensor,
364
+ labels: torch.Tensor,
365
+ hidden_size: int,
366
+ num_items_in_batch: Optional[int] = None,
367
+ ignore_index: int = -100,
368
+ shift_labels: Optional[torch.Tensor] = None,
369
+ final_logit_softcapping: Optional[float] = None,
370
+ return_token_accuracy: bool = False,
371
+ return_predicted_tokens: bool = False,
372
+ **kwargs,
373
+ ):
374
+ """Drop-in replacement for ``transformers.loss.ForCausalLMLoss`` that fuses the
375
+ final ``lm_head`` projection with the cross-entropy loss. Returns a scalar
376
+ ``loss`` by default; returns a :class:`CrossEntropyOutput` when
377
+ ``return_token_accuracy`` or ``return_predicted_tokens`` is set."""
378
+ applicable_params = inspect.signature(liger_fused_linear_cross_entropy).parameters
379
+ kwargs = {k: v for k, v in kwargs.items() if k in applicable_params}
380
+
381
+ if shift_labels is None:
382
+ labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
383
+ shift_labels = labels[..., 1:].contiguous()
384
+
385
+ hidden_states = hidden_states.view(-1, hidden_size)
386
+ shift_labels = shift_labels.view(-1).to(hidden_states.device)
387
+
388
+ reduction = "sum" if num_items_in_batch is not None else "mean"
389
+ result = liger_fused_linear_cross_entropy(
390
+ hidden_states,
391
+ lm_head_weight,
392
+ shift_labels,
393
+ reduction=reduction,
394
+ ignore_index=ignore_index,
395
+ softcap=final_logit_softcapping,
396
+ return_token_accuracy=return_token_accuracy,
397
+ return_predicted_tokens=return_predicted_tokens,
398
+ **kwargs,
399
+ )
400
+
401
+ if isinstance(result, CrossEntropyOutput):
402
+ loss = result.loss
403
+ token_accuracy = result.token_accuracy
404
+ predicted_tokens = result.predicted_tokens
405
+ else:
406
+ loss = result
407
+ token_accuracy = None
408
+ predicted_tokens = None
409
+
410
+ if reduction == "sum":
411
+ loss = loss / num_items_in_batch
412
+
413
+ if return_token_accuracy or return_predicted_tokens:
414
+ return CrossEntropyOutput(
415
+ loss=loss,
416
+ token_accuracy=token_accuracy,
417
+ predicted_tokens=predicted_tokens,
418
+ )
419
+ return loss
420
+
421
+
422
+ def liger_rotary_pos_emb(
423
+ q: torch.Tensor,
424
+ k: torch.Tensor,
425
+ cos: torch.Tensor,
426
+ sin: torch.Tensor,
427
+ position_ids: Optional[torch.Tensor] = None,
428
+ unsqueeze_dim: int = 1,
429
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
430
+ """Apply standard rotary positional embedding to ``q`` and ``k``."""
431
+ return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
432
+
433
+
434
+ def liger_multimodal_rotary_pos_emb(
435
+ q: torch.Tensor,
436
+ k: torch.Tensor,
437
+ cos: torch.Tensor,
438
+ sin: torch.Tensor,
439
+ mrope_section,
440
+ unsqueeze_dim: int = 1,
441
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
442
+ """Apply Qwen2-VL multimodal rotary positional embedding (M-RoPE) to ``q`` and ``k``."""
443
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
444
+
445
+
446
+ __all__ = [
447
+ "LigerRMSNorm",
448
+ "LigerLayerNorm",
449
+ "LigerGroupNorm",
450
+ "LigerDyT",
451
+ "LigerCrossEntropyLoss",
452
+ "LigerFusedLinearCrossEntropyLoss",
453
+ "LigerJSD",
454
+ "LigerKLDIVLoss",
455
+ "LigerTVDLoss",
456
+ "LigerSwiGLUMLP",
457
+ "LigerGEGLUMLP",
458
+ "CrossEntropyOutput",
459
+ "liger_fused_linear_cross_entropy",
460
+ "LigerForCausalLMLoss",
461
+ "liger_rotary_pos_emb",
462
+ "liger_multimodal_rotary_pos_emb",
463
+ ]
build/torch-cuda/metadata.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "name": "liger-kernels",
3
- "id": "_liger_kernels_cuda_e29f7ec",
4
  "version": 1,
5
  "license": "BSD-2-Clause",
6
  "python-depends": [],
 
1
  {
2
  "name": "liger-kernels",
3
+ "id": "_liger_kernels_cuda_08b4d53",
4
  "version": 1,
5
  "license": "BSD-2-Clause",
6
  "python-depends": [],
build/torch-cuda/qwen2vl_mrope.py CHANGED
@@ -219,4 +219,4 @@ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
219
  cos, sin = ctx.saved_tensors
220
  mrope_section = ctx.mrope_section
221
  dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
222
- return dq, dk, None, None, None, None
 
219
  cos, sin = ctx.saved_tensors
220
  mrope_section = ctx.mrope_section
221
  dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
222
+ return dq, dk, None, None, None, None
build/torch-cuda/rms_norm.py CHANGED
@@ -20,9 +20,12 @@ import triton.language as tl
20
  from .utils import calculate_settings
21
  from .utils import compare_version
22
  from .utils import ensure_contiguous
 
 
23
  from .utils import torch_to_triton_dtype
 
24
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
26
  try:
27
  # typical import path with dispatch available
28
  from triton.language.extra.libdevice import rsqrt
@@ -52,6 +55,7 @@ def _rms_norm_forward_kernel(
52
  eps,
53
  offset,
54
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
 
55
  BLOCK_SIZE: tl.constexpr,
56
  ):
57
  """
@@ -63,17 +67,18 @@ def _rms_norm_forward_kernel(
63
  3. https://arxiv.org/pdf/1910.07467
64
  """
65
 
66
- row_idx = tl.program_id(0)
67
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
  mask = col_offsets < n_cols
69
 
70
- Y_ptr += row_idx * Y_row_stride
71
- X_ptr += row_idx * X_row_stride
72
- RSTD_ptr += row_idx * RSTD_row_stride
73
 
74
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
75
  X_row_dtype = X_row.dtype
76
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
 
77
 
78
  # On Llama, only rstd is computed on fp32
79
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -81,7 +86,8 @@ def _rms_norm_forward_kernel(
81
 
82
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
  if casting_mode == _CASTING_MODE_GEMMA:
84
- W_row = W_row.to(tl.float32)
 
85
  X_row = X_row.to(tl.float32)
86
 
87
  if casting_mode == _CASTING_MODE_NONE:
@@ -94,7 +100,7 @@ def _rms_norm_forward_kernel(
94
  # We can save time by caching rms with minimal memory overhead
95
  # because rms is much smaller compared to X_row, as rms is for each row.
96
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
- tl.store(RSTD_ptr, rstd)
98
 
99
  X_row = X_row * rstd
100
 
@@ -102,12 +108,15 @@ def _rms_norm_forward_kernel(
102
  if casting_mode == _CASTING_MODE_LLAMA:
103
  X_row = X_row.to(X_row_dtype)
104
 
105
- Y_row = X_row * (offset + W_row)
 
 
 
106
 
107
  if casting_mode == _CASTING_MODE_GEMMA:
108
  Y_row = Y_row.to(X_row_dtype)
109
 
110
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
111
 
112
 
113
  @triton.jit
@@ -128,8 +137,9 @@ def _rms_norm_backward_kernel(
128
  n_rows,
129
  n_cols,
130
  offset,
131
- rows_per_program: tl.constexpr,
132
  casting_mode: tl.constexpr,
 
133
  BLOCK_SIZE: tl.constexpr,
134
  ):
135
  """
@@ -137,61 +147,256 @@ def _rms_norm_backward_kernel(
137
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
  """
139
 
140
- row_block_id = tl.program_id(0)
141
  row_start = row_block_id * rows_per_program
142
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
  col_offsets = tl.arange(0, BLOCK_SIZE)
144
  mask = col_offsets < n_cols
145
 
146
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
 
147
 
148
- dY_ptr += row_start * dY_row_stride
149
- dX_ptr += row_start * dX_row_stride
 
150
 
151
- X_ptr += row_start * X_row_stride
152
- RSTD_ptr += row_start
 
153
 
154
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
- W_row = W_row + offset
156
 
157
- for _ in range(row_start, row_end):
158
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
160
 
161
  # Get cached rms
162
- rstd_row = tl.load(RSTD_ptr)
163
 
164
  X_row = X_row.to(tl.float32)
165
 
166
  # Different bacward graphs for different casting modes
167
  if casting_mode == _CASTING_MODE_LLAMA:
168
- m = (dY_row * W_row).to(tl.float32)
 
 
 
169
 
170
  elif casting_mode == _CASTING_MODE_GEMMA:
171
  dY_row = dY_row.to(tl.float32)
172
- m = dY_row * W_row
 
 
 
173
  else:
174
- m = dY_row * W_row
 
 
 
175
 
176
  dX_row = rstd_row * m
177
 
178
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
 
180
- # calculate the gradient of W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  if casting_mode == _CASTING_MODE_LLAMA:
182
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
 
 
 
 
 
 
 
 
 
 
183
  else:
184
- # here X_row is already in fp32 (see previous if block)
185
- dW_row += dY_row * (X_row * rstd_row)
 
 
 
 
186
 
187
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
 
 
188
 
189
- dY_ptr += dY_row_stride
190
- dX_ptr += dX_row_stride
191
- X_ptr += X_row_stride
192
- RSTD_ptr += RSTD_row_stride
 
 
 
 
 
 
 
 
 
193
 
194
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
 
195
 
196
 
197
  _str_to_casting_mode = {
@@ -201,7 +406,7 @@ _str_to_casting_mode = {
201
  }
202
 
203
 
204
- def rms_norm_forward(X, W, eps, offset, casting_mode):
205
  if not isinstance(casting_mode, int):
206
  assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
  casting_mode = _str_to_casting_mode[casting_mode]
@@ -220,34 +425,64 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
220
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
221
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
222
 
223
- # Check constraints.
224
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
 
 
 
 
 
 
225
 
226
  # XPU-specific optimization
227
  kernel_args = {}
228
  if X.device.type == "xpu":
229
- kernel_args["grf_mode"] = "large"
230
- _rms_norm_forward_kernel[(n_rows,)](
231
- Y,
232
- Y.stride(0),
233
- X,
234
- X.stride(0),
235
- W,
236
- W.stride(0),
237
- RSTD,
238
- RSTD.stride(0),
239
- n_cols,
240
- eps,
241
- offset,
242
- casting_mode,
243
- BLOCK_SIZE=BLOCK_SIZE,
244
- num_warps=num_warps,
245
- **kernel_args, # XPU-specific optimization
246
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
248
 
249
 
250
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
251
  shape = dY.shape
252
  dim = shape[-1]
253
  dY = dY.view(-1, dim)
@@ -258,9 +493,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
258
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
259
  elif X.device.type == "xpu":
260
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
 
 
261
 
262
- # fp32 for numerical stability especially.
263
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
 
 
 
 
 
264
 
265
  if n_cols > BLOCK_SIZE:
266
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -275,33 +517,65 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
275
  # XPU-specific optimization
276
  kernel_args = {}
277
  if X.device.type == "xpu":
278
- kernel_args["grf_mode"] = "large"
279
-
280
- _rms_norm_backward_kernel[grid](
281
- dY,
282
- dY.stride(0),
283
- dX,
284
- dX.stride(0),
285
- X,
286
- X.stride(0),
287
- torch_to_triton_dtype[X.dtype],
288
- W,
289
- W.stride(0),
290
- RSTD,
291
- RSTD.stride(0),
292
- _dW,
293
- _dW.stride(0),
294
- n_rows,
295
- n_cols,
296
- offset,
297
- rows_per_program,
298
- casting_mode,
299
- BLOCK_SIZE=BLOCK_SIZE,
300
- num_warps=num_warps,
301
- **kernel_args, # XPU-specific optimization
302
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  dX = dX.view(*shape)
304
- dW = _dW.sum(dim=0).to(W.dtype)
 
 
 
 
305
 
306
  return dX, dW
307
 
@@ -330,18 +604,30 @@ class LigerRMSNormFunction(torch.autograd.Function):
330
 
331
  @staticmethod
332
  @ensure_contiguous
333
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
334
  """
335
  X: (B, T, H) or (BxT, H)
336
  W: (H,)
337
  """
338
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
 
 
 
 
 
 
 
339
  ctx.offset = offset
340
  ctx.casting_mode = casting_mode
341
  ctx.in_place = in_place
 
342
  ctx.BLOCK_SIZE = BLOCK_SIZE
343
  ctx.num_warps = num_warps
344
- ctx.save_for_backward(X, W, RSTD)
 
 
 
 
345
  return Y
346
 
347
  @staticmethod
@@ -350,16 +636,19 @@ class LigerRMSNormFunction(torch.autograd.Function):
350
  """
351
  Y: (B, T, H) or (BxT, H)
352
  """
353
- X, W, RSTD = ctx.saved_tensors
 
 
 
 
 
 
 
 
 
 
 
354
  dX, dW = rms_norm_backward(
355
- dY,
356
- X,
357
- W,
358
- RSTD,
359
- ctx.offset,
360
- ctx.casting_mode,
361
- ctx.BLOCK_SIZE,
362
- ctx.num_warps,
363
- ctx.in_place,
364
  )
365
- return dX, dW, None, None, None, None
 
20
  from .utils import calculate_settings
21
  from .utils import compare_version
22
  from .utils import ensure_contiguous
23
+ from .utils import get_npu_core_count
24
+ from .utils import set_large_grf_mode
25
  from .utils import torch_to_triton_dtype
26
+ from .utils import is_npu_available
27
 
28
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
29
  try:
30
  # typical import path with dispatch available
31
  from triton.language.extra.libdevice import rsqrt
 
55
  eps,
56
  offset,
57
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
58
+ elementwise_affine: tl.constexpr,
59
  BLOCK_SIZE: tl.constexpr,
60
  ):
61
  """
 
67
  3. https://arxiv.org/pdf/1910.07467
68
  """
69
 
70
+ row_idx = tl.program_id(0).to(tl.int64)
71
  col_offsets = tl.arange(0, BLOCK_SIZE)
72
  mask = col_offsets < n_cols
73
 
74
+ y_base = Y_ptr + row_idx * Y_row_stride
75
+ x_base = X_ptr + row_idx * X_row_stride
76
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
77
 
78
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
79
  X_row_dtype = X_row.dtype
80
+ if elementwise_affine:
81
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
82
 
83
  # On Llama, only rstd is computed on fp32
84
  if casting_mode == _CASTING_MODE_LLAMA:
 
86
 
87
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
88
  if casting_mode == _CASTING_MODE_GEMMA:
89
+ if elementwise_affine:
90
+ W_row = W_row.to(tl.float32)
91
  X_row = X_row.to(tl.float32)
92
 
93
  if casting_mode == _CASTING_MODE_NONE:
 
100
  # We can save time by caching rms with minimal memory overhead
101
  # because rms is much smaller compared to X_row, as rms is for each row.
102
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
103
+ tl.store(rstd_base, rstd)
104
 
105
  X_row = X_row * rstd
106
 
 
108
  if casting_mode == _CASTING_MODE_LLAMA:
109
  X_row = X_row.to(X_row_dtype)
110
 
111
+ if elementwise_affine:
112
+ Y_row = X_row * (offset + W_row)
113
+ else:
114
+ Y_row = X_row
115
 
116
  if casting_mode == _CASTING_MODE_GEMMA:
117
  Y_row = Y_row.to(X_row_dtype)
118
 
119
+ tl.store(y_base + col_offsets, Y_row, mask=mask)
120
 
121
 
122
  @triton.jit
 
137
  n_rows,
138
  n_cols,
139
  offset,
140
+ rows_per_program,
141
  casting_mode: tl.constexpr,
142
+ elementwise_affine: tl.constexpr,
143
  BLOCK_SIZE: tl.constexpr,
144
  ):
145
  """
 
147
  dw = sum(dy * (x / RMS)). summation over BxT dimension
148
  """
149
 
150
+ row_block_id = tl.program_id(0).to(tl.int64)
151
  row_start = row_block_id * rows_per_program
152
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
153
  col_offsets = tl.arange(0, BLOCK_SIZE)
154
  mask = col_offsets < n_cols
155
 
156
+ if elementwise_affine:
157
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
158
 
159
+ if elementwise_affine:
160
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
161
+ W_row = W_row + offset
162
 
163
+ for row_idx in range(row_start, row_end):
164
+ dy_base = dY_ptr + row_idx * dY_row_stride
165
+ dx_base = dX_ptr + row_idx * dX_row_stride
166
 
167
+ x_base = X_ptr + row_idx * X_row_stride
168
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
169
 
170
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
171
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
 
172
 
173
  # Get cached rms
174
+ rstd_row = tl.load(rstd_base)
175
 
176
  X_row = X_row.to(tl.float32)
177
 
178
  # Different bacward graphs for different casting modes
179
  if casting_mode == _CASTING_MODE_LLAMA:
180
+ if elementwise_affine:
181
+ m = (dY_row * W_row).to(tl.float32)
182
+ else:
183
+ m = dY_row.to(tl.float32)
184
 
185
  elif casting_mode == _CASTING_MODE_GEMMA:
186
  dY_row = dY_row.to(tl.float32)
187
+ if elementwise_affine:
188
+ m = dY_row * W_row
189
+ else:
190
+ m = dY_row
191
  else:
192
+ if elementwise_affine:
193
+ m = dY_row * W_row
194
+ else:
195
+ m = dY_row
196
 
197
  dX_row = rstd_row * m
198
 
199
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
200
 
201
+ if elementwise_affine:
202
+ # calculate the gradient of W
203
+ if casting_mode == _CASTING_MODE_LLAMA:
204
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
205
+ else:
206
+ # here X_row is already in fp32 (see previous if block)
207
+ dW_row += dY_row * (X_row * rstd_row)
208
+
209
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
210
+
211
+ if elementwise_affine:
212
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
213
+
214
+
215
+ @triton.jit
216
+ def _block_rms_norm_forward_kernel(
217
+ Y_ptr,
218
+ Y_row_stride,
219
+ X_ptr,
220
+ X_row_stride,
221
+ W_ptr,
222
+ W_row_stride,
223
+ RSTD_ptr,
224
+ RSTD_row_stride,
225
+ n_rows,
226
+ n_cols,
227
+ eps,
228
+ offset,
229
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
230
+ elementwise_affine: tl.constexpr,
231
+ BLOCK_SIZE: tl.constexpr,
232
+ BLOCK_ROW: tl.constexpr,
233
+ ):
234
+ """
235
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
236
+
237
+ Reference:
238
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
239
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
240
+ 3. https://arxiv.org/pdf/1910.07467
241
+ """
242
+
243
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
244
+ col_offsets = tl.arange(0, BLOCK_SIZE)
245
+ row_mask = row_idx < n_rows
246
+ col_mask = col_offsets < n_cols
247
+
248
+ X_row = tl.load(
249
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
250
+ mask=row_mask[:, None] & col_mask[None, :],
251
+ other=0,
252
+ )
253
+ X_row_dtype = X_row.dtype
254
+ if elementwise_affine:
255
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
256
+
257
+ # On Llama, only rstd is computed on fp32
258
+ if casting_mode == _CASTING_MODE_LLAMA:
259
+ X_row = X_row.to(tl.float32)
260
+
261
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
262
+ if casting_mode == _CASTING_MODE_GEMMA:
263
+ if elementwise_affine:
264
+ W_row = W_row.to(tl.float32)
265
+ X_row = X_row.to(tl.float32)
266
+
267
+ if casting_mode == _CASTING_MODE_NONE:
268
+ eps = eps.to(X_row_dtype)
269
+ offset = offset.to(X_row_dtype)
270
+
271
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
272
+ rstd = rsqrt(mean_square + eps)
273
+
274
+ # We can save time by caching rms with minimal memory overhead
275
+ # because rms is much smaller compared to X_row, as rms is for each row.
276
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
277
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
278
+
279
+ X_row = X_row * rstd[:, None]
280
+
281
+ # On Llama, the multiplication with the weight is done on the original dtype
282
+ if casting_mode == _CASTING_MODE_LLAMA:
283
+ X_row = X_row.to(X_row_dtype)
284
+
285
+ if elementwise_affine:
286
+ Y_row = X_row * (offset + W_row)[None, :]
287
+ else:
288
+ Y_row = X_row
289
+
290
+ if casting_mode == _CASTING_MODE_GEMMA:
291
+ Y_row = Y_row.to(X_row_dtype)
292
+
293
+ tl.store(
294
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
295
+ Y_row,
296
+ mask=row_mask[:, None] & col_mask[None, :],
297
+ )
298
+
299
+
300
+ @triton.jit
301
+ def _block_rms_norm_backward_kernel(
302
+ dY_ptr,
303
+ dY_row_stride,
304
+ dX_ptr,
305
+ dX_row_stride,
306
+ X_ptr,
307
+ X_row_stride,
308
+ X_dtype: tl.constexpr,
309
+ W_ptr,
310
+ W_row_stride,
311
+ RSTD_ptr,
312
+ RSTD_row_stride,
313
+ dW_ptr,
314
+ dW_row_stride,
315
+ n_rows,
316
+ n_cols,
317
+ offset,
318
+ casting_mode: tl.constexpr,
319
+ elementwise_affine: tl.constexpr,
320
+ BLOCK_SIZE: tl.constexpr,
321
+ BLOCK_ROW: tl.constexpr,
322
+ ):
323
+ """
324
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
325
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
326
+ """
327
+
328
+ pid = tl.program_id(0).cast(tl.int64)
329
+ NUM_SMS = tl.num_programs(0)
330
+
331
+ col_offsets = tl.arange(0, BLOCK_SIZE)
332
+ col_mask = col_offsets < n_cols
333
+
334
+ if elementwise_affine:
335
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
336
+
337
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
338
+ W_row = W_row + offset
339
+
340
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
341
+ row_idx = start + tl.arange(0, BLOCK_ROW)
342
+ row_mask = row_idx < n_rows
343
+ dY_row = tl.load(
344
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
345
+ mask=row_mask[:, None] & col_mask[None, :],
346
+ other=0.0,
347
+ )
348
+ X_row = tl.load(
349
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
350
+ mask=row_mask[:, None] & col_mask[None, :],
351
+ other=0.0,
352
+ )
353
+
354
+ # Get cached rms
355
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
356
+
357
+ X_row = X_row.to(tl.float32)
358
+
359
+ # Different bacward graphs for different casting modes
360
  if casting_mode == _CASTING_MODE_LLAMA:
361
+ if elementwise_affine:
362
+ m = (dY_row * W_row[None, :]).to(tl.float32)
363
+ else:
364
+ m = dY_row.to(tl.float32)
365
+
366
+ elif casting_mode == _CASTING_MODE_GEMMA:
367
+ dY_row = dY_row.to(tl.float32)
368
+ if elementwise_affine:
369
+ m = dY_row * W_row[None, :]
370
+ else:
371
+ m = dY_row
372
  else:
373
+ if elementwise_affine:
374
+ m = dY_row * W_row[None, :]
375
+ else:
376
+ m = dY_row
377
+
378
+ dX_row = rstd_row[:, None] * m
379
 
380
+ dX_row += (rstd_row[:, None]) * (
381
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
382
+ )
383
 
384
+ if elementwise_affine:
385
+ if casting_mode == _CASTING_MODE_LLAMA:
386
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
387
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
388
+ else:
389
+ # here X_row is already in fp32 (see previous if block)
390
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
391
+
392
+ tl.store(
393
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
394
+ dX_row,
395
+ mask=row_mask[:, None] & col_mask[None, :],
396
+ )
397
 
398
+ if elementwise_affine:
399
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
400
 
401
 
402
  _str_to_casting_mode = {
 
406
  }
407
 
408
 
409
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
410
  if not isinstance(casting_mode, int):
411
  assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
412
  casting_mode = _str_to_casting_mode[casting_mode]
 
425
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
426
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
427
 
428
+ if W is not None:
429
+ # Check constraints.
430
+ assert X.shape[1] == W.shape[0], (
431
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432
+ )
433
+ elementwise_affine = True
434
+ else:
435
+ elementwise_affine = False
436
 
437
  # XPU-specific optimization
438
  kernel_args = {}
439
  if X.device.type == "xpu":
440
+ set_large_grf_mode(kernel_args)
441
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
442
+ _rms_norm_forward_kernel[(n_rows,)](
443
+ Y,
444
+ Y.stride(0),
445
+ X,
446
+ X.stride(0),
447
+ W,
448
+ W.stride(0) if elementwise_affine else 0,
449
+ RSTD,
450
+ RSTD.stride(0),
451
+ n_cols,
452
+ eps,
453
+ offset,
454
+ casting_mode,
455
+ elementwise_affine=elementwise_affine,
456
+ BLOCK_SIZE=BLOCK_SIZE,
457
+ num_warps=num_warps,
458
+ **kernel_args, # XPU-specific optimization
459
+ )
460
+ else:
461
+ BLOCK_ROW = 16
462
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
463
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
464
+ Y,
465
+ Y.stride(0),
466
+ X,
467
+ X.stride(0),
468
+ W,
469
+ W.stride(0) if elementwise_affine else 0,
470
+ RSTD,
471
+ RSTD.stride(0),
472
+ n_rows,
473
+ n_cols,
474
+ eps,
475
+ offset,
476
+ casting_mode,
477
+ elementwise_affine=elementwise_affine,
478
+ BLOCK_SIZE=BLOCK_SIZE,
479
+ num_warps=num_warps,
480
+ **kernel_args, # XPU-specific optimization
481
+ )
482
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
483
 
484
 
485
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
486
  shape = dY.shape
487
  dim = shape[-1]
488
  dY = dY.view(-1, dim)
 
493
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
494
  elif X.device.type == "xpu":
495
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
496
+ elif X.device.type == "npu":
497
+ sm_count = get_npu_core_count()
498
 
499
+ if W is not None:
500
+ # fp32 for numerical stability especially.
501
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
502
+ elementwise_affine = True
503
+ else:
504
+ _dW = None
505
+ elementwise_affine = False
506
 
507
  if n_cols > BLOCK_SIZE:
508
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
 
517
  # XPU-specific optimization
518
  kernel_args = {}
519
  if X.device.type == "xpu":
520
+ set_large_grf_mode(kernel_args)
521
+
522
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
523
+ _rms_norm_backward_kernel[grid](
524
+ dY,
525
+ dY.stride(0),
526
+ dX,
527
+ dX.stride(0),
528
+ X,
529
+ X.stride(0),
530
+ torch_to_triton_dtype[X.dtype],
531
+ W,
532
+ W.stride(0) if elementwise_affine else 0,
533
+ RSTD,
534
+ RSTD.stride(0),
535
+ _dW,
536
+ _dW.stride(0) if elementwise_affine else 0,
537
+ n_rows,
538
+ n_cols,
539
+ offset,
540
+ rows_per_program,
541
+ casting_mode,
542
+ elementwise_affine=elementwise_affine,
543
+ BLOCK_SIZE=BLOCK_SIZE,
544
+ num_warps=num_warps,
545
+ **kernel_args, # XPU-specific optimization
546
+ )
547
+ else:
548
+ BLOCK_ROW = 16
549
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
550
+ _block_rms_norm_backward_kernel[grid](
551
+ dY,
552
+ dY.stride(0),
553
+ dX,
554
+ dX.stride(0),
555
+ X,
556
+ X.stride(0),
557
+ torch_to_triton_dtype[X.dtype],
558
+ W,
559
+ W.stride(0) if elementwise_affine else 0,
560
+ RSTD,
561
+ RSTD.stride(0),
562
+ _dW,
563
+ _dW.stride(0) if elementwise_affine else 0,
564
+ n_rows,
565
+ n_cols,
566
+ offset,
567
+ casting_mode,
568
+ elementwise_affine=elementwise_affine,
569
+ BLOCK_SIZE=BLOCK_SIZE,
570
+ num_warps=num_warps,
571
+ **kernel_args, # XPU-specific optimization
572
+ )
573
  dX = dX.view(*shape)
574
+
575
+ if elementwise_affine:
576
+ dW = _dW.sum(dim=0).to(W.dtype)
577
+ else:
578
+ dW = None
579
 
580
  return dX, dW
581
 
 
604
 
605
  @staticmethod
606
  @ensure_contiguous
607
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
608
  """
609
  X: (B, T, H) or (BxT, H)
610
  W: (H,)
611
  """
612
+ if isinstance(X, torch.distributed.tensor.DTensor):
613
+ # Input tensor is output of a tensor parallel module and
614
+ # needs to be gathered to a local tensor to compute
615
+ # RMSE layer norm on each TP worker.
616
+ # TODO: support CP.
617
+ X = X.full_tensor()
618
+
619
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
620
  ctx.offset = offset
621
  ctx.casting_mode = casting_mode
622
  ctx.in_place = in_place
623
+ ctx.row_mode = row_mode
624
  ctx.BLOCK_SIZE = BLOCK_SIZE
625
  ctx.num_warps = num_warps
626
+ ctx.elementwise_affine = W is not None
627
+ if W is not None:
628
+ ctx.save_for_backward(X, W, RSTD)
629
+ else:
630
+ ctx.save_for_backward(X, RSTD)
631
  return Y
632
 
633
  @staticmethod
 
636
  """
637
  Y: (B, T, H) or (BxT, H)
638
  """
639
+ if ctx.elementwise_affine:
640
+ X, W, RSTD = ctx.saved_tensors
641
+ else:
642
+ X, RSTD = ctx.saved_tensors
643
+ W = None
644
+
645
+ if isinstance(dY, torch.distributed.tensor.DTensor):
646
+ # Gradients are output of a tensor parallel module and
647
+ # needs to be gathered to a local tensor for computing RMSE layer.
648
+ # TODO: support CP.
649
+ dY = dY.full_tensor()
650
+
651
  dX, dW = rms_norm_backward(
652
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
 
 
 
 
 
 
 
 
653
  )
654
+ return dX, dW, None, None, None, None, None
build/torch-cuda/rope.py CHANGED
@@ -32,7 +32,7 @@ def _triton_rope(
32
 
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
- pid = tl.program_id(0)
36
 
37
  # locate start address
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -236,4 +236,4 @@ class LigerRopeFunction(torch.autograd.Function):
236
 
237
  cos, sin = ctx.saved_tensors
238
  dq, dk = rope_backward(dq, dk, cos, sin)
239
- return dq, dk, None, None, None, None
 
32
 
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
+ pid = tl.program_id(0).to(tl.int64)
36
 
37
  # locate start address
38
  q_ptr = q_ptr + pid * q_row_stride
 
236
 
237
  cos, sin = ctx.saved_tensors
238
  dq, dk = rope_backward(dq, dk, cos, sin)
239
+ return dq, dk, None, None, None, None
build/torch-cuda/swiglu.py CHANGED
@@ -12,7 +12,9 @@ def silu(x):
12
 
13
 
14
  @triton.jit
15
- def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
 
 
16
  program_id = tl.program_id(0).to(tl.int64)
17
 
18
  # locate start index
@@ -24,14 +26,16 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
24
  mask = col_offsets < n_cols
25
 
26
  # sigmoid requires type float32
27
- a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
 
32
 
33
  @triton.jit
34
- def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
 
 
35
  program_id = tl.program_id(0).to(tl.int64)
36
 
37
  # locate start index
@@ -44,20 +48,21 @@ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr,
44
 
45
  dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
46
  # sigmoid requires type float32
47
- a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
48
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
49
 
50
- # recomputation to save memory
51
  sig_a = tl.sigmoid(a_row)
52
  silu_a = a_row * sig_a
53
  db_row = dc_row * silu_a
54
- da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
 
55
 
56
  tl.store(a_ptr + col_offsets, da_row, mask=mask)
57
  tl.store(b_ptr + col_offsets, db_row, mask=mask)
58
 
59
 
60
- def swiglu_forward(a, b):
61
  ori_shape = a.shape
62
 
63
  n_cols = ori_shape[-1]
@@ -73,6 +78,7 @@ def swiglu_forward(a, b):
73
  b,
74
  c,
75
  c.stride(-2),
 
76
  n_cols=n_cols,
77
  BLOCK_SIZE=BLOCK_SIZE,
78
  num_warps=num_warps,
@@ -80,7 +86,7 @@ def swiglu_forward(a, b):
80
  return a, b, c.view(*ori_shape)
81
 
82
 
83
- def swiglu_backward(a, b, dc):
84
  ori_shape = dc.shape
85
  n_cols = ori_shape[-1]
86
  dc = dc.view(-1, n_cols)
@@ -93,6 +99,7 @@ def swiglu_backward(a, b, dc):
93
  a,
94
  b,
95
  dc.stride(-2),
 
96
  n_cols=n_cols,
97
  BLOCK_SIZE=BLOCK_SIZE,
98
  num_warps=num_warps,
@@ -103,14 +110,67 @@ def swiglu_backward(a, b, dc):
103
  class LigerSiLUMulFunction(torch.autograd.Function):
104
  @staticmethod
105
  @ensure_contiguous
106
- def forward(ctx, a, b):
107
- a, b, c = swiglu_forward(a, b)
108
- ctx.save_for_backward(a, b)
109
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  @staticmethod
112
  @ensure_contiguous
113
  def backward(ctx, dc):
114
  a, b = ctx.saved_tensors
115
- a, b = swiglu_backward(a, b, dc)
116
- return a, b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  @triton.jit
15
+ def _swiglu_forward_kernel(
16
+ a_ptr, b_ptr, c_ptr, stride, gate_multiplier, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
17
+ ):
18
  program_id = tl.program_id(0).to(tl.int64)
19
 
20
  # locate start index
 
26
  mask = col_offsets < n_cols
27
 
28
  # sigmoid requires type float32
29
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) * gate_multiplier
30
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
31
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
32
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
33
 
34
 
35
  @triton.jit
36
+ def _swiglu_backward_kernel(
37
+ dc_ptr, a_ptr, b_ptr, stride, gate_multiplier, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
38
+ ):
39
  program_id = tl.program_id(0).to(tl.int64)
40
 
41
  # locate start index
 
48
 
49
  dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
50
  # sigmoid requires type float32
51
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) * gate_multiplier
52
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
53
 
54
+ # recomputation to save memory. a_row already holds a * gate_multiplier.
55
  sig_a = tl.sigmoid(a_row)
56
  silu_a = a_row * sig_a
57
  db_row = dc_row * silu_a
58
+ # chain rule pulls an extra factor of gate_multiplier through the pre-activation scaling
59
+ da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row * gate_multiplier
60
 
61
  tl.store(a_ptr + col_offsets, da_row, mask=mask)
62
  tl.store(b_ptr + col_offsets, db_row, mask=mask)
63
 
64
 
65
+ def swiglu_forward(a, b, gate_multiplier: float = 1.0):
66
  ori_shape = a.shape
67
 
68
  n_cols = ori_shape[-1]
 
78
  b,
79
  c,
80
  c.stride(-2),
81
+ float(gate_multiplier),
82
  n_cols=n_cols,
83
  BLOCK_SIZE=BLOCK_SIZE,
84
  num_warps=num_warps,
 
86
  return a, b, c.view(*ori_shape)
87
 
88
 
89
+ def swiglu_backward(a, b, dc, gate_multiplier: float = 1.0):
90
  ori_shape = dc.shape
91
  n_cols = ori_shape[-1]
92
  dc = dc.view(-1, n_cols)
 
99
  a,
100
  b,
101
  dc.stride(-2),
102
+ float(gate_multiplier),
103
  n_cols=n_cols,
104
  BLOCK_SIZE=BLOCK_SIZE,
105
  num_warps=num_warps,
 
110
  class LigerSiLUMulFunction(torch.autograd.Function):
111
  @staticmethod
112
  @ensure_contiguous
113
+ def forward(ctx, a, b, gate_multiplier: float = 1.0, down_multiplier: float = 1.0):
114
+ gate_multiplier = float(gate_multiplier)
115
+ down_multiplier = float(down_multiplier)
116
+ ctx.gate_multiplier = gate_multiplier
117
+ ctx.down_multiplier = down_multiplier
118
+
119
+ if isinstance(a, torch.distributed.tensor.DTensor) or isinstance(b, torch.distributed.tensor.DTensor):
120
+ device_mesh, placements = (
121
+ (a.device_mesh, a.placements)
122
+ if isinstance(a, torch.distributed.tensor.DTensor)
123
+ else (b.device_mesh, b.placements)
124
+ )
125
+
126
+ # Assume that full tensors are gathered before and identical across
127
+ # the associated process groups.
128
+ if not isinstance(a, torch.distributed.tensor.DTensor):
129
+ a = torch.distributed.tensor.distribute_tensor(a, device_mesh=device_mesh, placements=placements)
130
+ if not isinstance(b, torch.distributed.tensor.DTensor):
131
+ b = torch.distributed.tensor.distribute_tensor(b, device_mesh=device_mesh, placements=placements)
132
+ a_local, b_local, c_local = swiglu_forward(a.to_local(), b.to_local(), gate_multiplier)
133
+ if down_multiplier != 1.0:
134
+ c_local = c_local * down_multiplier
135
+ ctx.save_for_backward(a_local, b_local)
136
+ ctx.dtensor_metadata = (device_mesh, placements)
137
+ return torch.distributed.tensor.DTensor.from_local(c_local, device_mesh, placements)
138
+ else:
139
+ a, b, c = swiglu_forward(a, b, gate_multiplier)
140
+ if down_multiplier != 1.0:
141
+ c = c * down_multiplier
142
+ ctx.save_for_backward(a, b)
143
+ ctx.dtensor_metadata = None
144
+ return c
145
 
146
  @staticmethod
147
  @ensure_contiguous
148
  def backward(ctx, dc):
149
  a, b = ctx.saved_tensors
150
+ gate_multiplier = ctx.gate_multiplier
151
+ down_multiplier = ctx.down_multiplier
152
+
153
+ if ctx.dtensor_metadata is not None:
154
+ device_mesh, placements = ctx.dtensor_metadata
155
+
156
+ # Assume that full tensors are gathered before and identical across
157
+ # the associated process groups.
158
+ dc_local = (
159
+ dc.to_local()
160
+ if isinstance(dc, torch.distributed.tensor.DTensor)
161
+ else torch.distributed.tensor.distribute_tensor(dc, device_mesh=device_mesh, placements=placements)
162
+ )
163
+ if down_multiplier != 1.0:
164
+ dc_local = dc_local * down_multiplier
165
+ a_local, b_local = swiglu_backward(a, b, dc_local, gate_multiplier)
166
+ return (
167
+ torch.distributed.tensor.DTensor.from_local(a_local, device_mesh, placements),
168
+ torch.distributed.tensor.DTensor.from_local(b_local, device_mesh, placements),
169
+ None,
170
+ None,
171
+ )
172
+
173
+ if down_multiplier != 1.0:
174
+ dc = dc * down_multiplier
175
+ a, b = swiglu_backward(a, b, dc, gate_multiplier)
176
+ return a, b, None, None
build/torch-cuda/tvd.py CHANGED
@@ -49,6 +49,7 @@ def _tv_distance_kernel(
49
  label_ptr,
50
  ignore_index: tl.constexpr,
51
  n_cols,
 
52
  BLOCK_SIZE: tl.constexpr,
53
  HAS_LABEL: tl.constexpr,
54
  reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
@@ -84,7 +85,8 @@ def _tv_distance_kernel(
84
  # TVD(P || Q) = 0.5 * |P - Q|
85
  tv_loss = 0.5 * tl.abs(p - q)
86
 
87
- grad_res = tl.where(p > q, 0.5, -0.5)
 
88
 
89
  tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
 
@@ -94,7 +96,8 @@ def _tv_distance_kernel(
94
  loss_sum += tl.sum(tv_loss, axis=0)
95
 
96
  if reduction != _REDUCTION_MODE_NONE:
97
- tl.store(loss_ptr, loss_sum)
 
98
 
99
 
100
  def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
@@ -113,6 +116,14 @@ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_
113
 
114
  n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
 
 
 
 
 
 
 
 
 
116
  _tv_distance_kernel[grid](
117
  p,
118
  p.stride(0),
@@ -125,18 +136,18 @@ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_
125
  shift_labels if has_label else torch.empty(1, device=p.device),
126
  ignore_index,
127
  V,
 
128
  BLOCK_SIZE=BLOCK_SIZE,
129
  HAS_LABEL=has_label,
130
  num_warps=num_warps,
131
  reduction=reduction,
132
  )
133
 
134
- if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
- return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
 
136
  elif reduction == _REDUCTION_MODE_SUM.value:
137
  return output_tensor.sum(dim=0), grads
138
- elif reduction == _REDUCTION_MODE_MEAN.value:
139
- return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
  else:
141
  return output_tensor, grads
142
 
@@ -204,4 +215,4 @@ class LigerTVDLossFunction(torch.autograd.Function):
204
  (grads,) = ctx.saved_tensors
205
  grads = tvd_backward_triton(grad_output, grads)
206
 
207
- return grads, None, None, None, None
 
49
  label_ptr,
50
  ignore_index: tl.constexpr,
51
  n_cols,
52
+ scale, # pre-computed reduction scale for gradients (fused into kernel)
53
  BLOCK_SIZE: tl.constexpr,
54
  HAS_LABEL: tl.constexpr,
55
  reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
 
85
  # TVD(P || Q) = 0.5 * |P - Q|
86
  tv_loss = 0.5 * tl.abs(p - q)
87
 
88
+ # Fuse reduction scaling into gradient computation (eliminates separate Python division)
89
+ grad_res = tl.where(p > q, 0.5 * scale, -0.5 * scale)
90
 
91
  tl.store(grads_ptr + offsets, grad_res, mask=mask)
92
 
 
96
  loss_sum += tl.sum(tv_loss, axis=0)
97
 
98
  if reduction != _REDUCTION_MODE_NONE:
99
+ # Fuse reduction scaling into loss (same scale as gradients; avoids Python division)
100
+ tl.store(loss_ptr, loss_sum * scale)
101
 
102
 
103
  def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
 
116
 
117
  n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
118
 
119
+ # Pre-compute gradient scale factor (fused into kernel to avoid separate division)
120
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
121
+ scale = 1.0 / n_non_ignore
122
+ elif reduction == _REDUCTION_MODE_MEAN.value:
123
+ scale = 1.0 / (n_non_ignore * V)
124
+ else:
125
+ scale = 1.0
126
+
127
  _tv_distance_kernel[grid](
128
  p,
129
  p.stride(0),
 
136
  shift_labels if has_label else torch.empty(1, device=p.device),
137
  ignore_index,
138
  V,
139
+ scale,
140
  BLOCK_SIZE=BLOCK_SIZE,
141
  HAS_LABEL=has_label,
142
  num_warps=num_warps,
143
  reduction=reduction,
144
  )
145
 
146
+ # Loss and gradients are already scaled inside the kernel — no separate division needed
147
+ if reduction in (_REDUCTION_MODE_BATCHMEAN.value, _REDUCTION_MODE_MEAN.value):
148
+ return output_tensor.sum(), grads
149
  elif reduction == _REDUCTION_MODE_SUM.value:
150
  return output_tensor.sum(dim=0), grads
 
 
151
  else:
152
  return output_tensor, grads
153
 
 
215
  (grads,) = ctx.saved_tensors
216
  grads = tvd_backward_triton(grad_output, grads)
217
 
218
+ return grads, None, None, None, None
build/torch-cuda/utils.py CHANGED
@@ -22,17 +22,33 @@ import triton.language as tl
22
 
23
  from packaging.version import Version
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  def infer_device():
26
  """
27
  Get current device name based on available devices
28
  """
29
  if torch.cuda.is_available(): # Works for both Nvidia and AMD
30
  return "cuda"
 
 
 
 
31
  elif torch.xpu.is_available():
32
  return "xpu"
33
  else:
34
  return "cpu"
35
 
 
36
  def is_hip() -> bool:
37
  return torch.version.hip is not None
38
 
@@ -86,6 +102,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
86
  functools.partial(torch.amp.custom_fwd, device_type=device),
87
  functools.partial(torch.amp.custom_bwd, device_type=device),
88
  )
 
 
89
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
90
 
91
 
@@ -132,4 +150,27 @@ def element_mul_kernel(
132
  for i in range(0, n_cols, BLOCK_SIZE):
133
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
134
  X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
135
- tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  from packaging.version import Version
24
 
25
+
26
+ def is_npu_available() -> bool:
27
+ """Detect Ascend NPU availability."""
28
+ try:
29
+ from transformers.utils import is_torch_npu_available
30
+
31
+ return is_torch_npu_available()
32
+ except Exception:
33
+ return False
34
+
35
+
36
  def infer_device():
37
  """
38
  Get current device name based on available devices
39
  """
40
  if torch.cuda.is_available(): # Works for both Nvidia and AMD
41
  return "cuda"
42
+ # Use Ascend NPU if available (torch.npu)
43
+ elif is_npu_available():
44
+ return "npu"
45
+ # XPU (Intel) if available
46
  elif torch.xpu.is_available():
47
  return "xpu"
48
  else:
49
  return "cpu"
50
 
51
+
52
  def is_hip() -> bool:
53
  return torch.version.hip is not None
54
 
 
102
  functools.partial(torch.amp.custom_fwd, device_type=device),
103
  functools.partial(torch.amp.custom_bwd, device_type=device),
104
  )
105
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
106
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
107
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
108
 
109
 
 
150
  for i in range(0, n_cols, BLOCK_SIZE):
151
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
152
  X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
153
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
154
+
155
+
156
+ def get_npu_core_count(default: int = 20) -> int:
157
+ """Return NPU vector core count.
158
+ Fallback to `default` if Triton runtime or NPU device is unavailable.
159
+ """
160
+ try:
161
+ utils = triton.runtime.driver.active.utils
162
+ props = utils.get_device_properties(0)
163
+ return int(props.get("num_vectorcore", default))
164
+ except Exception:
165
+ return default
166
+
167
+
168
+ def set_large_grf_mode(kernel_args: dict):
169
+ """Set large GRF mode for XPU devices."""
170
+ # On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
171
+ # triton XPU installed from source will be called `triton`.
172
+ if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
173
+ kernel_args["grf_mode"] = "256"
174
+ else:
175
+ # API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
176
+ kernel_args["grf_mode"] = "large"