kernels-bot commited on
Commit
2fe9f62
·
verified ·
1 Parent(s): 8b442f4

Uploaded using `kernel-builder`.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch-cuda/__init__.py +2 -10
  2. build/torch-cuda/_ops.py +33 -3
  3. build/torch-cuda/functional/__init__.py +171 -218
  4. build/torch-cuda/functional/backward.py +249 -308
  5. build/torch-cuda/functional/forward.py +72 -120
  6. build/torch-cuda/functional/grouped_gemm.py +0 -0
  7. build/torch-cuda/functional/moe_config.py +0 -581
  8. build/torch-cuda/functional/reduction_over_k_gather.py +0 -3
  9. build/torch-cuda/functional/{topk_softmax.py → topk.py} +158 -13
  10. build/torch-cuda/functional/utils.py +0 -25
  11. build/torch-cuda/metadata.json +2 -0
  12. build/torch-cuda/quack/__init__.py +2 -2
  13. build/torch-cuda/quack/_compile_worker.py +102 -0
  14. build/torch-cuda/quack/activation.py +108 -65
  15. build/torch-cuda/quack/autotuner.py +184 -3
  16. build/torch-cuda/quack/blockscaled_gemm_utils.py +752 -0
  17. build/torch-cuda/quack/broadcast_utils.py +1 -1
  18. build/torch-cuda/quack/cache_utils.py +195 -0
  19. build/torch-cuda/quack/copy_utils.py +635 -66
  20. build/torch-cuda/quack/cross_entropy.py +716 -0
  21. build/torch-cuda/quack/cute_dsl_ptxas.py +105 -19
  22. build/torch-cuda/quack/cute_dsl_utils.py +124 -52
  23. build/torch-cuda/quack/epi_composable.py +187 -0
  24. build/torch-cuda/quack/epi_ops.py +648 -0
  25. build/torch-cuda/quack/epi_utils.py +64 -0
  26. build/torch-cuda/quack/fast_math.py +29 -76
  27. build/torch-cuda/quack/gemm.py +225 -137
  28. build/torch-cuda/quack/gemm_act.py +396 -387
  29. build/torch-cuda/quack/gemm_blockscaled_interface.py +326 -0
  30. build/torch-cuda/quack/gemm_config.py +131 -72
  31. build/torch-cuda/quack/gemm_dact.py +417 -124
  32. build/torch-cuda/quack/gemm_default_epi.py +57 -204
  33. build/torch-cuda/quack/gemm_interface.py +1318 -200
  34. build/torch-cuda/quack/gemm_norm_act.py +400 -0
  35. build/torch-cuda/quack/gemm_sm100.py +0 -0
  36. build/torch-cuda/quack/gemm_sm120.py +626 -0
  37. build/torch-cuda/quack/gemm_sm90.py +316 -355
  38. build/torch-cuda/quack/gemm_sq_reduce.py +259 -0
  39. build/torch-cuda/quack/gemm_symmetric.py +236 -172
  40. build/torch-cuda/quack/gemm_tvm_ffi_utils.py +229 -0
  41. build/torch-cuda/quack/gemm_wrapper_utils.py +0 -317
  42. build/torch-cuda/quack/layout_utils.py +117 -28
  43. build/torch-cuda/quack/linear.py +368 -0
  44. build/torch-cuda/quack/linear_cross_entropy.py +275 -0
  45. build/torch-cuda/quack/mlp.py +331 -0
  46. build/torch-cuda/quack/mx_utils.py +269 -0
  47. build/torch-cuda/quack/nvmmh_heuristic.py +172 -0
  48. build/torch-cuda/quack/pipeline.py +395 -100
  49. build/torch-cuda/quack/reduce.py +2 -2
  50. build/torch-cuda/quack/rms_final_reduce.py +181 -0
build/torch-cuda/__init__.py CHANGED
@@ -2,23 +2,15 @@
2
  # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
  # ********************************************************************************
4
 
5
- from functools import lru_cache
6
-
7
- __version__ = "0.1.1"
8
 
9
  from .enums import KernelBackendMoE
10
-
11
  from .moe import MoE
12
- from .functional import (
13
- enable_quack_gemm,
14
- moe_general_routing_inputs,
15
- moe_TC_softmax_topk_layer,
16
- )
17
 
18
  __all__ = [
19
  "KernelBackendMoE",
20
  "MoE",
21
- "enable_quack_gemm",
22
  "moe_general_routing_inputs",
23
  "moe_TC_softmax_topk_layer",
24
  ]
 
2
  # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
  # ********************************************************************************
4
 
5
+ __version__ = "0.1.2.post1"
 
 
6
 
7
  from .enums import KernelBackendMoE
8
+ from .functional import moe_general_routing_inputs, moe_TC_softmax_topk_layer
9
  from .moe import MoE
 
 
 
 
 
10
 
11
  __all__ = [
12
  "KernelBackendMoE",
13
  "MoE",
 
14
  "moe_general_routing_inputs",
15
  "moe_TC_softmax_topk_layer",
16
  ]
build/torch-cuda/_ops.py CHANGED
@@ -1,8 +1,38 @@
1
  import torch
2
- ops = torch.ops._sonic_moe_2b49d3f
3
 
4
- def add_op_namespace_prefix(op_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_sonic_moe_2b49d3f::{op_name}"
 
1
  import torch
 
2
 
3
+ def get_backend() -> str:
4
+ """Detect the backend by inspecting torch."""
5
+ import torch
6
+
7
+ if hasattr(torch, "neuron"):
8
+ # Needs to be sorted before specific Torch builds, since Neuron
9
+ # extension can be loaded into e.g. CUDA Torch builds.
10
+ return "neuron"
11
+ elif torch.version.cuda is not None:
12
+ return "cuda"
13
+ elif torch.version.hip is not None:
14
+ return "rocm"
15
+ elif torch.backends.mps.is_available():
16
+ return "metal"
17
+ elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
18
+ return "xpu"
19
+ else:
20
+ return "cpu"
21
+
22
+
23
+ def _find_ops_name() -> str:
24
+ kernel_name = "sonic_moe"
25
+ unique_id = "a8c39a2"
26
+ backend = get_backend()
27
+ return f"_{kernel_name}_{backend}_{unique_id}"
28
+
29
+
30
+ _OPS_NAME = _find_ops_name()
31
+
32
+ ops = getattr(torch.ops, _OPS_NAME)
33
+
34
+ def add_op_namespace_prefix(op_name: str) -> str:
35
  """
36
  Prefix op by namespace.
37
  """
38
+ return f"{_OPS_NAME}::{op_name}"
build/torch-cuda/functional/__init__.py CHANGED
@@ -6,50 +6,72 @@ import os
6
 
7
  import torch
8
  import torch.nn.functional as F
9
- from ..quack.gemm_interface import gemm
10
 
11
  from ..enums import ActivationType, is_glu
12
- from ..quack_utils import gemm_dgated, gemm_gated
13
  from .backward import (
14
  _down_projection_backward_act,
15
  _down_projection_backward_weight,
16
- _softmax_topk_bwd,
17
  _token_broadcast_backward,
 
18
  _up_projection_backward_act,
19
  _up_projection_backward_weight,
20
  )
21
- from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward
22
  from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton
23
- from .utils import enable_quack_gemm, is_using_quack_gemm
24
 
25
 
26
  class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
27
  @staticmethod
28
- def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]:
 
 
29
  T = router_logits.size(0)
30
 
31
- # change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy
32
  topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
33
  topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
34
 
35
- _softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K)
 
 
 
 
 
 
 
 
36
 
37
- ctx.save_for_backward(topk_router_score, topk_router_indices)
 
 
38
  ctx.E = E
39
  ctx.dtype = router_logits.dtype
 
 
40
 
41
  return topk_router_score, topk_router_indices
42
 
43
  @staticmethod
44
- def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
45
  T, K = dtopk_score.size()
46
-
47
- topk_router_score, topk_router_indices = ctx.saved_tensors
48
  dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)
49
 
50
- _softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- return dlogits, None, None
53
 
54
 
55
  class _UpProjection(torch.autograd.Function):
@@ -62,14 +84,14 @@ class _UpProjection(torch.autograd.Function):
62
  expert_frequency_offset: torch.Tensor,
63
  total_expert_freq: int,
64
  K: int,
65
- stream_id: int,
66
  x_gather_idx: torch.Tensor,
67
  s_scatter_idx: torch.Tensor,
68
  s_reverse_scatter_idx: torch.Tensor,
69
  num_activated_expert_per_token_offset: torch.Tensor,
70
- is_varlen_K: bool,
71
  activation_type: ActivationType,
72
  is_inference_mode_enabled: bool,
 
73
  ) -> torch.Tensor:
74
  T, H = x.shape
75
  I, H, E = w1.shape
@@ -78,34 +100,25 @@ class _UpProjection(torch.autograd.Function):
78
  I //= 2
79
  TK = total_expert_freq
80
 
81
- if is_using_quack_gemm():
82
- assert not torch.compiler.is_compiling()
83
- assert is_glu_activation, "QuACK GEMM does not support non GLU activation yet"
84
- z, y1 = gemm_gated(
85
- x,
86
- w1.permute(2, 1, 0),
87
- activation="swiglu",
88
- cu_seqlens_m=expert_frequency_offset,
89
- A_idx=x_gather_idx,
90
- dynamic_scheduler=False,
91
- )
92
- else:
93
- z = torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device)
94
- y1 = torch.empty(TK, I, dtype=x.dtype, device=x.device)
95
- _up_projection_forward(
96
- x=x,
97
- w1=w1,
98
- z=z,
99
- y1=y1,
100
- b1=b1,
101
- expert_frequency_offset=expert_frequency_offset,
102
- expert_schedule_order=None,
103
- x_gather_idx=x_gather_idx,
104
- stream_id=stream_id,
105
- activation_type=activation_type.value,
106
- is_glu_activation=is_glu_activation,
107
- is_inference_mode_enabled=is_inference_mode_enabled,
108
- )
109
 
110
  ctx.T = T
111
  ctx.TK = TK
@@ -113,9 +126,9 @@ class _UpProjection(torch.autograd.Function):
113
  ctx.K = K
114
  ctx.H = H
115
  ctx.I = I
116
- ctx.is_varlen_K = is_varlen_K
117
  ctx.is_glu_activation = is_glu_activation
118
- ctx.stream_id = stream_id
119
 
120
  ctx.save_for_backward(
121
  x,
@@ -128,26 +141,21 @@ class _UpProjection(torch.autograd.Function):
128
  num_activated_expert_per_token_offset,
129
  )
130
 
131
- ctx.mark_non_differentiable(y1)
132
  ctx.set_materialize_grads(False)
133
 
134
- return y1, z
135
 
136
  @staticmethod
137
- def backward(ctx, _: None, dz: torch.Tensor):
138
- is_compiling = torch.compiler.is_compiling()
139
-
140
- if not is_compiling:
141
- assert _ is None
142
-
143
  T = ctx.T
144
  TK = ctx.TK
145
  E = ctx.E
146
  K = ctx.K
147
  H = ctx.H
148
  is_glu_activation = ctx.is_glu_activation
149
- is_varlen_K = ctx.is_varlen_K
150
- stream_id = ctx.stream_id
151
 
152
  (
153
  x,
@@ -160,77 +168,57 @@ class _UpProjection(torch.autograd.Function):
160
  num_activated_expert_per_token_offset,
161
  ) = ctx.saved_tensors
162
 
 
163
  dw1 = torch.empty_like(w1)
164
  db1 = None if b1 is None else torch.empty_like(b1)
165
 
166
- if is_using_quack_gemm():
167
- assert not is_compiling
168
-
169
- gemm(
170
- x.T,
171
- dz,
172
- out=dw1.permute(2, 1, 0),
173
- cu_seqlens_k=expert_frequency_offset,
174
- A_idx=x_gather_idx,
175
- batch_idx_permute=None,
176
- dynamic_scheduler=False,
177
- )
178
- dx_expanded = gemm(dz, w1.permute(2, 0, 1), cu_seqlens_m=expert_frequency_offset, dynamic_scheduler=False)
179
- else:
180
- dx_expanded = torch.empty(TK, H, dtype=dz.dtype, device=dz.device)
181
-
182
- _up_projection_backward_act(
183
- w1=w1,
184
- dx_expanded=dx_expanded,
185
- dz=dz,
186
- db1=db1,
187
- expert_frequency_offset=expert_frequency_offset,
188
- expert_schedule_order=None,
189
- x_gather_idx=x_gather_idx,
190
- s_scatter_idx=s_scatter_idx,
191
- is_glu_activation=is_glu_activation,
192
- stream_id=stream_id,
193
- )
194
-
195
- _up_projection_backward_weight(
196
- x=x,
197
- dw1=dw1,
198
- dz=dz,
199
- expert_frequency_offset=expert_frequency_offset,
200
- expert_schedule_order=None,
201
- x_gather_idx=x_gather_idx,
202
- is_glu_activation=is_glu_activation,
203
- stream_id=stream_id,
204
- )
205
-
206
- dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
207
 
208
  _token_broadcast_backward(
209
  dx_reduced=dx_reduced,
210
  dx_expanded=dx_expanded,
211
  s_reverse_scatter_idx=s_reverse_scatter_idx,
212
  num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
213
- varlen_K_max=(E if is_varlen_K else K),
214
  H=H,
215
- is_varlen_K=is_varlen_K,
216
  )
217
 
218
- return dx_reduced, dw1, db1, *[None] * 12
219
 
220
 
221
  class _DownProjection(torch.autograd.Function):
222
  @staticmethod
223
  def forward(
224
  ctx,
225
- y1: torch.Tensor,
226
- z: torch.Tensor,
227
  w2: torch.Tensor,
228
  b2: torch.Tensor | None,
229
  topk_scores: torch.Tensor,
230
  expert_frequency_offset: torch.Tensor,
231
  T: int,
232
  K: int,
233
- stream_id: int,
234
  x_gather_idx: torch.Tensor,
235
  s_scatter_idx: torch.Tensor,
236
  s_reverse_scatter_idx: torch.Tensor,
@@ -238,32 +226,24 @@ class _DownProjection(torch.autograd.Function):
238
  is_varlen_K: bool,
239
  activation_type: ActivationType,
240
  ) -> torch.Tensor:
241
- TK = y1.size(0)
242
  H, I, E = w2.shape
243
 
244
- if is_using_quack_gemm():
245
- assert not torch.compiler.is_compiling()
246
-
247
- assert b2 is None
248
- y2 = gemm(y1, w2.permute(2, 1, 0), cu_seqlens_m=expert_frequency_offset)
249
- else:
250
- y2 = torch.empty(TK, H, dtype=y1.dtype, device=y1.device)
251
- _down_projection_forward(
252
- w2=w2,
253
- y1=y1,
254
- y2=y2,
255
- b2=b2,
256
- expert_frequency_offset=expert_frequency_offset,
257
- expert_schedule_order=None,
258
- x_gather_idx=x_gather_idx,
259
- stream_id=stream_id,
260
- )
261
-
262
- o = torch.empty(T, H, device=z.device, dtype=z.dtype)
263
- topk_scores = topk_scores.flatten()
264
 
265
  _router_forward(
266
- y2=y2,
267
  o=o,
268
  topk_scores=topk_scores,
269
  s_reverse_scatter_idx=s_reverse_scatter_idx,
@@ -277,17 +257,15 @@ class _DownProjection(torch.autograd.Function):
277
  ctx.K = K
278
  ctx.is_varlen_K = is_varlen_K
279
  ctx.activation_type = activation_type
280
- ctx.stream_id = stream_id
281
 
282
  ctx.save_for_backward(
283
- z,
284
  w2,
285
  b2,
286
  topk_scores,
287
  expert_frequency_offset,
288
  x_gather_idx,
289
  s_scatter_idx,
290
- s_reverse_scatter_idx,
291
  )
292
 
293
  return o
@@ -296,96 +274,58 @@ class _DownProjection(torch.autograd.Function):
296
  def backward(ctx, dout: torch.Tensor):
297
  T = ctx.T
298
  K = ctx.K
299
- stream_id = ctx.stream_id
300
  is_varlen_K = ctx.is_varlen_K
301
  activation_type = ctx.activation_type
302
 
303
  (
304
- z,
305
  w2,
306
  b2,
307
  topk_scores,
308
  expert_frequency_offset,
309
  x_gather_idx,
310
  s_scatter_idx,
311
- s_reverse_scatter_idx,
312
  ) = ctx.saved_tensors
313
 
314
  dw2 = torch.empty_like(w2)
315
  db2 = None if b2 is None else torch.empty_like(b2)
316
- dz = torch.empty_like(z)
317
-
318
- if is_using_quack_gemm():
319
- assert not torch.compiler.is_compiling()
320
- assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
321
-
322
- s = topk_scores[s_scatter_idx]
323
- _, y1s, ds = gemm_dgated(
324
- dout,
325
- w2.permute(2, 0, 1),
326
- PreAct=z,
327
- activation="swiglu",
328
- dx_out=dz,
329
- colvec_scale=s,
330
- colvec_reduce=True,
331
- cu_seqlens_m=expert_frequency_offset,
332
- A_idx=x_gather_idx,
333
- dynamic_scheduler=False,
334
- )
335
- gemm(
336
- dout.T,
337
- y1s,
338
- out=dw2.permute(2, 0, 1),
339
- cu_seqlens_k=expert_frequency_offset,
340
- A_idx=x_gather_idx,
341
- batch_idx_permute=None,
342
- dynamic_scheduler=False,
343
- )
344
-
345
- ds = ds[s_reverse_scatter_idx]
346
- else:
347
- ds = torch.empty_like(topk_scores)
348
-
349
- I = w2.size(1)
350
- TK = x_gather_idx.size(0)
351
-
352
- y1s = torch.empty(TK, I, dtype=z.dtype, device=z.device)
353
- is_glu_activation = is_glu(activation_type)
354
-
355
- _down_projection_backward_act(
356
- dout=dout,
357
- z=z,
358
- w2=w2,
359
- dz=dz,
360
- ds=ds,
361
- b2=b2,
362
- db2=db2,
363
- y1s=y1s,
364
- topk_scores=topk_scores,
365
- expert_frequency_offset=expert_frequency_offset,
366
- expert_schedule_order=None,
367
- x_gather_idx=x_gather_idx,
368
- s_scatter_idx=s_scatter_idx,
369
- is_glu_activation=is_glu_activation,
370
- activation_type=activation_type.value,
371
- stream_id=stream_id,
372
- )
373
-
374
- _down_projection_backward_weight(
375
- dout=dout,
376
- y1s=y1s,
377
- dw2=dw2,
378
- expert_frequency_offset=expert_frequency_offset,
379
- expert_schedule_order=None,
380
- x_gather_idx=x_gather_idx,
381
- stream_id=stream_id,
382
- )
383
 
384
  # TC top-K routing
385
  if not is_varlen_K:
386
  ds = ds.view(T, K)
387
 
388
- return None, dz, dw2, db2, ds, *[None] * 10
389
 
390
 
391
  def moe_TC_softmax_topk_layer(
@@ -399,13 +339,18 @@ def moe_TC_softmax_topk_layer(
399
  stream_id: int,
400
  activation_type: ActivationType | str = ActivationType.SWIGLU,
401
  is_inference_mode_enabled: bool = False,
 
 
 
402
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
403
  assert ((b1 is None) and (b2 is None)) or (
404
  (b1 is not None) and (b2 is not None)
405
  ), "b1 and b2 has to be None or not None at the same time!"
406
  E = router_w.size(0)
407
  router_logits = F.linear(x, router_w)
408
- topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, E, K)
 
 
409
 
410
  T, K = topk_indices.size()
411
  TK = T * K
@@ -421,43 +366,43 @@ def moe_TC_softmax_topk_layer(
421
  topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx
422
  )
423
 
424
- T = x.size(0)
425
-
426
  if type(activation_type) == str:
427
  activation_type = ActivationType(activation_type)
428
 
429
- y1, z = _UpProjection.apply(
 
 
 
430
  x,
431
  w1,
432
  b1,
433
  expert_frequency_offset,
434
- T * K,
435
  K,
436
- stream_id,
437
  x_gather_idx,
438
  s_scatter_idx,
439
  s_reverse_scatter_idx,
440
  None,
441
- False, # is_varlen_K
442
  activation_type,
443
  is_inference_mode_enabled,
 
444
  )
445
 
446
  o = _DownProjection.apply(
447
- y1,
448
- z,
449
  w2,
450
  b2,
451
  topk_scores,
452
  expert_frequency_offset,
453
  T,
454
  K,
455
- stream_id,
456
  x_gather_idx,
457
  s_scatter_idx,
458
  s_reverse_scatter_idx,
459
  None,
460
- False, # is_varlen_K
461
  activation_type,
462
  )
463
 
@@ -466,7 +411,9 @@ def moe_TC_softmax_topk_layer(
466
 
467
  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
468
  # Weight format requirements:
469
- # - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1), must be interleaved [gate_row0, up_row0, gate_row1, up_row1, ...]
 
 
470
  # - w2_weight: Shape (H, I, E), stride order (2, 0, 1)
471
 
472
 
@@ -486,6 +433,7 @@ def moe_general_routing_inputs(
486
  stream_id: int,
487
  activation_type: ActivationType,
488
  is_inference_mode_enabled: bool = False,
 
489
  ) -> tuple[torch.Tensor, torch.Tensor]:
490
  assert ((b1 is None) and (b2 is None)) or (
491
  (b1 is not None) and (b2 is not None)
@@ -496,6 +444,9 @@ def moe_general_routing_inputs(
496
  E = w2.size(-1)
497
  device = router_scores.device
498
 
 
 
 
499
  s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
500
  s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
501
  expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
@@ -516,38 +467,40 @@ def moe_general_routing_inputs(
516
  num_activated_expert_per_token_offset,
517
  )
518
 
519
- y1, z = _UpProjection.apply(
 
 
 
520
  x,
521
  w1,
522
  b1,
523
  expert_frequency_offset,
524
  TK,
525
  None, # K, not needed
526
- stream_id,
527
  x_gather_idx,
528
  s_scatter_idx,
529
  s_reverse_scatter_idx,
530
  num_activated_expert_per_token_offset,
531
- True, # is_varlen_K
532
  activation_type,
533
  is_inference_mode_enabled,
 
534
  )
535
 
536
  o = _DownProjection.apply(
537
- y1,
538
- z,
539
  w2,
540
  b2,
541
  router_scores,
542
  expert_frequency_offset,
543
  T,
544
  None, # K, not needed
545
- stream_id,
546
  x_gather_idx,
547
  s_scatter_idx,
548
  s_reverse_scatter_idx,
549
  num_activated_expert_per_token_offset,
550
- True, # is_varlen_K
551
  activation_type,
552
  )
553
 
 
6
 
7
  import torch
8
  import torch.nn.functional as F
9
+ from ..quack.gemm_interface import gemm, gemm_dgated, gemm_gated
10
 
11
  from ..enums import ActivationType, is_glu
 
12
  from .backward import (
13
  _down_projection_backward_act,
14
  _down_projection_backward_weight,
 
15
  _token_broadcast_backward,
16
+ _topk_softmax_bwd,
17
  _up_projection_backward_act,
18
  _up_projection_backward_weight,
19
  )
20
+ from .forward import _down_projection_forward, _router_forward, _topk_softmax_fwd, _up_projection_forward
21
  from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton
 
22
 
23
 
24
  class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
25
  @staticmethod
26
+ def forward(
27
+ ctx, router_logits: torch.Tensor, E: int, K: int, is_softmax_over_topk: bool, norm_topk_probs: bool
28
+ ) -> tuple[torch.Tensor, torch.Tensor]:
29
  T = router_logits.size(0)
30
 
 
31
  topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
32
  topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
33
 
34
+ _topk_softmax_fwd(
35
+ router_logits,
36
+ topk_router_score,
37
+ topk_router_indices,
38
+ E,
39
+ K,
40
+ is_softmax_over_topk=is_softmax_over_topk,
41
+ norm_topk_probs=norm_topk_probs,
42
+ )
43
 
44
+ # Save router_logits for topk(softmax()) backward (recompute full softmax).
45
+ # For softmax(topk()) it's unused but save unconditionally for simplicity.
46
+ ctx.save_for_backward(topk_router_score, topk_router_indices, router_logits)
47
  ctx.E = E
48
  ctx.dtype = router_logits.dtype
49
+ ctx.is_softmax_over_topk = is_softmax_over_topk
50
+ ctx.norm_topk_probs = norm_topk_probs
51
 
52
  return topk_router_score, topk_router_indices
53
 
54
  @staticmethod
55
+ def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor):
56
  T, K = dtopk_score.size()
57
+ E = ctx.E
58
+ topk_router_score, topk_router_indices, router_logits = ctx.saved_tensors
59
  dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)
60
 
61
+ _topk_softmax_bwd(
62
+ router_logits,
63
+ dlogits,
64
+ None,
65
+ dtopk_score,
66
+ topk_router_score,
67
+ topk_router_indices,
68
+ E,
69
+ K,
70
+ is_softmax_over_topk=ctx.is_softmax_over_topk,
71
+ norm_topk_probs=ctx.norm_topk_probs,
72
+ )
73
 
74
+ return dlogits, None, None, None, None
75
 
76
 
77
  class _UpProjection(torch.autograd.Function):
 
84
  expert_frequency_offset: torch.Tensor,
85
  total_expert_freq: int,
86
  K: int,
 
87
  x_gather_idx: torch.Tensor,
88
  s_scatter_idx: torch.Tensor,
89
  s_reverse_scatter_idx: torch.Tensor,
90
  num_activated_expert_per_token_offset: torch.Tensor,
91
+ is_each_token_has_variable_activated_experts: bool,
92
  activation_type: ActivationType,
93
  is_inference_mode_enabled: bool,
94
+ concat_layout: bool = False,
95
  ) -> torch.Tensor:
96
  T, H = x.shape
97
  I, H, E = w1.shape
 
100
  I //= 2
101
  TK = total_expert_freq
102
 
103
+ a = torch.empty(TK, I, dtype=x.dtype, device=x.device)
104
+ h = (
105
+ torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device)
106
+ if (not is_inference_mode_enabled)
107
+ else None
108
+ )
109
+
110
+ _up_projection_forward(
111
+ x=x,
112
+ w1=w1,
113
+ h=h,
114
+ a=a,
115
+ b1=b1,
116
+ expert_frequency_offset=expert_frequency_offset,
117
+ x_gather_idx=x_gather_idx,
118
+ activation_type=activation_type.value,
119
+ is_inference_mode_enabled=is_inference_mode_enabled,
120
+ concat_layout=concat_layout,
121
+ )
 
 
 
 
 
 
 
 
 
122
 
123
  ctx.T = T
124
  ctx.TK = TK
 
126
  ctx.K = K
127
  ctx.H = H
128
  ctx.I = I
129
+ ctx.is_each_token_has_variable_activated_experts = is_each_token_has_variable_activated_experts
130
  ctx.is_glu_activation = is_glu_activation
131
+ ctx.concat_layout = concat_layout
132
 
133
  ctx.save_for_backward(
134
  x,
 
141
  num_activated_expert_per_token_offset,
142
  )
143
 
144
+ ctx.mark_non_differentiable(a)
145
  ctx.set_materialize_grads(False)
146
 
147
+ return a, h
148
 
149
  @staticmethod
150
+ def backward(ctx, _: None, dh: torch.Tensor):
 
 
 
 
 
151
  T = ctx.T
152
  TK = ctx.TK
153
  E = ctx.E
154
  K = ctx.K
155
  H = ctx.H
156
  is_glu_activation = ctx.is_glu_activation
157
+ is_each_token_has_variable_activated_experts = ctx.is_each_token_has_variable_activated_experts
158
+ concat_layout = ctx.concat_layout
159
 
160
  (
161
  x,
 
168
  num_activated_expert_per_token_offset,
169
  ) = ctx.saved_tensors
170
 
171
+ dx_expanded = torch.empty(TK, H, dtype=dh.dtype, device=dh.device)
172
  dw1 = torch.empty_like(w1)
173
  db1 = None if b1 is None else torch.empty_like(b1)
174
 
175
+ _up_projection_backward_act(
176
+ w1=w1,
177
+ dx_expanded=dx_expanded,
178
+ dh=dh,
179
+ db1=db1,
180
+ expert_frequency_offset=expert_frequency_offset,
181
+ is_glu_activation=is_glu_activation,
182
+ concat_layout=concat_layout,
183
+ )
184
+
185
+ _up_projection_backward_weight(
186
+ x=x,
187
+ dw1=dw1,
188
+ dh=dh,
189
+ expert_frequency_offset=expert_frequency_offset,
190
+ x_gather_idx=x_gather_idx,
191
+ is_glu_activation=is_glu_activation,
192
+ concat_layout=concat_layout,
193
+ )
194
+
195
+ dx_reduced = torch.empty(T, H, dtype=dh.dtype, device=dh.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  _token_broadcast_backward(
198
  dx_reduced=dx_reduced,
199
  dx_expanded=dx_expanded,
200
  s_reverse_scatter_idx=s_reverse_scatter_idx,
201
  num_activated_expert_per_token_offset=num_activated_expert_per_token_offset,
202
+ varlen_K_max=(E if is_each_token_has_variable_activated_experts else K),
203
  H=H,
204
+ is_varlen_K=is_each_token_has_variable_activated_experts,
205
  )
206
 
207
+ return dx_reduced, dw1, db1, *[None] * 13
208
 
209
 
210
  class _DownProjection(torch.autograd.Function):
211
  @staticmethod
212
  def forward(
213
  ctx,
214
+ a: torch.Tensor,
215
+ h: torch.Tensor,
216
  w2: torch.Tensor,
217
  b2: torch.Tensor | None,
218
  topk_scores: torch.Tensor,
219
  expert_frequency_offset: torch.Tensor,
220
  T: int,
221
  K: int,
 
222
  x_gather_idx: torch.Tensor,
223
  s_scatter_idx: torch.Tensor,
224
  s_reverse_scatter_idx: torch.Tensor,
 
226
  is_varlen_K: bool,
227
  activation_type: ActivationType,
228
  ) -> torch.Tensor:
229
+ TK = a.size(0)
230
  H, I, E = w2.shape
231
 
232
+ y = torch.empty(TK, H, dtype=a.dtype, device=a.device)
233
+
234
+ _down_projection_forward(
235
+ w2=w2,
236
+ a=a,
237
+ y=y,
238
+ b2=b2,
239
+ expert_frequency_offset=expert_frequency_offset,
240
+ )
241
+
242
+ o = torch.empty(T, H, device=a.device, dtype=a.dtype)
243
+ topk_scores = topk_scores.view(-1)
 
 
 
 
 
 
 
 
244
 
245
  _router_forward(
246
+ y=y,
247
  o=o,
248
  topk_scores=topk_scores,
249
  s_reverse_scatter_idx=s_reverse_scatter_idx,
 
257
  ctx.K = K
258
  ctx.is_varlen_K = is_varlen_K
259
  ctx.activation_type = activation_type
 
260
 
261
  ctx.save_for_backward(
262
+ h,
263
  w2,
264
  b2,
265
  topk_scores,
266
  expert_frequency_offset,
267
  x_gather_idx,
268
  s_scatter_idx,
 
269
  )
270
 
271
  return o
 
274
  def backward(ctx, dout: torch.Tensor):
275
  T = ctx.T
276
  K = ctx.K
 
277
  is_varlen_K = ctx.is_varlen_K
278
  activation_type = ctx.activation_type
279
 
280
  (
281
+ h,
282
  w2,
283
  b2,
284
  topk_scores,
285
  expert_frequency_offset,
286
  x_gather_idx,
287
  s_scatter_idx,
 
288
  ) = ctx.saved_tensors
289
 
290
  dw2 = torch.empty_like(w2)
291
  db2 = None if b2 is None else torch.empty_like(b2)
292
+ dh = torch.empty_like(h)
293
+
294
+ I = w2.size(1)
295
+ TK = x_gather_idx.size(0)
296
+
297
+ a_prime = torch.empty(TK, I, dtype=h.dtype, device=h.device)
298
+ ds = torch.empty_like(topk_scores)
299
+
300
+ _down_projection_backward_act(
301
+ dout=dout,
302
+ h=h,
303
+ w2=w2,
304
+ dh=dh,
305
+ ds=ds,
306
+ b2=b2,
307
+ db2=db2,
308
+ a_prime=a_prime,
309
+ topk_scores=topk_scores,
310
+ expert_frequency_offset=expert_frequency_offset,
311
+ x_gather_idx=x_gather_idx,
312
+ s_scatter_idx=s_scatter_idx,
313
+ activation_type=activation_type.value,
314
+ )
315
+
316
+ _down_projection_backward_weight(
317
+ dout=dout,
318
+ a_prime=a_prime,
319
+ dw2=dw2,
320
+ expert_frequency_offset=expert_frequency_offset,
321
+ x_gather_idx=x_gather_idx,
322
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  # TC top-K routing
325
  if not is_varlen_K:
326
  ds = ds.view(T, K)
327
 
328
+ return None, dh, dw2, db2, ds, *[None] * 10
329
 
330
 
331
  def moe_TC_softmax_topk_layer(
 
339
  stream_id: int,
340
  activation_type: ActivationType | str = ActivationType.SWIGLU,
341
  is_inference_mode_enabled: bool = False,
342
+ is_softmax_over_topk: bool = True,
343
+ norm_topk_probs: bool = False,
344
+ concat_layout: bool = False,
345
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
346
  assert ((b1 is None) and (b2 is None)) or (
347
  (b1 is not None) and (b2 is not None)
348
  ), "b1 and b2 has to be None or not None at the same time!"
349
  E = router_w.size(0)
350
  router_logits = F.linear(x, router_w)
351
+ topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(
352
+ router_logits, E, K, is_softmax_over_topk, norm_topk_probs
353
+ )
354
 
355
  T, K = topk_indices.size()
356
  TK = T * K
 
366
  topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx
367
  )
368
 
 
 
369
  if type(activation_type) == str:
370
  activation_type = ActivationType(activation_type)
371
 
372
+ assert not torch.compiler.is_compiling()
373
+ assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
374
+
375
+ a, h = _UpProjection.apply(
376
  x,
377
  w1,
378
  b1,
379
  expert_frequency_offset,
380
+ TK,
381
  K,
 
382
  x_gather_idx,
383
  s_scatter_idx,
384
  s_reverse_scatter_idx,
385
  None,
386
+ False, # is_each_token_has_variable_activated_expert
387
  activation_type,
388
  is_inference_mode_enabled,
389
+ concat_layout,
390
  )
391
 
392
  o = _DownProjection.apply(
393
+ a,
394
+ h,
395
  w2,
396
  b2,
397
  topk_scores,
398
  expert_frequency_offset,
399
  T,
400
  K,
 
401
  x_gather_idx,
402
  s_scatter_idx,
403
  s_reverse_scatter_idx,
404
  None,
405
+ False, # is_each_token_has_variable_activated_expert
406
  activation_type,
407
  )
408
 
 
411
 
412
  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
413
  # Weight format requirements:
414
+ # - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1)
415
+ # concat_layout=False (default): interleaved [gate_row0, up_row0, gate_row1, up_row1, ...]
416
+ # concat_layout=True: concatenated [gate_row0, ..., gate_row_{I-1}, up_row0, ..., up_row_{I-1}]
417
  # - w2_weight: Shape (H, I, E), stride order (2, 0, 1)
418
 
419
 
 
433
  stream_id: int,
434
  activation_type: ActivationType,
435
  is_inference_mode_enabled: bool = False,
436
+ concat_layout: bool = False,
437
  ) -> tuple[torch.Tensor, torch.Tensor]:
438
  assert ((b1 is None) and (b2 is None)) or (
439
  (b1 is not None) and (b2 is not None)
 
444
  E = w2.size(-1)
445
  device = router_scores.device
446
 
447
+ if router_scores.dtype != torch.float32:
448
+ router_scores = router_scores.float()
449
+
450
  s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
451
  s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device)
452
  expert_frequency = torch.empty(E, dtype=torch.int32, device=device)
 
467
  num_activated_expert_per_token_offset,
468
  )
469
 
470
+ assert not torch.compiler.is_compiling()
471
+ assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"
472
+
473
+ a, h = _UpProjection.apply(
474
  x,
475
  w1,
476
  b1,
477
  expert_frequency_offset,
478
  TK,
479
  None, # K, not needed
 
480
  x_gather_idx,
481
  s_scatter_idx,
482
  s_reverse_scatter_idx,
483
  num_activated_expert_per_token_offset,
484
+ True, # is_each_token_has_variable_activated_expert
485
  activation_type,
486
  is_inference_mode_enabled,
487
+ concat_layout,
488
  )
489
 
490
  o = _DownProjection.apply(
491
+ a,
492
+ h,
493
  w2,
494
  b2,
495
  router_scores,
496
  expert_frequency_offset,
497
  T,
498
  None, # K, not needed
 
499
  x_gather_idx,
500
  s_scatter_idx,
501
  s_reverse_scatter_idx,
502
  num_activated_expert_per_token_offset,
503
+ True, # is_each_token_has_variable_activated_expert
504
  activation_type,
505
  )
506
 
build/torch-cuda/functional/backward.py CHANGED
@@ -9,16 +9,10 @@ import cutlass.cute as cute
9
  import torch
10
  import triton
11
  import triton.language as tl
 
12
 
13
  from .._ops_compat import add_op_namespace_prefix
14
- from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
15
- from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2
16
- from .moe_config import (
17
- HopperWgmma_MoE_Down_proj_ActGrad_Bwd,
18
- HopperWgmma_MoE_Down_proj_WeightGrad_Bwd,
19
- HopperWgmma_MoE_Up_proj_ActGrad_Bwd,
20
- HopperWgmma_MoE_Up_proj_WeightGrad_Bwd,
21
- )
22
  from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
23
 
24
 
@@ -132,28 +126,29 @@ def _prune_triton_autotune_config(configs, nargs, **kw):
132
  )
133
  @triton.jit
134
  def db1_kernel(
135
- dz_ptr, # (T, H)
136
- db1_ptr, # (E, H),
137
- expert_offset_ptr, # (E+1,), offsets in grouped layout
138
  I: tl.constexpr,
139
  E: tl.constexpr,
140
- BLOCK_I: tl.constexpr, # Block size for H dimension
141
- BLOCK_TK: tl.constexpr, # Block size for token dimension
 
142
  ):
143
- Eidx = tl.program_id(0) # expert id
144
 
145
  E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64)
146
  E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64)
147
  n_tokens = E_count_end - E_count_start
148
 
149
  NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I)
 
150
  for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1):
151
  i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I)
152
  i_mask = i_offsets < I
153
 
154
  db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32)
155
 
156
- # Process tokens in blocks of BLOCK_TK
157
  for block_start in tl.range(0, n_tokens, BLOCK_TK):
158
  # Token offsets within this block
159
  tk_offsets = block_start + tl.arange(0, BLOCK_TK)
@@ -162,102 +157,52 @@ def db1_kernel(
162
 
163
  dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :]
164
  dz_mask = tk_mask[:, None] & i_mask[None, :]
165
- dz = tl.load(dz_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32)
166
 
167
- db1_acc += tl.sum(dz, axis=0) # Sum over BLOCK_TK dimension
168
 
169
- db1_offsets = Eidx.to(tl.int64) * I + i_offsets
 
 
 
 
 
170
  tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask)
171
 
172
 
173
- @triton.jit
174
- def _colsum_smallN_kernel(
175
- y_ptr, # *mut T, shape [M]
176
- x_ptr, # *const T, shape [M, N]
177
- stride_xm: tl.constexpr,
178
- stride_xn: tl.constexpr, # strides of X
179
- stride_y: tl.constexpr, # stride of Y (usually 1)
180
- N: tl.constexpr, # sizes
181
- BLOCK_N: tl.constexpr, # tile size along N
182
- ):
183
- row = tl.program_id(0)
184
-
185
- # assume BLOCK_N >= N
186
- offs = tl.arange(0, BLOCK_N)
187
- mask = offs < N
188
- # Load a tile from the row; cast to fp32 for the reduction
189
- x = tl.load(x_ptr + row * stride_xm + offs * stride_xn, mask=mask, other=0).to(tl.float32)
190
- # Reduce this tile to a scalar and add
191
- acc = tl.sum(x, axis=0)
192
-
193
- # Store the row-sum (cast back to y dtype)
194
- tl.store(y_ptr + row * stride_y, acc)
195
-
196
-
197
  @torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"})
198
  def _up_projection_backward_act(
199
  w1: torch.Tensor,
200
  dx_expanded: torch.Tensor,
201
- dz: torch.Tensor,
202
  db1: torch.Tensor | None,
203
  expert_frequency_offset: torch.Tensor,
204
- expert_schedule_order: torch.Tensor | None,
205
- x_gather_idx: torch.Tensor,
206
- s_scatter_idx: torch.Tensor,
207
  is_glu_activation: bool,
208
- stream_id: int,
209
  ) -> None:
210
  I, H, E = w1.size()
211
  if is_glu_activation:
212
  I //= 2
213
 
 
 
 
 
 
 
 
 
 
214
  # db1 computation
215
  if db1 is not None:
216
- db1_kernel[(E,)](dz, db1, expert_frequency_offset, (2 * I if is_glu_activation else I), E)
217
-
218
- mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
219
- mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
220
- mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
221
- mDz = convert_torch_tensor_to_cute_tensor(dz, (0, 1), 1, 16, 8, stream=stream_id)
222
- mDx_expanded = convert_torch_tensor_to_cute_tensor(dx_expanded, (0, 1), 1, 16, 8, stream=stream_id)
223
- mW1_trans = convert_torch_tensor_to_cute_tensor(w1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
224
-
225
- if expert_schedule_order is None:
226
- mE_permute_order = None
227
- else:
228
- mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
229
- current_stream = cuda.CUstream(stream_id)
230
-
231
- compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype)
232
- if compile_dx_key not in _up_projection_backward_act.compile_cache:
233
- dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation)
234
- tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
235
- _up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
236
- dx_module,
237
- mDz,
238
- mW1_trans,
239
- mDx_expanded,
240
- mE_offset,
241
- mX_gather,
242
- mS_scatter,
243
- tensormaps,
244
- mE_permute_order,
245
- current_stream,
246
  )
247
- _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps
248
-
249
- dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"]
250
- _up_projection_backward_act.compile_cache[compile_dx_key](
251
- mDz,
252
- mW1_trans,
253
- mDx_expanded,
254
- mE_offset,
255
- mX_gather,
256
- mS_scatter,
257
- dx_tensormaps,
258
- mE_permute_order,
259
- current_stream,
260
- )
261
 
262
 
263
  _up_projection_backward_act.compile_cache = {}
@@ -267,199 +212,87 @@ _up_projection_backward_act.compile_cache = {}
267
  def _up_projection_backward_weight(
268
  x: torch.Tensor,
269
  dw1: torch.Tensor,
270
- dz: torch.Tensor,
271
  expert_frequency_offset: torch.Tensor,
272
- expert_schedule_order: torch.Tensor | None,
273
  x_gather_idx: torch.Tensor,
274
  is_glu_activation: bool,
275
- stream_id: int,
276
  ) -> None:
277
  I, H, E = dw1.size()
278
  if is_glu_activation:
279
  I //= 2
280
 
281
- x = x.detach()
282
-
283
- mDz_trans = convert_torch_tensor_to_cute_tensor(dz.T, (1, 0), 0, 16, 8, stream=stream_id)
284
- mDw1_trans = convert_torch_tensor_to_cute_tensor(dw1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
285
-
286
- mX_trans = convert_torch_tensor_to_cute_tensor(x.T, (1, 0), 0, 16, 8, stream=stream_id)
287
- mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
288
- mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
289
-
290
- if expert_schedule_order is None:
291
- mE_permute_order = None
292
- else:
293
- mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
294
- current_stream = cuda.CUstream(stream_id)
295
-
296
- compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype)
297
- if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
298
- dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation)
299
- tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
300
- _up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
301
- dw1_module,
302
- mX_trans,
303
- mDz_trans,
304
- mDw1_trans,
305
- mE_offset,
306
- mX_gather,
307
- tensormaps,
308
- mE_permute_order,
309
- current_stream,
310
- )
311
- _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps
312
-
313
- dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"]
314
- _up_projection_backward_weight.compile_cache[compile_dw1_key](
315
- mX_trans,
316
- mDz_trans,
317
- mDw1_trans,
318
- mE_offset,
319
- mX_gather,
320
- dw1_tensormaps,
321
- mE_permute_order,
322
- current_stream,
323
  )
324
 
325
 
326
  _up_projection_backward_weight.compile_cache = {}
327
 
328
 
329
- @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dz", "ds", "db2", "y1s"})
330
  def _down_projection_backward_act(
331
  dout: torch.Tensor,
332
- z: torch.Tensor,
333
  w2: torch.Tensor,
334
- dz: torch.Tensor,
335
  ds: torch.Tensor,
336
  b2: torch.Tensor | None,
337
- db2: torch.Tensor | None,
338
- y1s: torch.Tensor,
339
  topk_scores: torch.Tensor,
340
  expert_frequency_offset: torch.Tensor,
341
- expert_schedule_order: torch.Tensor | None,
342
  x_gather_idx: torch.Tensor,
343
  s_scatter_idx: torch.Tensor,
344
- is_glu_activation: bool,
345
  activation_type: str,
346
- stream_id: int,
347
  ) -> None:
348
- H, I, E = w2.size()
349
- TK = x_gather_idx.size(0)
350
-
351
- dout = dout.detach()
352
- w2 = w2.detach()
353
- topk_scores = topk_scores.detach()
354
-
355
- mDout = convert_torch_tensor_to_cute_tensor(dout, (0, 1), 1, 16, 8, stream=stream_id)
356
- mW2_trans = convert_torch_tensor_to_cute_tensor(w2.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id)
357
- mS = convert_torch_tensor_to_cute_tensor(topk_scores, (0,), 0, 4, 1, stream=stream_id)
358
- if is_glu_activation:
359
- mDz_kernel_input = convert_torch_tensor_to_cute_tensor(
360
- dz.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id
361
- )
362
- mZ_kernel_input = convert_torch_tensor_to_cute_tensor(
363
- z.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id
364
- )
365
- else:
366
- mDz_kernel_input = convert_torch_tensor_to_cute_tensor(dz.detach(), (0, 1), 1, 16, 8, stream=stream_id)
367
- mZ_kernel_input = convert_torch_tensor_to_cute_tensor(z.detach(), (0, 1), 1, 16, 8, stream=stream_id)
368
-
369
- mY1S = convert_torch_tensor_to_cute_tensor(y1s, (0, 1), 1, 16, 8, stream=stream_id)
370
- mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
371
- mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
372
- mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id)
373
-
374
- if expert_schedule_order is None:
375
- mE_permute_order = None
376
- else:
377
- mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
378
- current_stream = cuda.CUstream(stream_id)
379
- ds_partial = None
380
-
381
- compile_dz_key = ("dz", E, H, I, z.dtype, activation_type)
382
- if compile_dz_key not in _down_projection_backward_act.compile_cache:
383
- # I don't know why but this sync appears to fix a mysterious initialization bug??
384
- torch.cuda.synchronize()
385
- dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type))
386
- tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)]
387
-
388
- ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1)
389
- ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
390
- mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
391
-
392
- _down_projection_backward_act.compile_cache["ds_partial_N"] = ds_partial_N
393
- _down_projection_backward_act.compile_cache[compile_dz_key] = cute.compile(
394
- dz_module,
395
- mDout,
396
- mW2_trans,
397
- mZ_kernel_input,
398
- mDz_kernel_input,
399
- mY1S,
400
- mS,
401
- mDS_partial,
402
- mE_offset,
403
- mX_gather,
404
- mS_scatter,
405
- tensormaps,
406
- mE_permute_order,
407
- current_stream,
408
- )
409
- _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps
410
-
411
- if ds_partial is None:
412
- ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"]
413
- ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
414
- mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
415
-
416
- dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"]
417
- _down_projection_backward_act.compile_cache[compile_dz_key](
418
- mDout,
419
- mW2_trans,
420
- mZ_kernel_input,
421
- mDz_kernel_input,
422
- mY1S,
423
- mS,
424
- mDS_partial,
425
- mE_offset,
426
- mX_gather,
427
- mS_scatter,
428
- dz_tensormaps,
429
- mE_permute_order,
430
- current_stream,
431
  )
 
432
 
433
  if db2 is None:
434
- # we don't need to update ds
435
- if ds_partial.size(1) == 1:
436
- ds.copy_(ds_partial.view(-1).to(dtype=ds.dtype))
437
- elif ds_partial.size(1) <= 32:
438
- ds.copy_(ds_partial.sum(dim=-1, dtype=ds.dtype))
439
- else:
440
- M, N = ds_partial.size()
441
-
442
- _colsum_smallN_kernel[M,](
443
- y_ptr=ds,
444
- x_ptr=ds_partial,
445
- stride_xm=ds_partial.stride(0),
446
- stride_xn=ds_partial.stride(1),
447
- stride_y=1,
448
- N=N,
449
- BLOCK_N=triton.next_power_of_2(N),
450
- )
451
  else:
452
- # db2 and ds update
 
 
 
 
 
 
453
  BLOCK_H = min(triton.next_power_of_2(H), 2048)
454
  NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H)
455
-
456
- new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, device=ds.device, dtype=torch.float32)
457
 
458
  db2_and_ds_kernel[(E, NUM_H_BLOCKS)](
459
  dout,
460
  topk_scores,
461
  new_ds_partial,
462
- ds_partial,
463
  b2,
464
  db2,
465
  x_gather_idx,
@@ -467,9 +300,9 @@ def _down_projection_backward_act(
467
  expert_frequency_offset,
468
  H,
469
  E,
470
- ds_partial_N,
471
  BLOCK_H=BLOCK_H,
472
- BLOCK_OLD_DS_PARTIAL_N=triton.next_power_of_2(ds_partial_N),
473
  )
474
 
475
  if NUM_H_BLOCKS == 1:
@@ -484,47 +317,19 @@ _down_projection_backward_act.compile_cache = {}
484
  @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"})
485
  def _down_projection_backward_weight(
486
  dout: torch.Tensor,
487
- y1s: torch.Tensor,
488
  dw2: torch.Tensor,
489
  expert_frequency_offset: torch.Tensor,
490
- expert_schedule_order: torch.Tensor | None,
491
  x_gather_idx: torch.Tensor,
492
- stream_id: int,
493
  ) -> None:
494
- H, I, E = dw2.size()
495
-
496
- mDout_trans = convert_torch_tensor_to_cute_tensor(dout.T, (1, 0), 0, 16, 8, stream=stream_id)
497
- mDw2 = convert_torch_tensor_to_cute_tensor(dw2, (2, 0, 1), 1, 16, 8, stream=stream_id)
498
- mY1S_trans = convert_torch_tensor_to_cute_tensor(y1s.T, (1, 0), 0, 16, 8, stream=stream_id)
499
- mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
500
- mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
501
-
502
- if expert_schedule_order is None:
503
- mE_permute_order = None
504
- else:
505
- mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
506
- current_stream = cuda.CUstream(stream_id)
507
-
508
- compile_dw2_key = ("dw2", E, H, I, dw2.dtype)
509
- if compile_dw2_key not in _down_projection_backward_weight.compile_cache:
510
- dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I)
511
- tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
512
- _down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile(
513
- dw2_module,
514
- mDout_trans,
515
- mY1S_trans,
516
- mDw2,
517
- mE_offset,
518
- mX_gather,
519
- tensormaps,
520
- mE_permute_order,
521
- current_stream,
522
- )
523
- _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps
524
-
525
- dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"]
526
- _down_projection_backward_weight.compile_cache[compile_dw2_key](
527
- mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream
528
  )
529
 
530
 
@@ -557,7 +362,7 @@ def _token_broadcast_backward(
557
 
558
 
559
  @triton.jit
560
- def _softmax_bwd_scatter_small_kernel(
561
  dlogits_ptr,
562
  dlogits_full_ptr,
563
  score_ptr,
@@ -597,35 +402,171 @@ def _softmax_bwd_scatter_small_kernel(
597
  tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
598
 
599
 
600
- @torch.library.custom_op(add_op_namespace_prefix("_softmax_topk_bwd"), mutates_args={"dlogits_full"})
601
- def _softmax_topk_bwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  dlogits_full: torch.Tensor,
603
  dlogits: Optional[torch.Tensor],
604
  dtopk_score: torch.Tensor,
605
  topk_router_score: torch.Tensor,
606
  topk_router_indices: torch.Tensor,
 
607
  K: int,
 
 
608
  ) -> None:
609
  T = dtopk_score.shape[0]
610
 
611
- _softmax_bwd_scatter_small_kernel[T,](
612
- dlogits,
613
- dlogits_full,
614
- topk_router_score,
615
- dtopk_score,
616
- topk_router_indices,
617
- dlogits_full.stride(0),
618
- dlogits_full.stride(1),
619
- topk_router_score.stride(0),
620
- topk_router_score.stride(1),
621
- dtopk_score.stride(0),
622
- dtopk_score.stride(1),
623
- topk_router_indices.stride(0),
624
- topk_router_indices.stride(1),
625
- K,
626
- triton.next_power_of_2(K),
627
- (dlogits is None),
628
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
 
631
  @triton.jit
 
9
  import torch
10
  import triton
11
  import triton.language as tl
12
+ from ..quack.gemm_interface import gemm, gemm_dgated
13
 
14
  from .._ops_compat import add_op_namespace_prefix
15
+ from ..utils import get_powers_of_2
 
 
 
 
 
 
 
16
  from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
17
 
18
 
 
126
  )
127
  @triton.jit
128
  def db1_kernel(
129
+ dh_ptr, # (TK, I) — always interleaved
130
+ db1_ptr, # (E, I)
131
+ expert_offset_ptr, # (E+1,)
132
  I: tl.constexpr,
133
  E: tl.constexpr,
134
+ BLOCK_I: tl.constexpr,
135
+ BLOCK_TK: tl.constexpr,
136
+ CONCAT_LAYOUT: tl.constexpr = False,
137
  ):
138
+ Eidx = tl.program_id(0)
139
 
140
  E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64)
141
  E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64)
142
  n_tokens = E_count_end - E_count_start
143
 
144
  NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I)
145
+ I_HALF: tl.constexpr = I // 2
146
  for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1):
147
  i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I)
148
  i_mask = i_offsets < I
149
 
150
  db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32)
151
 
 
152
  for block_start in tl.range(0, n_tokens, BLOCK_TK):
153
  # Token offsets within this block
154
  tk_offsets = block_start + tl.arange(0, BLOCK_TK)
 
157
 
158
  dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :]
159
  dz_mask = tk_mask[:, None] & i_mask[None, :]
160
+ dz = tl.load(dh_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32)
161
 
162
+ db1_acc += tl.sum(dz, axis=0)
163
 
164
+ # Write: remap interleaved concat if needed
165
+ if CONCAT_LAYOUT:
166
+ out_offsets = i_offsets // 2 + (i_offsets % 2) * I_HALF
167
+ else:
168
+ out_offsets = i_offsets
169
+ db1_offsets = Eidx.to(tl.int64) * I + out_offsets
170
  tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask)
171
 
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  @torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"})
174
  def _up_projection_backward_act(
175
  w1: torch.Tensor,
176
  dx_expanded: torch.Tensor,
177
+ dh: torch.Tensor,
178
  db1: torch.Tensor | None,
179
  expert_frequency_offset: torch.Tensor,
 
 
 
180
  is_glu_activation: bool,
181
+ concat_layout: bool = False,
182
  ) -> None:
183
  I, H, E = w1.size()
184
  if is_glu_activation:
185
  I //= 2
186
 
187
+ gemm(
188
+ dh,
189
+ w1.permute(2, 0, 1),
190
+ cu_seqlens_m=expert_frequency_offset,
191
+ dynamic_scheduler=False,
192
+ out=dx_expanded,
193
+ concat_layout=(("B",) if concat_layout else None),
194
+ )
195
+
196
  # db1 computation
197
  if db1 is not None:
198
+ db1_kernel[(E,)](
199
+ dh,
200
+ db1,
201
+ expert_frequency_offset,
202
+ (2 * I if is_glu_activation else I),
203
+ E,
204
+ CONCAT_LAYOUT=concat_layout and is_glu_activation,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  _up_projection_backward_act.compile_cache = {}
 
212
  def _up_projection_backward_weight(
213
  x: torch.Tensor,
214
  dw1: torch.Tensor,
215
+ dh: torch.Tensor,
216
  expert_frequency_offset: torch.Tensor,
 
217
  x_gather_idx: torch.Tensor,
218
  is_glu_activation: bool,
219
+ concat_layout: bool = False,
220
  ) -> None:
221
  I, H, E = dw1.size()
222
  if is_glu_activation:
223
  I //= 2
224
 
225
+ gemm(
226
+ x.T,
227
+ dh,
228
+ out=dw1.permute(2, 1, 0),
229
+ cu_seqlens_k=expert_frequency_offset,
230
+ A_idx=x_gather_idx,
231
+ batch_idx_permute=None,
232
+ dynamic_scheduler=False,
233
+ concat_layout=(("out",) if concat_layout else None),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
 
236
 
237
  _up_projection_backward_weight.compile_cache = {}
238
 
239
 
240
+ @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dh", "ds", "db2", "a_prime"})
241
  def _down_projection_backward_act(
242
  dout: torch.Tensor,
243
+ h: torch.Tensor,
244
  w2: torch.Tensor,
245
+ dh: torch.Tensor,
246
  ds: torch.Tensor,
247
  b2: torch.Tensor | None,
248
+ db2: torch.Tensor | None, # add impl later
249
+ a_prime: torch.Tensor,
250
  topk_scores: torch.Tensor,
251
  expert_frequency_offset: torch.Tensor,
 
252
  x_gather_idx: torch.Tensor,
253
  s_scatter_idx: torch.Tensor,
 
254
  activation_type: str,
 
255
  ) -> None:
256
+ assert activation_type in (
257
+ "swiglu",
258
+ "geglu",
259
+ ), f"QuACK gemm_gated only supports glu activations, got {activation_type}"
260
+
261
+ s = topk_scores[s_scatter_idx]
262
+ _, _, ds_scattered = gemm_dgated(
263
+ dout,
264
+ w2.permute(2, 0, 1),
265
+ PreAct=h,
266
+ activation=activation_type,
267
+ dx_out=dh,
268
+ postact_out=a_prime,
269
+ colvec_scale=s,
270
+ colvec_reduce=True,
271
+ cu_seqlens_m=expert_frequency_offset,
272
+ A_idx=x_gather_idx,
273
+ dynamic_scheduler=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  )
275
+ ds[s_scatter_idx] = ds_scattered
276
 
277
  if db2 is None:
278
+ ds[s_scatter_idx] = ds_scattered
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  else:
280
+ H = w2.size(0)
281
+ E = expert_frequency_offset.size(0) - 1
282
+ TK = x_gather_idx.size(0)
283
+
284
+ old_ds_partial = torch.empty(TK, 1, device=ds_scattered.device, dtype=ds_scattered.dtype)
285
+ old_ds_partial[s_scatter_idx, 0] = ds_scattered
286
+
287
  BLOCK_H = min(triton.next_power_of_2(H), 2048)
288
  NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H)
289
+ new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, dtype=torch.float32, device=ds.device)
 
290
 
291
  db2_and_ds_kernel[(E, NUM_H_BLOCKS)](
292
  dout,
293
  topk_scores,
294
  new_ds_partial,
295
+ old_ds_partial,
296
  b2,
297
  db2,
298
  x_gather_idx,
 
300
  expert_frequency_offset,
301
  H,
302
  E,
303
+ 1, # OLD_DS_PARTIAL_N = 1
304
  BLOCK_H=BLOCK_H,
305
+ BLOCK_OLD_DS_PARTIAL_N=1,
306
  )
307
 
308
  if NUM_H_BLOCKS == 1:
 
317
  @torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"})
318
  def _down_projection_backward_weight(
319
  dout: torch.Tensor,
320
+ a_prime: torch.Tensor,
321
  dw2: torch.Tensor,
322
  expert_frequency_offset: torch.Tensor,
 
323
  x_gather_idx: torch.Tensor,
 
324
  ) -> None:
325
+ gemm(
326
+ dout.T,
327
+ a_prime,
328
+ out=dw2.permute(2, 0, 1),
329
+ cu_seqlens_k=expert_frequency_offset,
330
+ A_idx=x_gather_idx,
331
+ batch_idx_permute=None,
332
+ dynamic_scheduler=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
334
 
335
 
 
362
 
363
 
364
  @triton.jit
365
+ def _softmax_over_topk_bwd_kernel(
366
  dlogits_ptr,
367
  dlogits_full_ptr,
368
  score_ptr,
 
402
  tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)
403
 
404
 
405
+ @triton.jit
406
+ def _topk_over_softmax_bwd_kernel(
407
+ logits_ptr, # (T, N) saved router logits
408
+ dlogits_ptr, # (T, N) output gradient
409
+ dscore_ptr, # (T, K) upstream gradient
410
+ idx_ptr, # (T, K) selected indices (int32)
411
+ score_ptr, # (T, K) forward scores (only used for renorm)
412
+ stride_lm: tl.constexpr,
413
+ stride_le: tl.constexpr,
414
+ stride_dm: tl.constexpr,
415
+ stride_dn: tl.constexpr,
416
+ stride_sm: tl.constexpr,
417
+ stride_sn: tl.constexpr,
418
+ stride_im: tl.constexpr,
419
+ stride_ik: tl.constexpr,
420
+ stride_scm: tl.constexpr,
421
+ stride_scn: tl.constexpr,
422
+ E: tl.constexpr,
423
+ K: tl.constexpr,
424
+ BLOCK_E: tl.constexpr,
425
+ BLOCK_K: tl.constexpr,
426
+ norm_topk_probs: tl.constexpr,
427
+ ):
428
+ """
429
+ Full topk(softmax()) backward over ALL E indices.
430
+
431
+ Forward: logits → p = softmax(logits) → [raw, idx] = topk(p, K)
432
+ → scores = raw / sum(raw) (if norm_topk_probs)
433
+
434
+ Backward:
435
+ 1. Recompute p = softmax(logits) over all E
436
+ 2. If renorm: dp_sel = (dscore - dot_s) / S
437
+ Else: dp_sel = dscore
438
+ 3. dot = Σ dp_sel_j * p_sel_j
439
+ 4. Scatter dp_sel into E-wide dp (zero at non-selected)
440
+ 5. dlogits = p * (dp - dot) for all E
441
+ """
442
+ row = tl.program_id(axis=0)
443
+
444
+ e_offs = tl.arange(0, BLOCK_E)
445
+ e_mask = e_offs < E
446
+ logits = tl.load(logits_ptr + row * stride_lm + e_offs * stride_le, mask=e_mask, other=-float("inf")).to(
447
+ tl.float32
448
+ )
449
+ row_max = tl.max(logits, axis=0)
450
+ exp_vals = tl.exp(logits - row_max)
451
+ row_sum = tl.sum(exp_vals, axis=0)
452
+ p = exp_vals / row_sum # (BLOCK_E,)
453
+
454
+ # --- Load K selected indices and upstream gradient ---
455
+ k_offs = tl.arange(0, BLOCK_K)
456
+ k_mask = k_offs < K
457
+ idx = tl.load(
458
+ idx_ptr + row * stride_im + k_offs * stride_ik,
459
+ mask=k_mask,
460
+ other=0,
461
+ ).to(tl.int32)
462
+ g_sel = tl.load(
463
+ dscore_ptr + row * stride_sm + k_offs * stride_sn,
464
+ mask=k_mask,
465
+ other=0,
466
+ ).to(tl.float32)
467
+
468
+ # p at selected indices (gather from global mem; can't index register tensor)
469
+ sel_logits = tl.load(
470
+ logits_ptr + row * stride_lm + idx * stride_le,
471
+ mask=k_mask,
472
+ other=-float("inf"),
473
+ ).to(tl.float32)
474
+ p_sel = tl.exp(sel_logits - row_max) / row_sum # (BLOCK_K,)
475
+
476
+ # --- Backward through optional renormalization ---
477
+ if norm_topk_probs:
478
+ scores = tl.load(
479
+ score_ptr + row * stride_scm + k_offs * stride_scn,
480
+ mask=k_mask,
481
+ other=0,
482
+ ).to(tl.float32)
483
+ dot_s = tl.sum(g_sel * scores, axis=0)
484
+ S = tl.sum(p_sel, axis=0)
485
+ dp_sel = (g_sel - dot_s) / S
486
+ else:
487
+ dp_sel = g_sel
488
+
489
+ # dot = Σ dp_sel_j * p_sel_j
490
+ dot = tl.sum(dp_sel * p_sel, axis=0)
491
+
492
+ # --- Scatter dp_sel into N-wide dp ---
493
+ # dp[i] = dp_sel[k] if i == idx[k], else 0
494
+ # Loop over K (unrolled at compile time since K is constexpr)
495
+ dp = tl.zeros([BLOCK_E], dtype=tl.float32)
496
+ for k_iter in tl.static_range(K):
497
+ cur_dp = tl.sum(tl.where(k_offs == k_iter, dp_sel, 0.0))
498
+ cur_idx = tl.sum(tl.where(k_offs == k_iter, idx, 0))
499
+ dp = tl.where(e_offs == cur_idx, cur_dp, dp)
500
+
501
+ # --- dlogits = p * (dp - dot) for all E ---
502
+ dlogits = p * (dp - dot)
503
+ tl.store(
504
+ dlogits_ptr + row * stride_dm + e_offs * stride_dn,
505
+ dlogits,
506
+ mask=e_mask,
507
+ )
508
+
509
+
510
+ @torch.library.custom_op(add_op_namespace_prefix("_topk_softmax_bwd"), mutates_args={"dlogits_full"})
511
+ def _topk_softmax_bwd(
512
+ router_logits: torch.Tensor,
513
  dlogits_full: torch.Tensor,
514
  dlogits: Optional[torch.Tensor],
515
  dtopk_score: torch.Tensor,
516
  topk_router_score: torch.Tensor,
517
  topk_router_indices: torch.Tensor,
518
+ E: int,
519
  K: int,
520
+ is_softmax_over_topk: bool = True,
521
+ norm_topk_probs: bool = False,
522
  ) -> None:
523
  T = dtopk_score.shape[0]
524
 
525
+ if is_softmax_over_topk:
526
+ # non-selected gradient is zero.
527
+ _softmax_over_topk_bwd_kernel[T,](
528
+ dlogits,
529
+ dlogits_full,
530
+ topk_router_score,
531
+ dtopk_score,
532
+ topk_router_indices,
533
+ dlogits_full.stride(0),
534
+ dlogits_full.stride(1),
535
+ topk_router_score.stride(0),
536
+ topk_router_score.stride(1),
537
+ dtopk_score.stride(0),
538
+ dtopk_score.stride(1),
539
+ topk_router_indices.stride(0),
540
+ topk_router_indices.stride(1),
541
+ K,
542
+ triton.next_power_of_2(K),
543
+ (dlogits is None),
544
+ )
545
+ else:
546
+ # topk(softmax(.)): non-selected gradient is -p_i * dot, NOT zero.
547
+ # must recompute full softmax for the complete Jacobian.
548
+ _topk_over_softmax_bwd_kernel[T,](
549
+ router_logits,
550
+ dlogits_full,
551
+ dtopk_score,
552
+ topk_router_indices,
553
+ topk_router_score,
554
+ router_logits.stride(0),
555
+ router_logits.stride(1),
556
+ dlogits_full.stride(0),
557
+ dlogits_full.stride(1),
558
+ dtopk_score.stride(0),
559
+ dtopk_score.stride(1),
560
+ topk_router_indices.stride(0),
561
+ topk_router_indices.stride(1),
562
+ topk_router_score.stride(0),
563
+ topk_router_score.stride(1),
564
+ E,
565
+ K,
566
+ triton.next_power_of_2(E),
567
+ triton.next_power_of_2(K),
568
+ norm_topk_probs,
569
+ )
570
 
571
 
572
  @triton.jit
build/torch-cuda/functional/forward.py CHANGED
@@ -9,18 +9,21 @@ import triton
9
  import triton.language as tl
10
  from cutlass.cute.runtime import from_dlpack
11
  from ..quack.cute_dsl_utils import torch2cute_dtype_map
 
12
 
13
- from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
14
  from .._ops_compat import add_op_namespace_prefix
15
- from ..utils import convert_torch_tensor_to_cute_tensor
16
- from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd
17
  from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
18
- from .topk_softmax import TopK_Softmax
19
 
20
 
21
  @torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
22
  def _topk_fwd(
23
- x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True
 
 
 
 
 
24
  ) -> None:
25
  """Top-k forward pass.
26
  Args:
@@ -39,9 +42,17 @@ def _topk_fwd(
39
 
40
  x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
41
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
42
- compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion)
 
 
 
 
43
  if compile_key not in _topk_fwd.compile_cache:
44
- topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion)
 
 
 
 
45
  _topk_fwd.compile_cache[compile_key] = cute.compile(
46
  topk_op, x_tensor, values_tensor, indices_tensor, current_stream
47
  )
@@ -51,129 +62,49 @@ def _topk_fwd(
51
  _topk_fwd.compile_cache = {}
52
 
53
 
54
- @torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"z", "y1"})
55
  def _up_projection_forward(
56
  x: torch.Tensor,
57
  w1: torch.Tensor,
58
- z: torch.Tensor,
59
- y1: torch.Tensor,
60
  b1: torch.Tensor | None,
61
  expert_frequency_offset: torch.Tensor,
62
- expert_schedule_order: torch.Tensor,
63
  x_gather_idx: torch.Tensor,
64
- stream_id: int,
65
  activation_type: str,
66
- is_glu_activation: bool,
67
  is_inference_mode_enabled: bool = False,
 
68
  ) -> None:
69
- I, H, E = w1.size()
70
- if is_glu_activation:
71
- I //= 2
72
-
73
- mX = convert_torch_tensor_to_cute_tensor(x.detach(), (0, 1), 1, 16, 8, stream=stream_id)
74
- mW1 = convert_torch_tensor_to_cute_tensor(w1.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
75
- mZ = convert_torch_tensor_to_cute_tensor(z, (0, 1), 1, 16, 8, stream=stream_id)
76
- mY1 = convert_torch_tensor_to_cute_tensor(y1, (0, 1), 1, 16, 8, stream=stream_id)
77
- mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
78
- mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
79
-
80
- if expert_schedule_order is None:
81
- mE_permute_order = None
82
- else:
83
- mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
84
-
85
- if b1 is None:
86
- mB1 = None
87
- else:
88
- mB1 = convert_torch_tensor_to_cute_tensor(b1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
89
-
90
- current_stream = cuda.CUstream(stream_id)
91
-
92
- compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled)
93
- if compile_w1_key not in _up_projection_forward.compile_cache:
94
- w1_module = HopperWgmma_MoE_Up_proj_Fwd(
95
- E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled
96
- )
97
- tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
98
- _up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
99
- w1_module,
100
- mX,
101
- mW1,
102
- mZ,
103
- mY1,
104
- mB1,
105
- mE_offset,
106
- mX_gather,
107
- tensormaps[0],
108
- tensormaps[1],
109
- mE_permute_order,
110
- current_stream,
111
- )
112
- _up_projection_forward.compile_cache[TENSORMAP] = tensormaps
113
-
114
- w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP]
115
- _up_projection_forward.compile_cache[compile_w1_key](
116
- mX,
117
- mW1,
118
- mZ,
119
- mY1,
120
- mB1,
121
- mE_offset,
122
- mX_gather,
123
- w1_tensormaps[0],
124
- w1_tensormaps[1],
125
- mE_permute_order,
126
- current_stream,
127
  )
128
 
129
 
130
  _up_projection_forward.compile_cache = {}
131
 
132
 
133
- @torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y2"})
134
  def _down_projection_forward(
135
  w2: torch.Tensor,
136
- y1: torch.Tensor,
137
- y2: torch.Tensor,
138
  b2: torch.Tensor | None,
139
  expert_frequency_offset: torch.Tensor,
140
- expert_schedule_order: torch.Tensor,
141
- x_gather_idx: torch.Tensor,
142
- stream_id: int,
143
  ) -> None:
144
- H, I, E = w2.size()
145
-
146
- mW2 = convert_torch_tensor_to_cute_tensor(w2.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id)
147
- mY1 = convert_torch_tensor_to_cute_tensor(y1.detach(), (0, 1), 1, 16, 8, stream=stream_id)
148
- mY2 = convert_torch_tensor_to_cute_tensor(y2, (0, 1), 1, 16, 8, stream=stream_id)
149
- mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id)
150
- mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id)
151
-
152
- if expert_schedule_order is None:
153
- mE_permute_order = None
154
- else:
155
- mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
156
-
157
- if b2 is None:
158
- mB2 = None
159
- else:
160
- mB2 = convert_torch_tensor_to_cute_tensor(b2.detach(), (0, 1), 1, 16, 8, stream=stream_id)
161
-
162
- current_stream = cuda.CUstream(stream_id)
163
-
164
- compile_w2_key = (E, H, I, (b2 is None), w2.dtype)
165
- if compile_w2_key not in _down_projection_forward.compile_cache:
166
- w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I)
167
- tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
168
- _down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
169
- w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
170
- )
171
- _down_projection_forward.compile_cache[TENSORMAP] = tensormaps
172
-
173
- w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP]
174
- _down_projection_forward.compile_cache[compile_w2_key](
175
- mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream
176
- )
177
 
178
 
179
  _down_projection_forward.compile_cache = {}
@@ -181,7 +112,7 @@ _down_projection_forward.compile_cache = {}
181
 
182
  @torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
183
  def _router_forward(
184
- y2: torch.Tensor,
185
  o: torch.Tensor,
186
  topk_scores: torch.Tensor,
187
  s_reverse_scatter_idx: torch.Tensor,
@@ -191,7 +122,7 @@ def _router_forward(
191
  is_varlen_K: bool,
192
  ) -> None:
193
  token_gather_and_sum_varlen_K_triton(
194
- y2,
195
  topk_scores,
196
  o,
197
  s_reverse_scatter_idx,
@@ -225,14 +156,35 @@ def _softmax_fwd_small_kernel(
225
  @torch.library.custom_op(
226
  add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
227
  )
228
- def _softmax_topk_fwd(
229
- router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int
 
 
 
 
 
 
230
  ) -> None:
231
- # T = router_logits.shape[0]
232
  if E <= 4096 and K <= 16 and E % 8 == 0:
233
- # fast topk-softmax fusion that covers most common MoE configs
234
- _topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True)
 
 
 
 
 
 
235
  else:
236
- topk_results = router_logits.topk(K, dim=-1)
237
- topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype))
238
- topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
 
 
 
 
 
 
 
 
 
 
 
9
  import triton.language as tl
10
  from cutlass.cute.runtime import from_dlpack
11
  from ..quack.cute_dsl_utils import torch2cute_dtype_map
12
+ from ..quack.gemm_interface import gemm, gemm_gated
13
 
 
14
  from .._ops_compat import add_op_namespace_prefix
 
 
15
  from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
16
+ from .topk import Softmax_Over_TopK, TopK_Over_Softmax
17
 
18
 
19
  @torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"})
20
  def _topk_fwd(
21
+ x: torch.Tensor,
22
+ k: int,
23
+ values: torch.Tensor,
24
+ indices: torch.Tensor,
25
+ is_softmax_over_topk: bool,
26
+ norm_topk_probs: bool,
27
  ) -> None:
28
  """Top-k forward pass.
29
  Args:
 
42
 
43
  x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
44
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
45
+ if is_softmax_over_topk:
46
+ compile_key = (input_dtype, output_dtype, N, k, True)
47
+ else:
48
+ compile_key = (input_dtype, output_dtype, N, k, False, norm_topk_probs)
49
+
50
  if compile_key not in _topk_fwd.compile_cache:
51
+ if is_softmax_over_topk:
52
+ topk_op = Softmax_Over_TopK(input_dtype, output_dtype, N, k)
53
+ else:
54
+ topk_op = TopK_Over_Softmax(input_dtype, output_dtype, N, k, norm_topk_probs)
55
+
56
  _topk_fwd.compile_cache[compile_key] = cute.compile(
57
  topk_op, x_tensor, values_tensor, indices_tensor, current_stream
58
  )
 
62
  _topk_fwd.compile_cache = {}
63
 
64
 
65
+ @torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"h", "a"})
66
  def _up_projection_forward(
67
  x: torch.Tensor,
68
  w1: torch.Tensor,
69
+ h: torch.Tensor,
70
+ a: torch.Tensor,
71
  b1: torch.Tensor | None,
72
  expert_frequency_offset: torch.Tensor,
 
73
  x_gather_idx: torch.Tensor,
 
74
  activation_type: str,
 
75
  is_inference_mode_enabled: bool = False,
76
+ concat_layout: bool = False,
77
  ) -> None:
78
+ assert activation_type in (
79
+ "swiglu",
80
+ "geglu",
81
+ ), f"QuACK gemm_gated only supports glu activations, got {activation_type}"
82
+ gemm_gated(
83
+ x,
84
+ w1.permute(2, 1, 0),
85
+ activation=activation_type,
86
+ cu_seqlens_m=expert_frequency_offset,
87
+ A_idx=x_gather_idx,
88
+ preact_out=h,
89
+ postact_out=a,
90
+ store_preact=(not is_inference_mode_enabled),
91
+ bias=b1,
92
+ concat_layout=(("B", "bias") if b1 is not None else ("B",)) if concat_layout else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
 
95
 
96
  _up_projection_forward.compile_cache = {}
97
 
98
 
99
+ @torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y"})
100
  def _down_projection_forward(
101
  w2: torch.Tensor,
102
+ a: torch.Tensor,
103
+ y: torch.Tensor,
104
  b2: torch.Tensor | None,
105
  expert_frequency_offset: torch.Tensor,
 
 
 
106
  ) -> None:
107
+ gemm(a, w2.permute(2, 1, 0), out=y, cu_seqlens_m=expert_frequency_offset, bias=b2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  _down_projection_forward.compile_cache = {}
 
112
 
113
  @torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"})
114
  def _router_forward(
115
+ y: torch.Tensor,
116
  o: torch.Tensor,
117
  topk_scores: torch.Tensor,
118
  s_reverse_scatter_idx: torch.Tensor,
 
122
  is_varlen_K: bool,
123
  ) -> None:
124
  token_gather_and_sum_varlen_K_triton(
125
+ y,
126
  topk_scores,
127
  o,
128
  s_reverse_scatter_idx,
 
156
  @torch.library.custom_op(
157
  add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"}
158
  )
159
+ def _topk_softmax_fwd(
160
+ router_logits: torch.Tensor,
161
+ topk_router_score: torch.Tensor,
162
+ topk_router_indices: torch.Tensor,
163
+ E: int,
164
+ K: int,
165
+ is_softmax_over_topk: bool,
166
+ norm_topk_probs: bool,
167
  ) -> None:
 
168
  if E <= 4096 and K <= 16 and E % 8 == 0:
169
+ _topk_fwd(
170
+ router_logits,
171
+ K,
172
+ topk_router_score,
173
+ topk_router_indices,
174
+ is_softmax_over_topk=is_softmax_over_topk,
175
+ norm_topk_probs=norm_topk_probs,
176
+ )
177
  else:
178
+ if is_softmax_over_topk:
179
+ topk_results = router_logits.topk(K, dim=-1)
180
+ vals = topk_results.values.softmax(dim=-1, dtype=torch.float32)
181
+ topk_router_score.copy_(vals.to(topk_router_score.dtype))
182
+ topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
183
+ else:
184
+ probs = router_logits.softmax(dim=-1, dtype=torch.float32)
185
+ topk_results = probs.topk(K, dim=-1)
186
+ vals = topk_results.values
187
+ if norm_topk_probs:
188
+ vals = vals / vals.sum(dim=-1, keepdim=True)
189
+ topk_router_score.copy_(vals.to(topk_router_score.dtype))
190
+ topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype))
build/torch-cuda/functional/grouped_gemm.py DELETED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/functional/moe_config.py DELETED
@@ -1,581 +0,0 @@
1
- # ********************************************************************************
2
- # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
- # ********************************************************************************
4
-
5
- import math
6
- from dataclasses import dataclass
7
-
8
- import cuda.bindings.driver as cuda
9
- import cutlass
10
- import cutlass.cute as cute
11
- import torch
12
- from cutlass import const_expr
13
- from ..quack.tile_scheduler import RasterOrderOption
14
-
15
- from ..enums import ActivationType, is_glu
16
- from .grouped_gemm import HopperWgmma_MoE_kernel
17
-
18
-
19
- LIBRARY_NAME = "cutedsl_kernels"
20
-
21
-
22
- def ceil_div(a: int, b: int):
23
- return int(math.ceil(a / b))
24
-
25
-
26
- @dataclass
27
- class HopperGEMMConfig:
28
- tile_shape_mnk: cutlass.Constexpr[cute.Shape] = (128, 256, 64)
29
- cluster_shape_mnk: cutlass.Constexpr[cute.Shape] = (2, 1)
30
- epi_tile_size: cutlass.Constexpr[int] = 32
31
- ## assume we always use persistent kernel
32
- # is_persistent: cutlass.Constexpr[bool] = True
33
- is_pingpong: cutlass.Constexpr[bool] = False
34
- raster_order: RasterOrderOption = RasterOrderOption.Heuristic
35
- L2_group_size: int = 8
36
- initial_d_epi_stage: cutlass.Constexpr[int] = 4
37
-
38
-
39
- class HopperWgmma_MoE_Up_proj_Fwd:
40
- def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False):
41
- super().__init__()
42
- is_glu_activation = is_glu(activation_type)
43
- if is_glu_activation:
44
- assert (
45
- H % 64 == 0 and H >= 512 and I % 64 == 0
46
- ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
47
- else:
48
- assert (
49
- H % 64 == 0 and H >= 512 and I % 128 == 0
50
- ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
51
- # TODO: this assertion does not mean that the MoE impl prohibits such config.
52
- # Instead, we just do not search for the best configs manually yet for small-shaped MoE
53
- if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
54
- up_config = HopperGEMMConfig(
55
- tile_shape_mnk=(128, 256, 64),
56
- cluster_shape_mnk=(2, 1),
57
- epi_tile_size=(32 if not inference_mode else 64),
58
- is_pingpong=False,
59
- initial_d_epi_stage=2,
60
- raster_order=RasterOrderOption.AlongM,
61
- )
62
- elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
63
- up_config = HopperGEMMConfig(
64
- tile_shape_mnk=(192, 128, 64),
65
- cluster_shape_mnk=(1, 1),
66
- epi_tile_size=(32 if not inference_mode else 64),
67
- is_pingpong=True,
68
- initial_d_epi_stage=8,
69
- raster_order=RasterOrderOption.AlongM,
70
- )
71
- else:
72
- raise NotImplementedError()
73
-
74
- compute_swiglu = False
75
- compute_geglu = False
76
- compute_reglu = False
77
-
78
- compute_relu_sq = False
79
- compute_silu = False
80
- compute_relu = False
81
- compute_gelu = False
82
-
83
- if activation_type == ActivationType.SWIGLU:
84
- compute_swiglu = True
85
- elif activation_type == ActivationType.GEGLU:
86
- compute_geglu = True
87
- elif activation_type == ActivationType.REGLU:
88
- compute_reglu = True
89
-
90
- elif activation_type == ActivationType.RELU_SQ:
91
- compute_relu_sq = True
92
- elif activation_type == ActivationType.RELU:
93
- compute_relu = True
94
- elif activation_type == ActivationType.SILU:
95
- compute_silu = True
96
- elif activation_type == ActivationType.GELU:
97
- compute_gelu = True
98
-
99
- else:
100
- raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
101
-
102
- self.module = HopperWgmma_MoE_kernel(
103
- E,
104
- cutlass.Float32,
105
- up_config.tile_shape_mnk,
106
- (*up_config.cluster_shape_mnk, 1),
107
- pingpong=up_config.is_pingpong,
108
- is_persistent=True,
109
- compute_swiglu=compute_swiglu,
110
- compute_reglu=compute_reglu,
111
- compute_geglu=compute_geglu,
112
- compute_relu_sq=compute_relu_sq,
113
- compute_relu=compute_relu,
114
- compute_silu=compute_silu,
115
- compute_gelu=compute_gelu,
116
- is_A_gather=True,
117
- epi_tile_size=up_config.epi_tile_size,
118
- initial_d_epi_stage=up_config.initial_d_epi_stage,
119
- inference_mode=inference_mode,
120
- )
121
- self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
122
- up_config.cluster_shape_mnk[0] * up_config.cluster_shape_mnk[1]
123
- )
124
- self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
125
-
126
- @cute.jit
127
- def __call__(
128
- self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream
129
- ):
130
- return self.module(
131
- mX,
132
- mW1,
133
- None,
134
- mB1,
135
- mZ,
136
- mY1,
137
- None,
138
- None,
139
- mE_offset,
140
- mX_gather,
141
- None,
142
- None,
143
- None,
144
- None,
145
- None,
146
- mD_tensormap,
147
- mY1_tensormap,
148
- None,
149
- mE_permute_order,
150
- const_expr(self.max_active_clusters),
151
- stream,
152
- )
153
-
154
-
155
- class HopperWgmma_MoE_Down_proj_Fwd:
156
- def __init__(self, E: int, H: int, I: int):
157
- super().__init__()
158
- assert (
159
- H % 64 == 0 and H >= 512 and I % 64 == 0
160
- ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
161
- if I >= 1024:
162
- down_config = HopperGEMMConfig(
163
- tile_shape_mnk=(128, 256, 64),
164
- cluster_shape_mnk=(2, 1),
165
- epi_tile_size=32,
166
- is_pingpong=False,
167
- initial_d_epi_stage=4,
168
- raster_order=RasterOrderOption.AlongN,
169
- )
170
- elif I >= 256:
171
- down_config = HopperGEMMConfig(
172
- tile_shape_mnk=(128, 192, 64),
173
- cluster_shape_mnk=(2, 1),
174
- epi_tile_size=(96 if H % 96 == 0 else 64),
175
- is_pingpong=True,
176
- initial_d_epi_stage=5,
177
- raster_order=RasterOrderOption.AlongN,
178
- )
179
- elif I >= 64:
180
- down_config = HopperGEMMConfig(
181
- tile_shape_mnk=(128, 192, 64),
182
- cluster_shape_mnk=(1, 2),
183
- epi_tile_size=64,
184
- is_pingpong=True,
185
- initial_d_epi_stage=8,
186
- raster_order=RasterOrderOption.AlongN,
187
- )
188
- else:
189
- raise NotImplementedError()
190
-
191
- self.module = HopperWgmma_MoE_kernel(
192
- E,
193
- cutlass.Float32,
194
- down_config.tile_shape_mnk,
195
- (*down_config.cluster_shape_mnk, 1),
196
- pingpong=down_config.is_pingpong,
197
- is_persistent=True,
198
- compute_swiglu=False,
199
- is_A_gather=False,
200
- epi_tile_size=down_config.epi_tile_size,
201
- initial_d_epi_stage=down_config.initial_d_epi_stage,
202
- )
203
- self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
204
- down_config.cluster_shape_mnk[0] * down_config.cluster_shape_mnk[1]
205
- )
206
-
207
- @cute.jit
208
- def __call__(self, mY1, mW2, mY2, mB2, mE_offset, mX_gather, mD_tensormap, mE_permute_order, stream):
209
- # we are not really using mX_gather in the Grouped GEMM,
210
- # but CuTe-DSL compiler disallows dynamic flow so we still need to pass this argument
211
- return self.module(
212
- mY1,
213
- mW2,
214
- None,
215
- mB2,
216
- mY2,
217
- None,
218
- None,
219
- None,
220
- mE_offset,
221
- mX_gather,
222
- None,
223
- None,
224
- None,
225
- None,
226
- None,
227
- mD_tensormap,
228
- None,
229
- None,
230
- mE_permute_order,
231
- const_expr(self.max_active_clusters),
232
- stream,
233
- )
234
-
235
-
236
- class HopperWgmma_MoE_Down_proj_ActGrad_Bwd:
237
- def __init__(self, E: int, H: int, I: int, activation_type: ActivationType):
238
- super().__init__()
239
- is_glu_activation = is_glu(activation_type)
240
- if is_glu_activation:
241
- assert (
242
- H % 64 == 0 and H >= 512 and I % 64 == 0
243
- ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
244
- else:
245
- assert (
246
- H % 64 == 0 and H >= 512 and I % 128 == 0
247
- ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
248
-
249
- # heavy register pressure due to pingpong + heavy epilogue
250
- # effectively no alternatives to this config
251
- dz_partial_ds_config = HopperGEMMConfig(
252
- tile_shape_mnk=(128, 128, 64),
253
- cluster_shape_mnk=(2, 1),
254
- epi_tile_size=32,
255
- initial_d_epi_stage=4,
256
- is_pingpong=True,
257
- raster_order=RasterOrderOption.Heuristic,
258
- )
259
-
260
- compute_swiglu = False
261
- compute_geglu = False
262
- compute_reglu = False
263
-
264
- compute_relu_sq = False
265
- compute_silu = False
266
- compute_relu = False
267
- compute_gelu = False
268
-
269
- if activation_type == ActivationType.SWIGLU:
270
- compute_swiglu = True
271
- elif activation_type == ActivationType.GEGLU:
272
- compute_geglu = True
273
- elif activation_type == ActivationType.REGLU:
274
- compute_reglu = True
275
-
276
- elif activation_type == ActivationType.RELU_SQ:
277
- compute_relu_sq = True
278
- elif activation_type == ActivationType.RELU:
279
- compute_relu = True
280
- elif activation_type == ActivationType.SILU:
281
- compute_silu = True
282
- elif activation_type == ActivationType.GELU:
283
- compute_gelu = True
284
-
285
- else:
286
- raise NotImplementedError(f"Activation function {activation_type} not supported yet!")
287
-
288
- self.module = HopperWgmma_MoE_kernel(
289
- E,
290
- cutlass.Float32,
291
- dz_partial_ds_config.tile_shape_mnk,
292
- (*dz_partial_ds_config.cluster_shape_mnk, 1),
293
- pingpong=dz_partial_ds_config.is_pingpong,
294
- is_persistent=True,
295
- compute_swiglu=compute_swiglu,
296
- compute_reglu=compute_reglu,
297
- compute_geglu=compute_geglu,
298
- compute_relu_sq=compute_relu_sq,
299
- compute_relu=compute_relu,
300
- compute_silu=compute_silu,
301
- compute_gelu=compute_gelu,
302
- compute_dz_and_partial_ds_and_y1s=True,
303
- is_A_gather=True,
304
- epi_tile_size=dz_partial_ds_config.epi_tile_size,
305
- initial_d_epi_stage=dz_partial_ds_config.initial_d_epi_stage,
306
- )
307
- self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
308
- dz_partial_ds_config.cluster_shape_mnk[0] * dz_partial_ds_config.cluster_shape_mnk[1]
309
- )
310
-
311
- @cute.jit
312
- def __call__(
313
- self,
314
- mDout,
315
- mW2_trans,
316
- mZ_FP32_if_GLU_else_BF16,
317
- mDz_FP32_if_GLU_else_BF16,
318
- mY1S,
319
- mS,
320
- mDS_partial,
321
- mE_offset,
322
- mX_gather,
323
- mS_scatter,
324
- tensormaps,
325
- mE_permute_order,
326
- stream,
327
- ):
328
- return self.module(
329
- mDout,
330
- mW2_trans,
331
- mZ_FP32_if_GLU_else_BF16,
332
- None,
333
- mDz_FP32_if_GLU_else_BF16,
334
- mY1S,
335
- mS,
336
- mDS_partial,
337
- mE_offset,
338
- mX_gather,
339
- None,
340
- mS_scatter,
341
- None,
342
- None,
343
- tensormaps[0],
344
- tensormaps[1],
345
- tensormaps[2],
346
- None,
347
- mE_permute_order,
348
- const_expr(self.max_active_clusters),
349
- stream,
350
- )
351
-
352
-
353
- class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd:
354
- def __init__(self, E: int, H: int, I: int):
355
- super().__init__()
356
- assert (
357
- H % 64 == 0 and H >= 512 and I % 64 == 0
358
- ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
359
-
360
- if I >= 128:
361
- dw2_config = HopperGEMMConfig(
362
- tile_shape_mnk=(128, 256, 64),
363
- cluster_shape_mnk=(2, 1),
364
- epi_tile_size=16,
365
- is_pingpong=False,
366
- initial_d_epi_stage=6,
367
- raster_order=RasterOrderOption.AlongN,
368
- )
369
- elif I == 64:
370
- dw2_config = HopperGEMMConfig(
371
- tile_shape_mnk=(64, 192, 64),
372
- cluster_shape_mnk=(2, 1),
373
- epi_tile_size=32,
374
- is_pingpong=True,
375
- initial_d_epi_stage=6,
376
- raster_order=RasterOrderOption.AlongN,
377
- )
378
- else:
379
- raise NotImplementedError()
380
-
381
- self.module = HopperWgmma_MoE_kernel(
382
- E,
383
- cutlass.Float32,
384
- dw2_config.tile_shape_mnk,
385
- (*dw2_config.cluster_shape_mnk, 1),
386
- pingpong=dw2_config.is_pingpong,
387
- is_persistent=True,
388
- compute_swiglu=False,
389
- compute_weight_gradient=True,
390
- compute_dz_and_partial_ds_and_y1s=False,
391
- is_A_gather=True,
392
- epi_tile_size=dw2_config.epi_tile_size,
393
- initial_d_epi_stage=dw2_config.initial_d_epi_stage,
394
- )
395
- self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
396
- dw2_config.cluster_shape_mnk[0] * dw2_config.cluster_shape_mnk[1]
397
- )
398
-
399
- @cute.jit
400
- def __call__(self, mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
401
- return self.module(
402
- mDout_trans,
403
- mY1S_trans,
404
- None,
405
- None,
406
- mDw2,
407
- None,
408
- None,
409
- None,
410
- mE_offset,
411
- mX_gather,
412
- None,
413
- None,
414
- None,
415
- tensormaps[0],
416
- None,
417
- None,
418
- None,
419
- None,
420
- mE_permute_order,
421
- const_expr(self.max_active_clusters),
422
- stream,
423
- )
424
-
425
-
426
- class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
427
- def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
428
- super().__init__()
429
- if is_glu_activation:
430
- assert (
431
- H % 64 == 0 and H >= 512 and I % 64 == 0
432
- ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
433
- else:
434
- assert (
435
- H % 64 == 0 and H >= 512 and I % 128 == 0
436
- ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
437
-
438
- if (I >= 512 and is_glu_activation) or (I >= 1024 and not is_glu_activation):
439
- dx_config = HopperGEMMConfig(
440
- tile_shape_mnk=(128, 256, 64),
441
- cluster_shape_mnk=(2, 1),
442
- epi_tile_size=32,
443
- is_pingpong=False,
444
- initial_d_epi_stage=4,
445
- raster_order=RasterOrderOption.AlongN,
446
- )
447
- elif (I >= 64 and is_glu_activation) or (I >= 128 and not is_glu_activation):
448
- dx_config = HopperGEMMConfig(
449
- tile_shape_mnk=(128, 192, 64),
450
- cluster_shape_mnk=(2, 1),
451
- epi_tile_size=64,
452
- is_pingpong=True,
453
- initial_d_epi_stage=8,
454
- raster_order=RasterOrderOption.AlongN,
455
- )
456
- else:
457
- raise NotImplementedError()
458
-
459
- self.module = HopperWgmma_MoE_kernel(
460
- E,
461
- cutlass.Float32,
462
- dx_config.tile_shape_mnk,
463
- (*dx_config.cluster_shape_mnk, 1),
464
- pingpong=dx_config.is_pingpong,
465
- is_persistent=True,
466
- compute_swiglu=False,
467
- compute_dz_and_partial_ds_and_y1s=False,
468
- is_A_gather=False,
469
- epi_tile_size=dx_config.epi_tile_size,
470
- )
471
-
472
- self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
473
- dx_config.cluster_shape_mnk[0] * dx_config.cluster_shape_mnk[1]
474
- )
475
- self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
476
-
477
- @cute.jit
478
- def __call__(
479
- self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream
480
- ):
481
- return self.module(
482
- mDz,
483
- mW1_trans,
484
- None,
485
- None,
486
- mDx_expanded,
487
- None,
488
- None,
489
- None,
490
- mE_offset,
491
- mX_gather,
492
- None,
493
- mS_scatter,
494
- None,
495
- None,
496
- None,
497
- tensormaps[0],
498
- tensormaps[1],
499
- None,
500
- mE_permute_order,
501
- const_expr(self.max_active_clusters),
502
- stream,
503
- )
504
-
505
-
506
- class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
507
- def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
508
- super().__init__()
509
- if is_glu_activation:
510
- assert (
511
- H % 64 == 0 and H >= 512 and I % 64 == 0
512
- ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0"
513
- else:
514
- assert (
515
- H % 64 == 0 and H >= 512 and I % 128 == 0
516
- ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0"
517
-
518
- if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation):
519
- dw1_config = HopperGEMMConfig(
520
- tile_shape_mnk=(128, 256, 64),
521
- cluster_shape_mnk=(2, 1),
522
- epi_tile_size=16,
523
- is_pingpong=False,
524
- initial_d_epi_stage=6,
525
- raster_order=RasterOrderOption.Heuristic,
526
- )
527
- elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation):
528
- dw1_config = HopperGEMMConfig(
529
- tile_shape_mnk=(256, 128, 64),
530
- cluster_shape_mnk=(2, 1),
531
- epi_tile_size=16,
532
- is_pingpong=False,
533
- initial_d_epi_stage=6,
534
- raster_order=RasterOrderOption.AlongN,
535
- )
536
- else:
537
- raise NotImplementedError()
538
-
539
- self.module = HopperWgmma_MoE_kernel(
540
- E,
541
- cutlass.Float32,
542
- dw1_config.tile_shape_mnk,
543
- (*dw1_config.cluster_shape_mnk, 1),
544
- pingpong=dw1_config.is_pingpong,
545
- is_persistent=True,
546
- compute_swiglu=False,
547
- compute_weight_gradient=True,
548
- compute_dz_and_partial_ds_and_y1s=False,
549
- is_A_gather=True,
550
- epi_tile_size=dw1_config.epi_tile_size,
551
- )
552
-
553
- self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters(
554
- dw1_config.cluster_shape_mnk[0] * dw1_config.cluster_shape_mnk[1]
555
- )
556
-
557
- @cute.jit
558
- def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
559
- return self.module(
560
- mX_trans,
561
- mDz_trans,
562
- None,
563
- None,
564
- mDw1_trans,
565
- None,
566
- None,
567
- None,
568
- mE_offset,
569
- mX_gather,
570
- None,
571
- None,
572
- None,
573
- tensormaps[0],
574
- None,
575
- None,
576
- None,
577
- None,
578
- mE_permute_order,
579
- const_expr(self.max_active_clusters),
580
- stream,
581
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/functional/reduction_over_k_gather.py CHANGED
@@ -11,9 +11,6 @@ import triton.language as tl
11
  from ..utils import get_powers_of_2
12
 
13
 
14
- ### This triton impl is equivalent as the cute-dsl impl shown above,
15
- # and also achieves similar memory bandwidth on H100 for large K and H.
16
- # However, for small K and H, this impl is better by autotuning so we use it as the default.
17
  def _get_triton_autotune_configs() -> list[triton.Config]:
18
  configs = []
19
  for BLOCK_H in get_powers_of_2(256, 4096):
 
11
  from ..utils import get_powers_of_2
12
 
13
 
 
 
 
14
  def _get_triton_autotune_configs() -> list[triton.Config]:
15
  configs = []
16
  for BLOCK_H in get_powers_of_2(256, 4096):
build/torch-cuda/functional/{topk_softmax.py → topk.py} RENAMED
@@ -4,12 +4,14 @@
4
 
5
  # this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py
6
  import math
 
7
  from typing import Type
8
 
9
  import cuda.bindings.driver as cuda
10
  import cutlass
11
  import cutlass.cute as cute
12
- from ..quack import utils
 
13
  from cutlass import const_expr
14
  from ..quack.sort.bitonic_sort import bitonic_topk
15
  from triton import next_power_of_2
@@ -17,14 +19,23 @@ from triton import next_power_of_2
17
  from ..utils import domain_offset_i64
18
 
19
 
20
- class TopK_Softmax:
 
 
 
 
 
 
 
 
21
  def __init__(
22
  self,
23
  input_dtype: Type[cutlass.Numeric],
24
  output_dtype: Type[cutlass.Numeric],
25
  N: int,
26
  k: int,
27
- require_softmax_fusion: bool = True,
 
28
  ):
29
  self.input_dtype = input_dtype
30
  self.output_dtype = output_dtype
@@ -38,11 +49,13 @@ class TopK_Softmax:
38
  assert N <= 4096 and N % 8 == 0
39
  assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth"
40
 
41
- self.require_softmax_fusion = require_softmax_fusion
 
 
 
 
42
 
43
  def _calculate_threads_per_row(self):
44
- # we want num_elems_per_thread >= self.k
45
- # and each thread can handle at most 64 elements
46
  N = self.next_power_of_2_N
47
  num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
48
  return num_threads_per_row
@@ -78,7 +91,7 @@ class TopK_Softmax:
78
  output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize)
79
 
80
  num_threads = cute.size(input_tv_layout, mode=[0])
81
- self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout, output_tiler_mn).launch(
82
  grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1],
83
  block=[num_threads, 1, 1],
84
  stream=stream,
@@ -93,7 +106,6 @@ class TopK_Softmax:
93
  input_tv_layout: cute.Layout,
94
  input_tiler_mn: cute.Shape,
95
  output_tv_layout: cute.Layout,
96
- output_tiler_mn: cute.Shape,
97
  ):
98
  tidx, _, _ = cute.arch.thread_idx()
99
  bidx, _, _ = cute.arch.block_idx()
@@ -106,7 +118,6 @@ class TopK_Softmax:
106
  gX = cute.local_tile(mX, input_tiler_mn, (0, 0))
107
  cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0))
108
 
109
- # declare the atoms which will be used later for memory copy
110
  copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
111
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx)
112
  tXgX = thr_copy_X.partition_S(gX)
@@ -117,7 +128,7 @@ class TopK_Softmax:
117
 
118
  is_even_N = const_expr(shape[1] == input_tiler_mn[1])
119
  tXpX = (
120
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
121
  if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N))
122
  else None
123
  )
@@ -126,7 +137,67 @@ class TopK_Softmax:
126
  tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32)
127
  tXrX_f32.store(tXrX.load().to(cutlass.Float32))
128
 
129
- # Encode the indices into the bottom bits of values.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  log_N = int(math.log2(self.next_power_of_2_N))
131
  idx_mask = const_expr((1 << log_N) - 1)
132
  input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0])
@@ -162,7 +233,8 @@ class TopK_Softmax:
162
  col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
163
  topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
164
 
165
- if const_expr(self.require_softmax_fusion):
 
166
  topk_vals_max = -cutlass.Float32.inf
167
  for i in cutlass.range_constexpr(self.k):
168
  topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max)
@@ -175,7 +247,18 @@ class TopK_Softmax:
175
  for i in cutlass.range_constexpr(self.k):
176
  topk_vals[i] = topk_vals[i] / topk_exp_sum
177
 
178
- # Convert cleaned values to output type
 
 
 
 
 
 
 
 
 
 
 
179
  topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type)
180
  for i in cutlass.range_constexpr(self.k):
181
  topk_vals_out[i] = topk_vals[i].to(mValues.element_type)
@@ -193,3 +276,65 @@ class TopK_Softmax:
193
  for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])):
194
  cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
195
  cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py
6
  import math
7
+ from enum import Enum
8
  from typing import Type
9
 
10
  import cuda.bindings.driver as cuda
11
  import cutlass
12
  import cutlass.cute as cute
13
+ from ..quack import copy_utils as copy_utils
14
+ from ..quack import utils as utils
15
  from cutlass import const_expr
16
  from ..quack.sort.bitonic_sort import bitonic_topk
17
  from triton import next_power_of_2
 
19
  from ..utils import domain_offset_i64
20
 
21
 
22
+ class _TopKMode(Enum):
23
+ SOFTMAX_OVER_TOPK = "softmax_over_topk" # most common choice: softmax(topk(x))
24
+ TOPK_OVER_SOFTMAX = "topk_over_softmax" # Qwen3: topk(softmax(x))
25
+ TOPK_NO_FUSION = "topk"
26
+
27
+
28
+ class _TopK:
29
+ """Private base class. Use TopK_Softmax, Softmax_TopK, or TopK instead."""
30
+
31
  def __init__(
32
  self,
33
  input_dtype: Type[cutlass.Numeric],
34
  output_dtype: Type[cutlass.Numeric],
35
  N: int,
36
  k: int,
37
+ mode: _TopKMode,
38
+ norm_topk_prob: bool = False,
39
  ):
40
  self.input_dtype = input_dtype
41
  self.output_dtype = output_dtype
 
49
  assert N <= 4096 and N % 8 == 0
50
  assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth"
51
 
52
+ self.mode = mode
53
+ if norm_topk_prob:
54
+ assert mode == _TopKMode.TOPK_OVER_SOFTMAX, "`norm_topk_prob` only works with softmax-then-topk"
55
+
56
+ self.norm_topk_prob = norm_topk_prob
57
 
58
  def _calculate_threads_per_row(self):
 
 
59
  N = self.next_power_of_2_N
60
  num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
61
  return num_threads_per_row
 
91
  output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize)
92
 
93
  num_threads = cute.size(input_tv_layout, mode=[0])
94
+ self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout).launch(
95
  grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1],
96
  block=[num_threads, 1, 1],
97
  stream=stream,
 
106
  input_tv_layout: cute.Layout,
107
  input_tiler_mn: cute.Shape,
108
  output_tv_layout: cute.Layout,
 
109
  ):
110
  tidx, _, _ = cute.arch.thread_idx()
111
  bidx, _, _ = cute.arch.block_idx()
 
118
  gX = cute.local_tile(mX, input_tiler_mn, (0, 0))
119
  cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0))
120
 
 
121
  copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
122
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx)
123
  tXgX = thr_copy_X.partition_S(gX)
 
128
 
129
  is_even_N = const_expr(shape[1] == input_tiler_mn[1])
130
  tXpX = (
131
+ copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
132
  if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N))
133
  else None
134
  )
 
137
  tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32)
138
  tXrX_f32.store(tXrX.load().to(cutlass.Float32))
139
 
140
+ # ------------------------------------------------------------------
141
+ # Softmax-then-TopK: full-row softmax → in-place log-prob transform.
142
+ # ------------------------------------------------------------------
143
+ if const_expr(self.mode == _TopKMode.TOPK_OVER_SOFTMAX):
144
+ if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)):
145
+ utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
146
+
147
+ threads_per_row_red = const_expr(self._calculate_threads_per_row())
148
+ num_threads_cta = const_expr(128 if self.next_power_of_2_N <= 16384 else 256)
149
+
150
+ # ---- thread-local (max, sum_exp) pair ----
151
+ local_max = -cutlass.Float32.inf
152
+ for i in cutlass.range_constexpr(cute.size(tXrX_f32)):
153
+ local_max = cute.arch.fmax(tXrX_f32[i], local_max)
154
+
155
+ local_sum = cutlass.Float32(0.0)
156
+ for i in cutlass.range_constexpr(cute.size(tXrX_f32)):
157
+ local_sum = local_sum + cute.math.exp(tXrX_f32[i] - local_max)
158
+
159
+ if const_expr(threads_per_row_red == 1):
160
+ row_max = local_max
161
+ row_sum = local_sum
162
+ else:
163
+ smem = cutlass.utils.SmemAllocator()
164
+ smem_layout = cute.make_ordered_layout((num_threads_cta,), order=(0,))
165
+ smem_max = smem.allocate_tensor(
166
+ cutlass.Float32,
167
+ smem_layout,
168
+ byte_alignment=16,
169
+ )
170
+ smem_sum = smem.allocate_tensor(
171
+ cutlass.Float32,
172
+ smem_layout,
173
+ byte_alignment=16,
174
+ )
175
+ row_in_blk = tidx // threads_per_row_red
176
+
177
+ smem_max[tidx] = local_max
178
+ smem_sum[tidx] = local_sum
179
+ cute.arch.barrier()
180
+
181
+ # Peel first partner: no exp needed
182
+ base = row_in_blk * threads_per_row_red
183
+ row_max = smem_max[base]
184
+ row_sum = smem_sum[base]
185
+
186
+ for p in cutlass.range_constexpr(1, self._calculate_threads_per_row()):
187
+ p_max = smem_max[base + p]
188
+ p_sum = smem_sum[base + p]
189
+ if p_max > row_max:
190
+ row_sum = row_sum * cute.math.exp(row_max - p_max) + p_sum
191
+ row_max = p_max
192
+ else:
193
+ row_sum = row_sum + p_sum * cute.math.exp(p_max - row_max)
194
+
195
+ # In-place logit → log-probability
196
+ log_normalizer = row_max + cute.math.log(row_sum)
197
+ for i in cutlass.range_constexpr(cute.size(tXrX_f32)):
198
+ tXrX_f32[i] = tXrX_f32[i] - log_normalizer
199
+
200
+ # Encode indices into mantissa low bits.
201
  log_N = int(math.log2(self.next_power_of_2_N))
202
  idx_mask = const_expr((1 << log_N) - 1)
203
  input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0])
 
233
  col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
234
  topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
235
 
236
+ # TopK-then-Softmax
237
+ if const_expr(self.mode == _TopKMode.SOFTMAX_OVER_TOPK):
238
  topk_vals_max = -cutlass.Float32.inf
239
  for i in cutlass.range_constexpr(self.k):
240
  topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max)
 
247
  for i in cutlass.range_constexpr(self.k):
248
  topk_vals[i] = topk_vals[i] / topk_exp_sum
249
 
250
+ # Softmax-then-TopK: recover probabilities from log-probs.
251
+ if const_expr(self.mode == _TopKMode.TOPK_OVER_SOFTMAX):
252
+ for i in cutlass.range_constexpr(self.k):
253
+ topk_vals[i] = cute.math.exp(topk_vals[i])
254
+
255
+ if const_expr(self.norm_topk_prob):
256
+ topk_sum = cutlass.Float32(0.0)
257
+ for i in cutlass.range_constexpr(self.k):
258
+ topk_sum = topk_sum + topk_vals[i]
259
+ for i in cutlass.range_constexpr(self.k):
260
+ topk_vals[i] = topk_vals[i] / topk_sum
261
+
262
  topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type)
263
  for i in cutlass.range_constexpr(self.k):
264
  topk_vals_out[i] = topk_vals[i].to(mValues.element_type)
 
276
  for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])):
277
  cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
278
  cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
279
+
280
+
281
+ class Softmax_Over_TopK(_TopK):
282
+ """softmax(topk(x))"""
283
+
284
+ def __init__(
285
+ self,
286
+ input_dtype: Type[cutlass.Numeric],
287
+ output_dtype: Type[cutlass.Numeric],
288
+ N: int,
289
+ k: int,
290
+ ):
291
+ mode = _TopKMode.SOFTMAX_OVER_TOPK
292
+ super().__init__(
293
+ input_dtype=input_dtype,
294
+ output_dtype=output_dtype,
295
+ N=N,
296
+ k=k,
297
+ mode=mode,
298
+ )
299
+
300
+
301
+ class TopK_Over_Softmax(_TopK):
302
+ """Qwen3: topk(softmax(x))
303
+ When norm_topk_prob=True, renormalizes the K selected probabilities to sum to 1.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ input_dtype: Type[cutlass.Numeric],
309
+ output_dtype: Type[cutlass.Numeric],
310
+ N: int,
311
+ k: int,
312
+ norm_topk_prob: bool = True,
313
+ ):
314
+ super().__init__(
315
+ input_dtype=input_dtype,
316
+ output_dtype=output_dtype,
317
+ N=N,
318
+ k=k,
319
+ mode=_TopKMode.TOPK_OVER_SOFTMAX,
320
+ norm_topk_prob=norm_topk_prob,
321
+ )
322
+
323
+
324
+ class TopK(_TopK):
325
+ """Raw topk — no softmax."""
326
+
327
+ def __init__(
328
+ self,
329
+ input_dtype: Type[cutlass.Numeric],
330
+ output_dtype: Type[cutlass.Numeric],
331
+ N: int,
332
+ k: int,
333
+ ):
334
+ super().__init__(
335
+ input_dtype=input_dtype,
336
+ output_dtype=output_dtype,
337
+ N=N,
338
+ k=k,
339
+ mode=_TopKMode.TOPK_NO_FUSION,
340
+ )
build/torch-cuda/functional/utils.py DELETED
@@ -1,25 +0,0 @@
1
- # ********************************************************************************
2
- # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
3
- # ********************************************************************************
4
-
5
- import os
6
- from contextlib import contextmanager
7
-
8
-
9
- _IS_USING_QUACK_GEMM = os.getenv("USE_QUACK_GEMM", "0") == "1"
10
-
11
-
12
- @contextmanager
13
- def enable_quack_gemm(enable: bool = True):
14
- global _IS_USING_QUACK_GEMM
15
-
16
- previous_value = _IS_USING_QUACK_GEMM
17
- _IS_USING_QUACK_GEMM = enable
18
-
19
- yield
20
-
21
- _IS_USING_QUACK_GEMM = previous_value
22
-
23
-
24
- def is_using_quack_gemm() -> bool:
25
- return _IS_USING_QUACK_GEMM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/metadata.json CHANGED
@@ -1,7 +1,9 @@
1
  {
 
2
  "version": 1,
3
  "license": "Apache-2.0",
4
  "python-depends": [
 
5
  "nvidia-cutlass-dsl"
6
  ],
7
  "backend": {
 
1
  {
2
+ "id": "_sonic_moe_cuda_a8c39a2",
3
  "version": 1,
4
  "license": "Apache-2.0",
5
  "python-depends": [
6
+ "tvm-ffi",
7
  "nvidia-cutlass-dsl"
8
  ],
9
  "backend": {
build/torch-cuda/quack/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- __version__ = "0.2.5"
2
 
3
  import os
4
 
5
  if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
6
- from . import cute_dsl_ptxas
7
 
8
  cute_dsl_ptxas.patch()
 
1
+ __version__ = "0.3.11"
2
 
3
  import os
4
 
5
  if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
6
+ from . import cute_dsl_ptxas # noqa: F401
7
 
8
  cute_dsl_ptxas.patch()
build/torch-cuda/quack/_compile_worker.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ # Persistent subprocess worker for parallel autotuning pre-compilation.
3
+ # Receives length-prefixed pickled tasks on stdin, creates FakeTensors
4
+ # matching the parent's tensor metadata, and compiles with COMPILE_ONLY=True.
5
+ # Stays alive to process multiple configs (amortizes import overhead).
6
+
7
+ import importlib
8
+ import pickle
9
+ import struct
10
+ import sys
11
+
12
+ import torch
13
+ from torch._subclasses.fake_tensor import FakeTensorMode
14
+
15
+ from . import cache_utils
16
+
17
+ cache_utils.COMPILE_ONLY = True
18
+
19
+ _dtype_map = {
20
+ "torch.float16": torch.float16,
21
+ "torch.bfloat16": torch.bfloat16,
22
+ "torch.float32": torch.float32,
23
+ "torch.float64": torch.float64,
24
+ "torch.int32": torch.int32,
25
+ "torch.int64": torch.int64,
26
+ "torch.int8": torch.int8,
27
+ "torch.uint8": torch.uint8,
28
+ "torch.bool": torch.bool,
29
+ }
30
+
31
+
32
+ def _make_fake_tensor(meta):
33
+ shape = meta["shape"]
34
+ stride = meta["stride"]
35
+ dtype = _dtype_map[meta["dtype"]]
36
+ return torch.empty_strided(shape, stride, dtype=dtype, device="cuda")
37
+
38
+
39
+ def _recv(stream):
40
+ """Read a length-prefixed pickled message. Returns None on EOF."""
41
+ header = stream.read(4)
42
+ if len(header) < 4:
43
+ return None
44
+ length = struct.unpack("<I", header)[0]
45
+ if length == 0:
46
+ return None
47
+ data = stream.read(length)
48
+ return pickle.loads(data)
49
+
50
+
51
+ def _send(stream, msg):
52
+ """Write a length-prefixed pickled message."""
53
+ data = pickle.dumps(msg)
54
+ stream.write(struct.pack("<I", len(data)))
55
+ stream.write(data)
56
+ stream.flush()
57
+
58
+
59
+ def main():
60
+ stdin = sys.stdin.buffer
61
+ stdout = sys.stdout.buffer
62
+
63
+ # Signal ready
64
+ _send(stdout, "READY")
65
+
66
+ fn_cache = {}
67
+ while True:
68
+ payload = _recv(stdin)
69
+ if payload is None:
70
+ break
71
+
72
+ fn_module = payload["fn_module"]
73
+ fn_qualname = payload["fn_qualname"]
74
+ fn_key = (fn_module, fn_qualname)
75
+ if fn_key not in fn_cache:
76
+ mod = importlib.import_module(fn_module)
77
+ obj = mod
78
+ for part in fn_qualname.split("."):
79
+ obj = getattr(obj, part)
80
+ fn_cache[fn_key] = getattr(obj, "fn", obj)
81
+ fn = fn_cache[fn_key]
82
+
83
+ tensor_meta = payload["tensor_meta"]
84
+ kwargs = payload["kwargs"]
85
+ config_kwargs = payload["config_kwargs"]
86
+
87
+ with FakeTensorMode():
88
+ fake_args = []
89
+ for meta in tensor_meta:
90
+ if isinstance(meta, dict) and "shape" in meta:
91
+ fake_args.append(_make_fake_tensor(meta))
92
+ else:
93
+ fake_args.append(meta)
94
+ try:
95
+ fn(*fake_args, **kwargs, **config_kwargs)
96
+ _send(stdout, "OK")
97
+ except Exception as e:
98
+ _send(stdout, f"ERR:{e}")
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()
build/torch-cuda/quack/activation.py CHANGED
@@ -2,18 +2,24 @@
2
 
3
  import math
4
  from typing import Tuple
 
5
 
6
  import cutlass.cute as cute
7
  from cutlass import Float32, Boolean, const_expr
8
  from cutlass.cutlass_dsl import T, dsl_user_op
9
- from cutlass._mlir.dialects import llvm
10
-
11
- from . import utils as utils
12
 
13
 
14
  F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
15
 
16
 
 
 
 
 
 
 
 
17
  @dsl_user_op
18
  def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
  return Float32(
@@ -24,7 +30,6 @@ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
24
  "=f,f",
25
  has_side_effects=False,
26
  is_align_stack=False,
27
- asm_dialect=llvm.AsmDialect.AD_ATT,
28
  )
29
  )
30
 
@@ -35,9 +40,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
35
  # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
  return 0.5 + 0.5 * tanh(0.5 * x)
37
  else:
38
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
39
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
- return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
41
 
42
 
43
  @dsl_user_op
@@ -75,7 +80,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
75
  return cute.arch.fmax(x, Float32(0.0)) * x
76
  else:
77
  relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
- return utils.mul_packed_f32x2(relu_x, x)
79
 
80
 
81
  @dsl_user_op
@@ -98,8 +103,8 @@ def drelu_sq(
98
  return dx, relu_sq_out
99
  else:
100
  relu_x = relu(x)
101
- relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
- dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
103
  return dx, relu_sq_out
104
 
105
 
@@ -119,14 +124,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
119
  * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
  )
121
  else:
122
- x_sq = utils.mul_packed_f32x2(x, x)
123
- x_sq_scaled = utils.fma_packed_f32x2(
124
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
  )
126
- z = utils.mul_packed_f32x2(x, x_sq_scaled)
127
  tanh_z = (tanh(z[0]), tanh(z[1]))
128
- x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
- return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
130
 
131
 
132
  @dsl_user_op
@@ -167,28 +172,28 @@ def dgelu_tanh_approx(
167
  return dx, gelu_out
168
  else:
169
  # Compute z = x * (c1 + c2 * x^2)
170
- x_sq = utils.mul_packed_f32x2(x, x)
171
- x_sq_scaled = utils.fma_packed_f32x2(
172
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
  )
174
- z = utils.mul_packed_f32x2(x, x_sq_scaled)
175
  tanh_z = (tanh(z[0]), tanh(z[1]))
176
- half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
- gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
 
179
  # Compute gradient
180
  # sech^2(z) = 1 - tanh^2(z)
181
- sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
  # dz/dx = c1 + 3 * c2 * x^2
183
- dz_dx = utils.fma_packed_f32x2(
184
  x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
  )
186
  # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
- sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
- x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
- dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
 
191
- dx = utils.mul_packed_f32x2(dout, dgelu)
192
  return dx, gelu_out
193
 
194
 
@@ -204,15 +209,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
204
  )
205
  else:
206
  log2_e = math.log2(math.e)
207
- x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
208
  x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
- x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
210
  log_x_exp_p1 = (
211
  cute.math.log2(x_exp_p1[0], fastmath=True),
212
  cute.math.log2(x_exp_p1[1], fastmath=True),
213
  )
214
  ln2 = math.log(2.0)
215
- softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
  use_linear_0 = Boolean(x[0] > 20.0)
217
  use_linear_1 = Boolean(x[1] > 20.0)
218
  return (
@@ -241,9 +246,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) ->
241
  # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
  return x_half * tanh(x_half) + x_half
243
  else:
244
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
- return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
247
 
248
 
249
  @dsl_user_op
@@ -251,7 +256,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32
251
  if const_expr(not isinstance(x, tuple)):
252
  return silu(x) * y
253
  else:
254
- return utils.mul_packed_f32x2(silu(x), y)
255
 
256
 
257
  @dsl_user_op
@@ -301,20 +306,22 @@ def dswiglu(
301
  # Compute sigmoid(x) and silu(x)
302
  if const_expr(not already_halved):
303
  sigmoid_x = sigmoid(x)
304
- silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
305
  else:
306
  tanh_x = (tanh(x[0]), tanh(x[1]))
307
- sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
- silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
- silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
310
  # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
- sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
312
  sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
  )
314
- d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
- dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
 
 
316
  dy = silu_x_dout
317
- swiglu_out = utils.mul_packed_f32x2(silu_x, y)
318
  return dx, dy, swiglu_out
319
 
320
 
@@ -334,11 +341,11 @@ def swiglu_oai(
334
  silu_x = x_half * tanh(alpha * x_half) + x_half
335
  return silu_x * y + silu_x
336
  else:
337
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
- alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
339
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
- silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
- return utils.fma_packed_f32x2(silu_x, y, silu_x)
342
 
343
 
344
  @dsl_user_op
@@ -370,22 +377,22 @@ def dswiglu_oai(
370
  return dx, dy, swiglu_out
371
  else:
372
  # Compute sigmoid(alpha * x)
373
- alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
- sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
- silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
- silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
378
  # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
- silu_x_minus_product = utils.fma_packed_f32x2(
380
  silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
  )
382
- sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
383
  (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
  )
385
- d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
- dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
  dy = silu_x_dout
388
- swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
389
  return dx, dy, swiglu_out
390
 
391
 
@@ -400,7 +407,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
400
  return sigmoid_x * y # FMUL
401
  else:
402
  sigmoid_x = sigmoid(x)
403
- return utils.mul_packed_f32x2(sigmoid_x, y)
404
 
405
 
406
  @dsl_user_op
@@ -430,11 +437,11 @@ def dglu(
430
  return dx, dy, glu_out
431
  else:
432
  sigmoid_x = sigmoid(x)
433
- sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
- glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
435
  # dx = (y - glu_out) * sigmoid_x_dout
436
- y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
- dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
  dy = sigmoid_x_dout
439
  return dx, dy, glu_out
440
 
@@ -448,7 +455,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
448
  return cute.arch.fmax(x, Float32(0.0)) * y
449
  else:
450
  relu_x = relu(x)
451
- return utils.mul_packed_f32x2(relu_x, y)
452
 
453
 
454
  @dsl_user_op
@@ -475,10 +482,10 @@ def dreglu(
475
  x0_pos = Boolean(x[0] > 0)
476
  x1_pos = Boolean(x[1] > 0)
477
  relu_x = relu(x)
478
- dout_y = utils.mul_packed_f32x2(dout, y)
479
  dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
- dy = utils.mul_packed_f32x2(dout, relu_x)
481
- reglu_out = utils.mul_packed_f32x2(relu_x, y)
482
  return dx, dy, reglu_out
483
 
484
 
@@ -491,7 +498,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
491
  if const_expr(not isinstance(x, tuple)):
492
  return gelu_tanh_approx(x) * y
493
  else:
494
- return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
495
 
496
 
497
  @dsl_user_op
@@ -518,7 +525,43 @@ def dgeglu(
518
  # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
  dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
  # Compute gradients for geglu
521
- dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
- dy = utils.mul_packed_f32x2(gelu_x, dout)
523
- geglu_out = utils.mul_packed_f32x2(gelu_x, y)
524
  return dx, dy, geglu_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import math
4
  from typing import Tuple
5
+ from functools import partial
6
 
7
  import cutlass.cute as cute
8
  from cutlass import Float32, Boolean, const_expr
9
  from cutlass.cutlass_dsl import T, dsl_user_op
10
+ from cutlass._mlir.dialects import llvm, nvvm
 
 
11
 
12
 
13
  F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
14
 
15
 
16
+ sub_packed_f32x2 = partial(
17
+ cute.arch.calc_packed_f32x2_op,
18
+ src_c=None,
19
+ calc_func=nvvm.sub_packed_f32x2,
20
+ )
21
+
22
+
23
  @dsl_user_op
24
  def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
25
  return Float32(
 
30
  "=f,f",
31
  has_side_effects=False,
32
  is_align_stack=False,
 
33
  )
34
  )
35
 
 
40
  # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
41
  return 0.5 + 0.5 * tanh(0.5 * x)
42
  else:
43
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
44
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
45
+ return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
46
 
47
 
48
  @dsl_user_op
 
80
  return cute.arch.fmax(x, Float32(0.0)) * x
81
  else:
82
  relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
83
+ return cute.arch.mul_packed_f32x2(relu_x, x)
84
 
85
 
86
  @dsl_user_op
 
103
  return dx, relu_sq_out
104
  else:
105
  relu_x = relu(x)
106
+ relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
107
+ dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
108
  return dx, relu_sq_out
109
 
110
 
 
124
  * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
125
  )
126
  else:
127
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
128
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
129
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
130
  )
131
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
132
  tanh_z = (tanh(z[0]), tanh(z[1]))
133
+ x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
134
+ return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
135
 
136
 
137
  @dsl_user_op
 
172
  return dx, gelu_out
173
  else:
174
  # Compute z = x * (c1 + c2 * x^2)
175
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
176
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
177
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
178
  )
179
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
180
  tanh_z = (tanh(z[0]), tanh(z[1]))
181
+ half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
182
+ gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
183
 
184
  # Compute gradient
185
  # sech^2(z) = 1 - tanh^2(z)
186
+ sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
187
  # dz/dx = c1 + 3 * c2 * x^2
188
+ dz_dx = cute.arch.fma_packed_f32x2(
189
  x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
190
  )
191
  # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
192
+ sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
193
+ x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
194
+ dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
195
 
196
+ dx = cute.arch.mul_packed_f32x2(dout, dgelu)
197
  return dx, gelu_out
198
 
199
 
 
209
  )
210
  else:
211
  log2_e = math.log2(math.e)
212
+ x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
213
  x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
214
+ x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
215
  log_x_exp_p1 = (
216
  cute.math.log2(x_exp_p1[0], fastmath=True),
217
  cute.math.log2(x_exp_p1[1], fastmath=True),
218
  )
219
  ln2 = math.log(2.0)
220
+ softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
221
  use_linear_0 = Boolean(x[0] > 20.0)
222
  use_linear_1 = Boolean(x[1] > 20.0)
223
  return (
 
246
  # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
247
  return x_half * tanh(x_half) + x_half
248
  else:
249
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
250
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
251
+ return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
252
 
253
 
254
  @dsl_user_op
 
256
  if const_expr(not isinstance(x, tuple)):
257
  return silu(x) * y
258
  else:
259
+ return cute.arch.mul_packed_f32x2(silu(x), y)
260
 
261
 
262
  @dsl_user_op
 
306
  # Compute sigmoid(x) and silu(x)
307
  if const_expr(not already_halved):
308
  sigmoid_x = sigmoid(x)
309
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
310
  else:
311
  tanh_x = (tanh(x[0]), tanh(x[1]))
312
+ sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
313
+ silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
314
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
315
  # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
316
+ sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
317
  sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
318
  )
319
+ d_silu_x_dout = cute.arch.fma_packed_f32x2(
320
+ sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
321
+ )
322
+ dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
323
  dy = silu_x_dout
324
+ swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
325
  return dx, dy, swiglu_out
326
 
327
 
 
341
  silu_x = x_half * tanh(alpha * x_half) + x_half
342
  return silu_x * y + silu_x
343
  else:
344
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
345
+ alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
346
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
347
+ silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
348
+ return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
349
 
350
 
351
  @dsl_user_op
 
377
  return dx, dy, swiglu_out
378
  else:
379
  # Compute sigmoid(alpha * x)
380
+ alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
381
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
382
+ sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
383
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
384
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
385
  # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
386
+ silu_x_minus_product = cute.arch.fma_packed_f32x2(
387
  silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
388
  )
389
+ sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
390
  (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
391
  )
392
+ d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
393
+ dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
394
  dy = silu_x_dout
395
+ swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
396
  return dx, dy, swiglu_out
397
 
398
 
 
407
  return sigmoid_x * y # FMUL
408
  else:
409
  sigmoid_x = sigmoid(x)
410
+ return cute.arch.mul_packed_f32x2(sigmoid_x, y)
411
 
412
 
413
  @dsl_user_op
 
437
  return dx, dy, glu_out
438
  else:
439
  sigmoid_x = sigmoid(x)
440
+ sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
441
+ glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
442
  # dx = (y - glu_out) * sigmoid_x_dout
443
+ y_minus_glu_out = sub_packed_f32x2(y, glu_out)
444
+ dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
445
  dy = sigmoid_x_dout
446
  return dx, dy, glu_out
447
 
 
455
  return cute.arch.fmax(x, Float32(0.0)) * y
456
  else:
457
  relu_x = relu(x)
458
+ return cute.arch.mul_packed_f32x2(relu_x, y)
459
 
460
 
461
  @dsl_user_op
 
482
  x0_pos = Boolean(x[0] > 0)
483
  x1_pos = Boolean(x[1] > 0)
484
  relu_x = relu(x)
485
+ dout_y = cute.arch.mul_packed_f32x2(dout, y)
486
  dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
487
+ dy = cute.arch.mul_packed_f32x2(dout, relu_x)
488
+ reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
489
  return dx, dy, reglu_out
490
 
491
 
 
498
  if const_expr(not isinstance(x, tuple)):
499
  return gelu_tanh_approx(x) * y
500
  else:
501
+ return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
502
 
503
 
504
  @dsl_user_op
 
525
  # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
526
  dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
527
  # Compute gradients for geglu
528
+ dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
529
+ dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
530
+ geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
531
  return dx, dy, geglu_out
532
+
533
+
534
+ # ============================================================================
535
+ # Activation name -> function maps
536
+ # ============================================================================
537
+
538
+ act_fn_map = {
539
+ None: None,
540
+ "silu": silu,
541
+ "relu": relu,
542
+ "relu_sq": relu_sq,
543
+ "gelu_tanh_approx": gelu_tanh_approx,
544
+ }
545
+
546
+ dact_fn_map = {
547
+ None: None,
548
+ "relu": drelu,
549
+ "relu_sq": drelu_sq,
550
+ "gelu_tanh_approx": dgelu_tanh_approx,
551
+ }
552
+
553
+ gate_fn_map = {
554
+ "swiglu": swiglu,
555
+ "swiglu_oai": swiglu_oai,
556
+ "reglu": reglu,
557
+ "geglu": geglu,
558
+ "glu": glu,
559
+ }
560
+
561
+ dgate_fn_map = {
562
+ "swiglu": dswiglu,
563
+ "swiglu_oai": dswiglu_oai,
564
+ "reglu": dreglu,
565
+ "geglu": dgeglu,
566
+ "glu": dglu,
567
+ }
build/torch-cuda/quack/autotuner.py CHANGED
@@ -25,6 +25,29 @@ PACKAGE_NAME = "quack"
25
  VERSION = __version__
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def get_home_dir():
29
  return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
30
 
@@ -52,6 +75,22 @@ def _base32(key):
52
  return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  class Autotuner:
56
  def __init__(
57
  self,
@@ -124,6 +163,146 @@ class Autotuner:
124
  return partial(triton.testing.do_bench, warmup=5, rep=25)
125
  return self._do_bench
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def _bench(self, *args, config, **meta):
128
  verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
129
  if verbose:
@@ -227,6 +406,8 @@ class Autotuner:
227
 
228
  @torch.compiler.disable # Don't want any tracing here
229
  def benchmark():
 
 
230
  bench_start = time.time()
231
  timings = {
232
  config: self._bench(*args, config=config, **kwargs)
@@ -316,11 +497,11 @@ class AutotuneConfig:
316
  return ", ".join(res)
317
 
318
  def __hash__(self):
319
- return hash(tuple(*self.all_kwargs().items()))
320
 
321
  def __eq__(self, other):
322
- self_tuple = tuple(*self.all_kwargs().items())
323
- other_tuple = tuple(*other.all_kwargs().items())
324
  return self_tuple == other_tuple
325
 
326
 
 
25
  VERSION = __version__
26
 
27
 
28
+ def _get_current_cuda_device() -> str | None:
29
+ """Return the physical CUDA device identifier for the current process.
30
+
31
+ Maps the logical ``torch.cuda.current_device()`` index through
32
+ ``CUDA_VISIBLE_DEVICES`` (if set) so the result is valid as a
33
+ standalone ``CUDA_VISIBLE_DEVICES`` value (handles integer IDs,
34
+ GPU UUIDs, and MIG IDs).
35
+
36
+ Returns ``None`` if CUDA is not initialized or the device cannot
37
+ be determined.
38
+ """
39
+ if not (torch.cuda.is_available() and torch.cuda.is_initialized()):
40
+ return None
41
+ logical_device = torch.cuda.current_device()
42
+ parent_visible = os.environ.get("CUDA_VISIBLE_DEVICES")
43
+ if parent_visible is not None:
44
+ visible_devices = [d.strip() for d in parent_visible.split(",")]
45
+ if logical_device < len(visible_devices):
46
+ return visible_devices[logical_device]
47
+ return None
48
+ return str(logical_device)
49
+
50
+
51
  def get_home_dir():
52
  return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
53
 
 
75
  return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
76
 
77
 
78
+ def _gpu_warmup(duration_ms=200):
79
+ """Saturate the GPU to reach thermal steady-state before benchmarking.
80
+
81
+ Without this, the first autotuning config gets artificially good numbers
82
+ because the GPU hasn't been power-throttled yet.
83
+ """
84
+ a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16)
85
+ torch.cuda.synchronize()
86
+ target = duration_ms / 1000
87
+ t0 = time.time()
88
+ while time.time() - t0 < target:
89
+ for _ in range(100):
90
+ a = a @ a
91
+ torch.cuda.synchronize()
92
+
93
+
94
  class Autotuner:
95
  def __init__(
96
  self,
 
163
  return partial(triton.testing.do_bench, warmup=5, rep=25)
164
  return self._do_bench
165
 
166
+ def _precompile(self, *args, configs, **kwargs):
167
+ """Pre-compile all configs in parallel subprocesses to populate .o cache.
168
+
169
+ cute.compile() is not thread-safe (MLIR thread-local state) and fork after
170
+ CUDA init causes segfaults. So we spawn persistent subprocess workers: each
171
+ has its own CUDA context, creates FakeTensors matching the parent's tensor
172
+ metadata, and compiles with COMPILE_ONLY=True. Workers stay alive to amortize
173
+ import overhead across multiple configs. The parent then loads instantly from
174
+ the .o cache during benchmarking.
175
+ """
176
+ from .cache_utils import CACHE_ENABLED
177
+
178
+ if not CACHE_ENABLED:
179
+ return
180
+
181
+ max_workers = min(len(configs), int(os.getenv("QUACK_COMPILE_WORKERS", "8")))
182
+ if max_workers <= 1:
183
+ return
184
+
185
+ # Quick check: compile first config in-process. If it loads from .o cache
186
+ # (<0.5s), the rest are likely cached too — skip spawning workers.
187
+ t_check = time.time()
188
+ try:
189
+ current = dict(kwargs, **configs[0].all_kwargs())
190
+ self.fn(*args, **current)
191
+ except Exception:
192
+ pass
193
+ if time.time() - t_check < 0.5:
194
+ return
195
+
196
+ verbose = os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
197
+ if verbose:
198
+ print(f"Pre-compiling {len(configs)} configs with {max_workers} workers")
199
+ t0 = time.time()
200
+
201
+ import pickle
202
+ import struct
203
+ import subprocess
204
+ import sys
205
+
206
+ def _send(stream, msg):
207
+ data = pickle.dumps(msg)
208
+ stream.write(struct.pack("<I", len(data)))
209
+ stream.write(data)
210
+ stream.flush()
211
+
212
+ def _recv(stream):
213
+ header = stream.read(4)
214
+ if len(header) < 4:
215
+ return None
216
+ length = struct.unpack("<I", header)[0]
217
+ return pickle.loads(stream.read(length)) if length else None
218
+
219
+ # Serialize tensor metadata
220
+ tensor_meta = []
221
+ for arg in args:
222
+ if isinstance(arg, Tensor):
223
+ tensor_meta.append(
224
+ {
225
+ "shape": list(arg.shape),
226
+ "stride": list(arg.stride()),
227
+ "dtype": str(arg.dtype),
228
+ }
229
+ )
230
+ else:
231
+ tensor_meta.append(arg)
232
+
233
+ fn_module = self.fn.__module__
234
+ fn_qualname = self.fn.__qualname__
235
+
236
+ # Restrict worker subprocesses to the parent's current CUDA device.
237
+ # Without this, all workers default to cuda:0 and their CUDA context
238
+ # initialization can OOM when many ranks share a node.
239
+ worker_env = os.environ.copy()
240
+ current_device = _get_current_cuda_device()
241
+ if current_device is not None:
242
+ worker_env["CUDA_VISIBLE_DEVICES"] = current_device
243
+
244
+ # Launch persistent worker pool. When vendored under sonic_moe (loaded
245
+ # via kernels.get_kernel), the quack package isn't importable as a
246
+ # top-level module, so invoke the worker via its fully-qualified dotted
247
+ # path and inject PYTHONPATH so the subprocess can import it.
248
+ worker_module = __package__ + "._compile_worker" if __package__ else "quack._compile_worker"
249
+ if __package__:
250
+ import importlib.util
251
+ spec = importlib.util.find_spec(__package__.split(".")[0])
252
+ if spec is not None and spec.submodule_search_locations:
253
+ pkg_parent = os.path.dirname(list(spec.submodule_search_locations)[0])
254
+ existing_pp = worker_env.get("PYTHONPATH", "")
255
+ worker_env["PYTHONPATH"] = (
256
+ f"{pkg_parent}{os.pathsep}{existing_pp}" if existing_pp else pkg_parent
257
+ )
258
+
259
+ workers = []
260
+ for _ in range(max_workers):
261
+ p = subprocess.Popen(
262
+ [sys.executable, "-m", worker_module],
263
+ stdin=subprocess.PIPE,
264
+ stdout=subprocess.PIPE,
265
+ stderr=subprocess.DEVNULL if not verbose else None,
266
+ env=worker_env,
267
+ )
268
+ ready = _recv(p.stdout)
269
+ if ready != "READY":
270
+ p.kill()
271
+ continue
272
+ workers.append(p)
273
+
274
+ if not workers:
275
+ return
276
+
277
+ # Round-robin dispatch configs to workers
278
+ pending = [0] * len(workers)
279
+ for i, config in enumerate(configs):
280
+ w = workers[i % len(workers)]
281
+ _send(
282
+ w.stdin,
283
+ {
284
+ "fn_module": fn_module,
285
+ "fn_qualname": fn_qualname,
286
+ "tensor_meta": tensor_meta,
287
+ "kwargs": kwargs,
288
+ "config_kwargs": config.all_kwargs(),
289
+ },
290
+ )
291
+ pending[i % len(workers)] += 1
292
+
293
+ # Collect all results
294
+ for wi, w in enumerate(workers):
295
+ for _ in range(pending[wi]):
296
+ _recv(w.stdout)
297
+
298
+ # Shutdown workers (close stdin → worker exits)
299
+ for w in workers:
300
+ w.stdin.close()
301
+ w.wait()
302
+
303
+ if verbose:
304
+ print(f"Pre-compilation done in {time.time() - t0:.1f}s")
305
+
306
  def _bench(self, *args, config, **meta):
307
  verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
308
  if verbose:
 
406
 
407
  @torch.compiler.disable # Don't want any tracing here
408
  def benchmark():
409
+ self._precompile(*args, configs=pruned_configs, **kwargs)
410
+ _gpu_warmup()
411
  bench_start = time.time()
412
  timings = {
413
  config: self._bench(*args, config=config, **kwargs)
 
497
  return ", ".join(res)
498
 
499
  def __hash__(self):
500
+ return hash(tuple(self.all_kwargs().items()))
501
 
502
  def __eq__(self, other):
503
+ self_tuple = tuple(self.all_kwargs().items())
504
+ other_tuple = tuple(other.all_kwargs().items())
505
  return self_tuple == other_tuple
506
 
507
 
build/torch-cuda/quack/blockscaled_gemm_utils.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026, Tri Dao.
2
+
3
+ import itertools
4
+ from functools import partial
5
+ from typing import Callable, Optional, Type, Tuple
6
+
7
+ import torch
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+
12
+ from .compile_utils import make_fake_tensor as fake_tensor
13
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
14
+ from .gemm_default_epi import GemmDefaultSm100
15
+ from .gemm_tvm_ffi_utils import div_for_dtype, make_scheduler_args
16
+ from .mx_utils import (
17
+ to_mx_compiled,
18
+ to_mxfp4_compiled,
19
+ to_nvfp4_compiled,
20
+ )
21
+ from .varlen_utils import VarlenArguments
22
+
23
+
24
+ TORCH_DTYPE_MAP = {
25
+ cutlass.Float4E2M1FN: torch.float4_e2m1fn_x2,
26
+ cutlass.Float16: torch.float16,
27
+ cutlass.BFloat16: torch.bfloat16,
28
+ cutlass.Float32: torch.float32,
29
+ cutlass.Float8E4M3FN: torch.float8_e4m3fn,
30
+ cutlass.Float8E5M2: torch.float8_e5m2,
31
+ cutlass.Float8E8M0FNU: torch.float8_e8m0fnu,
32
+ }
33
+
34
+ FLOAT8_DTYPES = {
35
+ torch.float8_e4m3fn,
36
+ torch.float8_e5m2,
37
+ torch.float8_e8m0fnu,
38
+ }
39
+
40
+
41
+ FP4_E2M1FN_VALUES = (
42
+ 0.0,
43
+ 0.5,
44
+ 1.0,
45
+ 1.5,
46
+ 2.0,
47
+ 3.0,
48
+ 4.0,
49
+ 6.0,
50
+ -0.0,
51
+ -0.5,
52
+ -1.0,
53
+ -1.5,
54
+ -2.0,
55
+ -3.0,
56
+ -4.0,
57
+ -6.0,
58
+ )
59
+
60
+
61
+ def ceil_div(a: int, b: int) -> int:
62
+ return (a + b - 1) // b
63
+
64
+
65
+ def torch_dtype_for_cutlass(dtype: Type[cutlass.Numeric]) -> torch.dtype:
66
+ if dtype not in TORCH_DTYPE_MAP:
67
+ raise TypeError(f"Unsupported dtype: {dtype}")
68
+ return TORCH_DTYPE_MAP[dtype]
69
+
70
+
71
+ def _make_fake_tensor_like(tensor: torch.Tensor, dtype: Type[cutlass.Numeric]) -> cute.Tensor:
72
+ return cute.runtime.make_fake_tensor(
73
+ dtype,
74
+ tensor.shape,
75
+ stride=tensor.stride(),
76
+ assumed_align=16,
77
+ )
78
+
79
+
80
+ def _leading_dim_from_stride(tensor: torch.Tensor) -> int:
81
+ for i, stride in enumerate(tensor.stride()):
82
+ if stride == 1:
83
+ return i
84
+ raise ValueError(
85
+ f"Tensor has no unit stride dimension: shape={tensor.shape}, stride={tensor.stride()}"
86
+ )
87
+
88
+
89
+ def _make_compile_tensor_like(
90
+ tensor: torch.Tensor, dtype: Type[cutlass.Numeric], dynamic_layout: bool = False
91
+ ) -> cute.Tensor:
92
+ compile_tensor = cute.runtime.from_dlpack(tensor)
93
+ compile_tensor.element_type = dtype
94
+ if dynamic_layout:
95
+ marked = compile_tensor.mark_layout_dynamic(leading_dim=_leading_dim_from_stride(tensor))
96
+ if marked is not None:
97
+ compile_tensor = marked
98
+ return compile_tensor
99
+
100
+
101
+ def _make_fake_compact_tensor(
102
+ shape: Tuple[int, ...], dtype: Type[cutlass.Numeric], leading_dim: int
103
+ ) -> cute.Tensor:
104
+ logical_shape = list(shape)
105
+ if dtype == cutlass.Float4E2M1FN:
106
+ logical_shape[leading_dim] *= 2
107
+ return fake_tensor(
108
+ dtype,
109
+ tuple(logical_shape),
110
+ leading_dim=leading_dim,
111
+ divisibility=div_for_dtype(dtype),
112
+ )
113
+
114
+
115
+ def _fp4_e2m1fn_value_table(device: torch.device) -> torch.Tensor:
116
+ return torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32, device=device)
117
+
118
+
119
+ def _pack_fp4_e2m1fn_codes(codes: torch.Tensor) -> torch.Tensor:
120
+ """Pack logical FP4 codes into torch.float4_e2m1fn_x2 storage."""
121
+ if codes.dtype != torch.uint8:
122
+ raise TypeError(f"Expected uint8 FP4 codes, got {codes.dtype}")
123
+ packed_shape = (codes.shape[0], ceil_div(codes.shape[1], 2), codes.shape[2])
124
+ packed = torch.empty(packed_shape, dtype=torch.float4_e2m1fn_x2, device=codes.device)
125
+ packed_u8 = packed.view(torch.uint8)
126
+ low = codes[:, 0::2, :]
127
+ high = torch.zeros_like(low)
128
+ high[:, : codes[:, 1::2, :].shape[1], :] = codes[:, 1::2, :]
129
+ packed_u8.copy_(low | (high << 4))
130
+ return packed
131
+
132
+
133
+ def _create_fp4_operand_tensor(
134
+ l: int,
135
+ mode0: int,
136
+ mode1: int,
137
+ is_mode0_major: bool,
138
+ *,
139
+ init: str,
140
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
141
+ if is_mode0_major:
142
+ raise ValueError("Float4E2M1FN blockscaled operands must be K-major")
143
+ tensor = torch.empty(
144
+ (mode0, ceil_div(mode1, 2), l), dtype=torch.float4_e2m1fn_x2, device="cuda"
145
+ )
146
+ tensor.view(torch.uint8).zero_()
147
+ if init == "empty":
148
+ return None, tensor
149
+ if init != "normal":
150
+ raise ValueError(f"Unsupported init: {init}")
151
+
152
+ magnitudes = torch.randint(0, 8, (mode0, mode1, l), device="cuda", dtype=torch.uint8)
153
+ signs = torch.randint(0, 2, (mode0, mode1, l), device="cuda", dtype=torch.uint8)
154
+ signs = torch.where(magnitudes == 0, torch.zeros_like(signs), signs << 3)
155
+ codes = magnitudes | signs
156
+ tensor.copy_(_pack_fp4_e2m1fn_codes(codes))
157
+ ref = _fp4_e2m1fn_value_table(tensor.device)[codes.long()]
158
+ return ref, tensor
159
+
160
+
161
+ def create_blockscaled_operand_tensor(
162
+ l: int,
163
+ mode0: int,
164
+ mode1: int,
165
+ is_mode0_major: bool,
166
+ dtype: Type[cutlass.Numeric],
167
+ *,
168
+ init: str = "normal",
169
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
170
+ if dtype == cutlass.Float4E2M1FN:
171
+ return _create_fp4_operand_tensor(l, mode0, mode1, is_mode0_major, init=init)
172
+ shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1)
173
+ permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0)
174
+ torch_dtype = torch_dtype_for_cutlass(dtype)
175
+ gen_dtype = torch.bfloat16 if torch_dtype in FLOAT8_DTYPES else torch_dtype
176
+ tensor = torch.empty(shape, dtype=gen_dtype, device="cuda")
177
+ if init == "normal":
178
+ tensor.normal_(std=mode1 ** (-0.5))
179
+ elif init != "empty":
180
+ raise ValueError(f"Unsupported init: {init}")
181
+ # Do NOT .contiguous() after .permute() — that would re-materialize with wrong
182
+ # strides (L innermost) and break K-majorness / N-majorness for l > 1.
183
+ # The original (l, mode0/1, mode1/0) is contiguous, and the permuted view has
184
+ # the correct per-mode strides: stride=1 on the intended contiguous dim.
185
+ tensor = tensor.to(torch_dtype).permute(permute_order)
186
+ ref = tensor.float() if init != "empty" else None
187
+ return ref, tensor
188
+
189
+
190
+ def _pack_blockscaled_scales(ref_blocks: torch.Tensor) -> torch.Tensor:
191
+ """Rearrange (mn, sf_k, l) scales into the (l, rm, rk, 512) blocked layout."""
192
+ mn, sf_k, l = ref_blocks.shape
193
+ rm = ceil_div(mn, 128)
194
+ rk = ceil_div(sf_k, 4)
195
+ packed_6d = torch.zeros((l, rm, rk, 32, 4, 4), dtype=torch.float32, device=ref_blocks.device)
196
+ packed_view = packed_6d.permute(3, 4, 1, 5, 2, 0) # (32, 4, rm, 4, rk, l)
197
+ m_idx = torch.arange(mn, device=ref_blocks.device)
198
+ k_idx = torch.arange(sf_k, device=ref_blocks.device)
199
+ l_idx = torch.arange(l, device=ref_blocks.device)
200
+ packed_view[
201
+ m_idx[:, None, None] % 32,
202
+ (m_idx[:, None, None] // 32) % 4,
203
+ m_idx[:, None, None] // 128,
204
+ k_idx[None, :, None] % 4,
205
+ k_idx[None, :, None] // 4,
206
+ l_idx[None, None, :],
207
+ ] = ref_blocks
208
+ return packed_6d.view(l, rm, rk, 512)
209
+
210
+
211
+ def create_blockscaled_scale_tensor(
212
+ l: int,
213
+ mn: int,
214
+ k: int,
215
+ sf_vec_size: int,
216
+ dtype: Type[cutlass.Numeric],
217
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
218
+ sf_k = ceil_div(k, sf_vec_size)
219
+ if dtype == cutlass.Float8E8M0FNU:
220
+ exponents = torch.randint(0, 2, (mn, sf_k, l), device="cuda", dtype=torch.int32)
221
+ ref_blocks = torch.pow(2.0, exponents.float())
222
+ else:
223
+ ref_blocks = torch.randint(1, 4, (mn, sf_k, l), device="cuda", dtype=torch.int32).float()
224
+
225
+ packed_f32 = _pack_blockscaled_scales(ref_blocks)
226
+ packed = torch.empty_like(packed_f32, dtype=torch_dtype_for_cutlass(dtype))
227
+ packed.copy_(packed_f32)
228
+ ref = (
229
+ ref_blocks.permute(2, 0, 1)
230
+ .unsqueeze(-1)
231
+ .expand(l, mn, sf_k, sf_vec_size)
232
+ .reshape(l, mn, sf_k * sf_vec_size)
233
+ .permute(1, 2, 0)
234
+ )[:, :k, :]
235
+ return ref, packed
236
+
237
+
238
+ def pack_scale_2d_to_blocked_contig(scale_2d: torch.Tensor) -> torch.Tensor:
239
+ """Rearrange a (l, mn, sf_k) or (mn, sf_k) e8m0 scale tensor into the
240
+ contiguous (l, rm, rk, 512) blocked layout shared by the quack kernel and
241
+ cuBLAS's block-scaling. Each 512 B inner block holds one 128 MN × 4 K
242
+ swizzled tile. Pads `mn` to a multiple of 128 and `sf_k` to a multiple of
243
+ 4 with zeros."""
244
+ if scale_2d.dim() == 2:
245
+ scale_2d = scale_2d.unsqueeze(0)
246
+ assert scale_2d.dim() == 3, f"expected (l, mn, sf_k), got shape {tuple(scale_2d.shape)}"
247
+ orig_dtype = scale_2d.dtype
248
+ l, mn, sf_k = scale_2d.shape
249
+ rm = ceil_div(mn, 128)
250
+ rk = ceil_div(sf_k, 4)
251
+ mn_pad = rm * 128
252
+ sf_k_pad = rk * 4
253
+ u8 = scale_2d.contiguous().view(torch.uint8)
254
+ if mn_pad != mn or sf_k_pad != sf_k:
255
+ padded = torch.zeros(l, mn_pad, sf_k_pad, device=scale_2d.device, dtype=torch.uint8)
256
+ padded[:, :mn, :sf_k] = u8
257
+ else:
258
+ padded = u8
259
+ # (l, mn_pad, sf_k_pad) -> (l, rm, 128, rk, 4) -> (l, rm, rk, 128, 4)
260
+ blocks = padded.view(l, rm, 128, rk, 4).permute(0, 1, 3, 2, 4)
261
+ # split 128 into (4 outer, 32 inner), then swap to (32, 4)
262
+ blocks = blocks.reshape(l, rm, rk, 4, 32, 4).transpose(3, 4).contiguous()
263
+ return blocks.view(l, rm, rk, 512).view(orig_dtype)
264
+
265
+
266
+ def scale_view_for_kernel(scale_contig: torch.Tensor, mn: int, sf_k: int, l: int) -> torch.Tensor:
267
+ """Validate a (l, rm, rk, 512) scale tensor and return it unchanged.
268
+ Only the innermost 512-B tile must be contiguous (stride 1, size 512);
269
+ outer (L, rm, rk) strides are free — the kernel reads them from the
270
+ passed tensor. This lets callers pass a slice/view of a larger buffer
271
+ with no extra copy. Works for both E8M0 (MX) and E4M3 (NVFP4)."""
272
+ rm = ceil_div(mn, 128)
273
+ rk = ceil_div(sf_k, 4)
274
+ assert scale_contig.shape == (l, rm, rk, 512), (
275
+ f"expected (l, rm, rk, 512) = ({l}, {rm}, {rk}, 512), got {tuple(scale_contig.shape)}"
276
+ )
277
+ assert scale_contig.stride(-1) == 1, (
278
+ f"innermost 512-B dim must be unit-stride, got stride {scale_contig.stride(-1)}"
279
+ )
280
+ return scale_contig
281
+
282
+
283
+ def scale_blocked_for_cublas(
284
+ scale_contig: torch.Tensor, mn: int, sf_k: int, l_idx: int = 0
285
+ ) -> torch.Tensor:
286
+ """Flatten a (l, rm, rk, 512) scale tensor to the 1D swizzled layout
287
+ torch._scaled_mm expects. Uses a single l slice."""
288
+ assert scale_contig.is_contiguous() and scale_contig.dim() == 4
289
+ return scale_contig[l_idx].reshape(-1)
290
+
291
+
292
+ _FP4_E2M1_CODE_TO_VALUE = torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32)
293
+
294
+
295
+ def _fp4_unpacked_to_value(codes_u8: torch.Tensor) -> torch.Tensor:
296
+ """Convert FP4 E2M1 codes in [0,16) to signed float values via table lookup.
297
+ Code layout: bit 3 = sign, bits 0-2 = magnitude index into {0,.5,1,1.5,2,3,4,6}."""
298
+ table = _FP4_E2M1_CODE_TO_VALUE.to(codes_u8.device)
299
+ return table[codes_u8.long()]
300
+
301
+
302
+ def _blockscaled_format_of(ab_dtype, sf_dtype, sf_vec_size) -> str:
303
+ """Identify which blockscaled format the (ab, sf, vec) tuple corresponds to."""
304
+ if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32:
305
+ return "mxfp8"
306
+ if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32:
307
+ return "mxfp4"
308
+ if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 16:
309
+ return "nvfp4"
310
+ raise ValueError(
311
+ f"init=quant does not support (ab={ab_dtype}, sf={sf_dtype}, vec={sf_vec_size}). "
312
+ f"Supported: MXFP8 (e4m3+e8m0+32), MXFP4 (e2m1+e8m0+32), NVFP4 (e2m1+e4m3+16)."
313
+ )
314
+
315
+
316
+ def create_blockscaled_operand_quantized(
317
+ l: int,
318
+ mn: int,
319
+ k: int,
320
+ is_mn_major: bool,
321
+ sf_vec_size: int = 32,
322
+ ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN,
323
+ sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU,
324
+ *,
325
+ randn_std: Optional[float] = None,
326
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
327
+ """Generate bf16 randn, quantize to MXFP8/MXFP4/NVFP4 and produce:
328
+ ref: (mn, k, l) float32 dequantized reference
329
+ q_mkl: (mn, k, l) operand tensor in the layout the quack kernel consumes
330
+ (float8_e4m3fn for fp8 formats; int8 with packed nibbles for fp4)
331
+ scale_contig: (l, rm, rk, 512) contiguous scale storage. Each 512 B
332
+ inner block is one 128 MN × 4 K swizzled tile. Byte layout matches
333
+ cuBLAS `to_blocked`. Pass directly to the quack kernel, or use
334
+ `scale_blocked_for_cublas` for cuBLAS.
335
+ """
336
+ fmt = _blockscaled_format_of(ab_dtype, sf_dtype, sf_vec_size)
337
+ if is_mn_major and fmt != "mxfp8":
338
+ raise NotImplementedError(
339
+ f"is_mn_major=True is only supported for MXFP8 (tcgen05 MMA requires "
340
+ f"K-major for MXFP4/NVFP4 operands); got fmt={fmt}"
341
+ )
342
+ assert k % sf_vec_size == 0, f"k ({k}) must be divisible by sf_vec_size ({sf_vec_size})"
343
+ sf_k = k // sf_vec_size
344
+ std = randn_std if randn_std is not None else k**-0.5
345
+
346
+ x_hp = (torch.randn(l, mn, k, dtype=torch.bfloat16, device="cuda") * std).contiguous()
347
+ x_flat = x_hp.view(l * mn, k)
348
+
349
+ if fmt == "mxfp8":
350
+ q_flat, scale_2d = to_mx_compiled(x_flat, sf_vec_size) # (l*mn, k), (l*mn, sf_k)
351
+ if is_mn_major:
352
+ # Operand: (mn, k, l) MN-major. Start from (l, mn, k) contig, transpose
353
+ # to (l, k, mn) contig, then permute to (mn, k, l) with strides (1, mn, mn*k).
354
+ q_mkl = (
355
+ q_flat.view(l, mn, k).transpose(1, 2).contiguous().permute(2, 1, 0)
356
+ ) # strides (1, mn, mn*k)
357
+ else:
358
+ # Operand: (mn, k, l) K-major VIEW of contiguous (l, mn, k).
359
+ # Do NOT call .contiguous() here — that would materialize as (mn, k, l) row-major,
360
+ # making L the innermost stride=1 dim and BREAKING K-majorness for l > 1.
361
+ q_mkl = q_flat.view(l, mn, k).contiguous().permute(1, 2, 0) # strides (k, 1, mn*k)
362
+ q_vals = q_flat.float().view(l, mn, k)
363
+ scale_vals = scale_2d.float().view(l, mn, sf_k).repeat_interleave(sf_vec_size, dim=-1)
364
+ ref_mkl = (q_vals * scale_vals).permute(1, 2, 0).contiguous()
365
+ scale_2d = scale_2d.view(l, mn, sf_k)
366
+ elif fmt in ("mxfp4", "nvfp4"):
367
+ if fmt == "mxfp4":
368
+ q_packed, scale_2d = to_mxfp4_compiled(x_flat, sf_vec_size) # (l*mn, k/2), (l*mn, sf_k)
369
+ else:
370
+ q_packed, scale_2d, _pts = to_nvfp4_compiled(x_flat, sf_vec_size, None)
371
+ # q_packed is uint8, two 4-bit codes per byte (low nibble=even K, high=odd K).
372
+ # Decode for ref: code -> {0,.5,1,1.5,2,3,4,6,-0,-.5,...} via lookup.
373
+ codes_lo = (q_packed & 0x0F).view(l, mn, k // 2)
374
+ codes_hi = ((q_packed >> 4) & 0x0F).view(l, mn, k // 2)
375
+ vals_lo = _fp4_unpacked_to_value(codes_lo) # (l, mn, k/2)
376
+ vals_hi = _fp4_unpacked_to_value(codes_hi)
377
+ q_values = torch.stack([vals_lo, vals_hi], dim=-1).reshape(l, mn, k) # interleave back
378
+ scale_vals = scale_2d.float().view(l, mn, sf_k).repeat_interleave(sf_vec_size, dim=-1)
379
+ ref_mkl = (q_values * scale_vals).permute(1, 2, 0).contiguous()
380
+ # Kernel operand: (mn, k/2, l) K-major view (no post-contiguous!)
381
+ q_mkl = (
382
+ q_packed.view(l, mn, k // 2).contiguous().permute(1, 2, 0).view(torch.float4_e2m1fn_x2)
383
+ )
384
+ scale_2d = scale_2d.view(l, mn, sf_k)
385
+
386
+ scale_contig = pack_scale_2d_to_blocked_contig(scale_2d)
387
+ return ref_mkl, q_mkl, scale_contig
388
+
389
+
390
+ def create_blockscaled_varlen_m_operands(
391
+ num_experts: int,
392
+ m_per: int,
393
+ n: int,
394
+ k: int,
395
+ sf_vec_size: int,
396
+ ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN,
397
+ sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU,
398
+ *,
399
+ randn_std: Optional[float] = None,
400
+ seqlens_m: Optional[list] = None,
401
+ b_major: str = "k",
402
+ ):
403
+ """Generate bf16 randn + quantize for a varlen_m blockscaled GEMM.
404
+
405
+ Per-expert seqlens may be arbitrary (not required to be multiples of 128).
406
+ SF is stored in dQaccum-style padded format: each expert `i`'s scales
407
+ occupy `ceildiv(m_i, 128) * 128` rows at offset
408
+ `(cu_seqlens_m[i] + i * 128) // 128 * 128` in the padded scale buffer.
409
+ The kernel decodes via `VarlenManager.offset_batch_SFA` which applies the
410
+ same formula.
411
+
412
+ Returns (a_ref, b_ref, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_m):
413
+ a_ref: (total_m, k) fp32 dequantized
414
+ b_ref: (num_experts, n, k) fp32 dequantized
415
+ qa: (total_m, k) 2D K-major quantized operand (fp8) or (total_m, k/2) (fp4)
416
+ qb: (n, k, num_experts) 3D K-major quantized operand (fp8) or (n, k/2, num_experts) (fp4)
417
+ a_sc_contig: (1, total_padded_rm, rk, 512) — dQaccum-padded SFA.
418
+ total_padded_rm = ((total_m + num_experts * 128) // 128).
419
+ b_sc_contig: (num_experts, rn, rk, 512) — regular per-expert SFB.
420
+ cu_seqlens_m: (num_experts+1,) int32
421
+ """
422
+ assert k % sf_vec_size == 0
423
+ if seqlens_m is None:
424
+ seqlens_m = [m_per] * num_experts
425
+ assert len(seqlens_m) == num_experts, (
426
+ f"seqlens_m length {len(seqlens_m)} != num_experts {num_experts}"
427
+ )
428
+ total_m = int(sum(seqlens_m))
429
+ std = randn_std if randn_std is not None else k**-0.5
430
+ sf_k = k // sf_vec_size
431
+
432
+ if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32:
433
+ from .mx_utils import to_mx_compiled
434
+
435
+ to_fn = to_mx_compiled
436
+ else:
437
+ raise NotImplementedError(
438
+ f"varlen_m currently only supports MXFP8 (got ab={ab_dtype}, sf={sf_dtype}, vec={sf_vec_size}). "
439
+ "FP4 support pending."
440
+ )
441
+
442
+ # Quantize A: (total_m, k) bf16 -> (total_m, k) fp8 K-major.
443
+ # A data itself is stored packed (no per-expert padding); only SFA is padded.
444
+ a_hp = (torch.randn(total_m, k, dtype=torch.bfloat16, device="cuda") * std).contiguous()
445
+ qa, sa_2d = to_fn(a_hp, sf_vec_size) # (total_m, k), (total_m, sf_k)
446
+ a_ref = qa.float() * sa_2d.float().repeat_interleave(sf_vec_size, dim=-1)
447
+
448
+ # Build padded SFA storage (dQaccum format). Each expert's m_i rows of
449
+ # scales are written at padded tile offset `cu_seqlens[i] // 128 + i`.
450
+ # Allocation: `ceildiv(total_m, 128) + (L - 1)` tiles — proven sufficient
451
+ # in AI/varlen_blockscaled_sf_layout.md (proof 2's "tighter alternative").
452
+ # Matches `total_m // 128 + L` when total_m % 128 > 0; 1 tile smaller
453
+ # when total_m is an exact multiple of 128.
454
+ tile = 128
455
+ total_padded_rm = (total_m + tile - 1) // tile + (num_experts - 1)
456
+ total_padded_m = total_padded_rm * tile
457
+ sa_2d_padded = torch.zeros(total_padded_m, sf_k, dtype=sa_2d.dtype, device=sa_2d.device)
458
+ offset = 0
459
+ for i, m_i in enumerate(seqlens_m):
460
+ offset_padded = (offset // tile + i) * tile
461
+ sa_2d_padded[offset_padded : offset_padded + m_i] = sa_2d[offset : offset + m_i]
462
+ offset += m_i
463
+ a_sc_contig = pack_scale_2d_to_blocked_contig(sa_2d_padded.view(1, total_padded_m, sf_k))
464
+
465
+ # Quantize B: (num_experts, n, k) bf16 -> (n, k, num_experts). b_major selects
466
+ # k-major (stride (k, 1, n*k)) or n-major (stride (1, n, n*k)).
467
+ assert b_major in ("k", "n"), f"b_major must be 'k' or 'n', got {b_major!r}"
468
+ b_hp = (torch.randn(num_experts, n, k, dtype=torch.bfloat16, device="cuda") * std).contiguous()
469
+ qb_flat, sb_2d = to_fn(b_hp.view(num_experts * n, k), sf_vec_size)
470
+ if b_major == "k":
471
+ qb = (
472
+ qb_flat.view(num_experts, n, k).contiguous().permute(1, 2, 0)
473
+ ) # (n, k, l) stride (k, 1, n*k)
474
+ else:
475
+ qb = (
476
+ qb_flat.view(num_experts, n, k).transpose(1, 2).contiguous().permute(2, 1, 0)
477
+ ) # (n, k, l) stride (1, n, n*k)
478
+ sb_2d = sb_2d.view(num_experts, n, sf_k)
479
+ b_sc_contig = pack_scale_2d_to_blocked_contig(sb_2d)
480
+ b_ref = qb_flat.float().view(num_experts, n, k) * sb_2d.float().repeat_interleave(
481
+ sf_vec_size, dim=-1
482
+ )
483
+
484
+ cu_seqlens_m = torch.tensor(
485
+ [0] + list(itertools.accumulate(seqlens_m)), dtype=torch.int32, device="cuda"
486
+ )
487
+ return a_ref, b_ref, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_m
488
+
489
+
490
+ def create_blockscaled_varlen_k_operands(
491
+ num_experts: int,
492
+ k_per: int,
493
+ m: int,
494
+ n: int,
495
+ sf_vec_size: int,
496
+ ab_dtype: Type[cutlass.Numeric] = cutlass.Float8E4M3FN,
497
+ sf_dtype: Type[cutlass.Numeric] = cutlass.Float8E8M0FNU,
498
+ *,
499
+ randn_std: Optional[float] = None,
500
+ seqlens_k: Optional[list] = None,
501
+ ):
502
+ """Generate bf16 randn + quantize for a varlen_k blockscaled GEMM.
503
+
504
+ Per-expert `k_i` must be a multiple of `sf_vec_size` (quantization chunk)
505
+ but NOT necessarily a multiple of `sf_vec_size * 4` (= 128 for MXFP8).
506
+ The SF buffer uses dQaccum-style K padding: each expert `i`'s scales occupy
507
+ `ceildiv(k_i, 128) * 128` bytes worth of K at offset
508
+ `(cu_seqlens_k[i] + i * 128) // 128 * 128` (in source-K units). A and B
509
+ operand data stay packed and unpadded along K — only their SF buffers pad.
510
+
511
+ Returns (a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k):
512
+ a_ref_list: list of per-expert (m, k_i) fp32 dequantized A.
513
+ b_ref_list: list of per-expert (n, k_i) fp32 dequantized B.
514
+ qa: (m, total_k) K-major fp8 (stride (total_k, 1)).
515
+ qb: (n, total_k) K-major fp8 (stride (total_k, 1)).
516
+ a_sc_contig: (1, rm, total_padded_rk, 512) dQaccum-padded SFA.
517
+ b_sc_contig: (1, rn, total_padded_rk, 512) dQaccum-padded SFB.
518
+ cu_seqlens_k: (num_experts+1,) int32.
519
+ """
520
+ if not (
521
+ ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 32
522
+ ):
523
+ raise NotImplementedError(
524
+ f"varlen_k currently only supports MXFP8 (got ab={ab_dtype}, sf={sf_dtype}, "
525
+ f"vec={sf_vec_size}). FP4 is k-major-only and not wired up."
526
+ )
527
+ if seqlens_k is None:
528
+ seqlens_k = [k_per] * num_experts
529
+ assert len(seqlens_k) == num_experts, (
530
+ f"seqlens_k length {len(seqlens_k)} != num_experts {num_experts}"
531
+ )
532
+ for i, k_i in enumerate(seqlens_k):
533
+ assert k_i % sf_vec_size == 0, (
534
+ f"seqlens_k[{i}]={k_i} must be divisible by sf_vec_size={sf_vec_size}"
535
+ )
536
+ total_k = int(sum(seqlens_k))
537
+ std = randn_std if randn_std is not None else (max(seqlens_k)) ** -0.5
538
+ sf_k_total = total_k // sf_vec_size
539
+
540
+ from .mx_utils import to_mx_compiled
541
+
542
+ a_q_list, a_sc_list, a_ref_list = [], [], []
543
+ b_q_list, b_sc_list, b_ref_list = [], [], []
544
+ for k_i in seqlens_k:
545
+ # A slice: (m, k_i) bf16 -> fp8, scales (m, k_i // sf_vec_size).
546
+ a_hp = (torch.randn(m, k_i, dtype=torch.bfloat16, device="cuda") * std).contiguous()
547
+ a_q, a_sc = to_mx_compiled(a_hp, sf_vec_size)
548
+ a_q_list.append(a_q)
549
+ a_sc_list.append(a_sc)
550
+ a_ref_list.append(a_q.float() * a_sc.float().repeat_interleave(sf_vec_size, dim=-1))
551
+
552
+ b_hp = (torch.randn(n, k_i, dtype=torch.bfloat16, device="cuda") * std).contiguous()
553
+ b_q, b_sc = to_mx_compiled(b_hp, sf_vec_size)
554
+ b_q_list.append(b_q)
555
+ b_sc_list.append(b_sc)
556
+ b_ref_list.append(b_q.float() * b_sc.float().repeat_interleave(sf_vec_size, dim=-1))
557
+
558
+ # Pack operand data along K: (m, total_k), (n, total_k). varlen_k's
559
+ # ragged TMA descriptors are built for MN-major operands (stride 1 on
560
+ # M/N), so store M-major A and N-major B.
561
+ # cat gives K-major; transpose → contiguous → transpose to get M-major.
562
+ qa = torch.cat(a_q_list, dim=1).t().contiguous().t() # (m, total_k) stride (1, m)
563
+ qb = torch.cat(b_q_list, dim=1).t().contiguous().t() # (n, total_k) stride (1, n)
564
+ assert qa.stride() == (1, qa.shape[0])
565
+ assert qb.stride() == (1, qb.shape[0])
566
+
567
+ # Pad SFA/SFB per-expert to multiples of 128 source-K (= 4 scales).
568
+ # offset_tile = cu_seqlens[i] // 128 + i (same formula the kernel uses).
569
+ # Allocation = ceildiv(total_k, 128) + (L - 1) tiles (tighter than
570
+ # total_k//128 + L when total_k is a multiple of 128; same otherwise).
571
+ tile = 128 # sf_vec_size * 4
572
+ total_padded_rk = (total_k + tile - 1) // tile + (num_experts - 1)
573
+ total_padded_k = total_padded_rk * tile
574
+ total_padded_sf_k = total_padded_k // sf_vec_size
575
+ sa_2d_padded = torch.zeros(m, total_padded_sf_k, dtype=a_sc_list[0].dtype, device="cuda")
576
+ sb_2d_padded = torch.zeros(n, total_padded_sf_k, dtype=b_sc_list[0].dtype, device="cuda")
577
+ k_offset = 0
578
+ for i, k_i in enumerate(seqlens_k):
579
+ sf_k_i = k_i // sf_vec_size
580
+ k_offset_padded = (k_offset // tile + i) * tile
581
+ sf_k_offset_padded = k_offset_padded // sf_vec_size
582
+ sa_2d_padded[:, sf_k_offset_padded : sf_k_offset_padded + sf_k_i] = a_sc_list[i]
583
+ sb_2d_padded[:, sf_k_offset_padded : sf_k_offset_padded + sf_k_i] = b_sc_list[i]
584
+ k_offset += k_i
585
+
586
+ a_sc_contig = pack_scale_2d_to_blocked_contig(sa_2d_padded.view(1, m, total_padded_sf_k))
587
+ b_sc_contig = pack_scale_2d_to_blocked_contig(sb_2d_padded.view(1, n, total_padded_sf_k))
588
+
589
+ cu_seqlens_k = torch.tensor(
590
+ [0] + list(itertools.accumulate(seqlens_k)), dtype=torch.int32, device="cuda"
591
+ )
592
+ return a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k
593
+
594
+
595
+ def compile_blockscaled_gemm_tvm_ffi(
596
+ ab_dtype: Type[cutlass.Numeric],
597
+ sf_dtype: Type[cutlass.Numeric],
598
+ sf_vec_size: int,
599
+ d_dtype: Type[cutlass.Numeric],
600
+ mma_tiler_mn: Tuple[int, int],
601
+ cluster_shape_mn: Tuple[int, int],
602
+ mA: torch.Tensor,
603
+ mB: torch.Tensor,
604
+ mD: torch.Tensor,
605
+ mSFA: torch.Tensor,
606
+ mSFB: torch.Tensor,
607
+ *,
608
+ use_clc_persistence: bool = True,
609
+ varlen_m: bool = False,
610
+ varlen_k: bool = False,
611
+ ) -> Callable:
612
+ """Compile the SM100 blockscaled GEMM.
613
+
614
+ When varlen_m: mA is (total_m, k) K-major, mD is (total_m, n) N-major,
615
+ mB is (n, k, l); run(...) takes an extra cu_seqlens_m tensor.
616
+ When varlen_k: mA is (m, total_k), mB is (n, total_k), mD is (m, n, l);
617
+ run(...) takes an extra cu_seqlens_k tensor.
618
+ """
619
+ device_capacity = get_device_capacity(mA.device)
620
+ if device_capacity[0] not in (10, 11):
621
+ raise RuntimeError("Blockscaled SM100 GEMM requires SM100/SM110")
622
+ assert not (varlen_m and varlen_k), "Only one of varlen_m / varlen_k"
623
+
624
+ gemm = partial(
625
+ GemmDefaultSm100,
626
+ sf_vec_size=sf_vec_size,
627
+ use_clc_persistence=use_clc_persistence,
628
+ )(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1))
629
+ compile_epi_args = gemm.EpilogueArguments()
630
+ scheduler_args = make_scheduler_args(
631
+ get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
632
+ max_swizzle_size=8,
633
+ tile_count_semaphore=None,
634
+ batch_idx_permute=None,
635
+ )
636
+ stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
637
+
638
+ from .gemm_tvm_ffi_utils import make_fake_varlen_args
639
+
640
+ varlen_args_fake = make_fake_varlen_args(varlen_m, varlen_k, False, None) or VarlenArguments()
641
+
642
+ # Fake operand tensors with sym_ints (varlen-aware shapes).
643
+ if varlen_m:
644
+ total_m_sym = cute.sym_int()
645
+ n_sym, k_sym, l_sym = cute.sym_int(), cute.sym_int(), cute.sym_int()
646
+ # Detect each operand's leading (stride-1) dim so m-major A / n-major B
647
+ # are accepted for varlen_m (MXFP8 only — fp4 is rejected upstream).
648
+ fake_mA = fake_tensor(
649
+ ab_dtype,
650
+ (total_m_sym, k_sym),
651
+ leading_dim=_leading_dim_from_stride(mA),
652
+ divisibility=div_for_dtype(ab_dtype),
653
+ )
654
+ fake_mB = fake_tensor(
655
+ ab_dtype,
656
+ (n_sym, k_sym, l_sym),
657
+ leading_dim=_leading_dim_from_stride(mB),
658
+ divisibility=div_for_dtype(ab_dtype),
659
+ )
660
+ fake_mD = fake_tensor(
661
+ d_dtype,
662
+ (total_m_sym, n_sym),
663
+ leading_dim=_leading_dim_from_stride(mD),
664
+ divisibility=div_for_dtype(d_dtype),
665
+ )
666
+ elif varlen_k:
667
+ total_k_sym = cute.sym_int()
668
+ m_sym, n_sym, l_sym = cute.sym_int(), cute.sym_int(), cute.sym_int()
669
+ # varlen_k uses MN-major A/B convention (stride 1 on M/N axis), but
670
+ # detect from the actual tensor so either layout works.
671
+ fake_mA = fake_tensor(
672
+ ab_dtype,
673
+ (m_sym, total_k_sym),
674
+ leading_dim=_leading_dim_from_stride(mA),
675
+ divisibility=div_for_dtype(ab_dtype),
676
+ )
677
+ fake_mB = fake_tensor(
678
+ ab_dtype,
679
+ (n_sym, total_k_sym),
680
+ leading_dim=_leading_dim_from_stride(mB),
681
+ divisibility=div_for_dtype(ab_dtype),
682
+ )
683
+ fake_mD = fake_tensor(
684
+ d_dtype,
685
+ (m_sym, n_sym, l_sym),
686
+ leading_dim=_leading_dim_from_stride(mD),
687
+ divisibility=div_for_dtype(d_dtype),
688
+ )
689
+ else:
690
+ # Detect each operand's leading (stride-1) dim so m-major A / n-major B
691
+ # are accepted along with the default k-major.
692
+ fake_mA = _make_fake_compact_tensor(
693
+ mA.shape, ab_dtype, leading_dim=_leading_dim_from_stride(mA)
694
+ )
695
+ fake_mB = _make_fake_compact_tensor(
696
+ mB.shape, ab_dtype, leading_dim=_leading_dim_from_stride(mB)
697
+ )
698
+ fake_mD = _make_fake_compact_tensor(
699
+ mD.shape, d_dtype, leading_dim=_leading_dim_from_stride(mD)
700
+ )
701
+
702
+ @cute.jit
703
+ def runner(
704
+ a: cute.Tensor,
705
+ b: cute.Tensor,
706
+ d: cute.Tensor,
707
+ sfa: cute.Tensor,
708
+ sfb: cute.Tensor,
709
+ varlen_args,
710
+ stream,
711
+ ):
712
+ gemm(a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None)
713
+
714
+ compiled = cute.compile(
715
+ runner,
716
+ fake_mA,
717
+ fake_mB,
718
+ fake_mD,
719
+ _make_compile_tensor_like(mSFA, sf_dtype, dynamic_layout=True),
720
+ _make_compile_tensor_like(mSFB, sf_dtype, dynamic_layout=True),
721
+ varlen_args_fake,
722
+ stream,
723
+ options="--enable-tvm-ffi",
724
+ )
725
+
726
+ if varlen_m or varlen_k:
727
+
728
+ def run(a, b, d, sfa, sfb, cu_seqlens):
729
+ varlen_args = VarlenArguments(
730
+ mCuSeqlensM=cu_seqlens if varlen_m else None,
731
+ mCuSeqlensK=cu_seqlens if varlen_k else None,
732
+ )
733
+ compiled(a, b, d, sfa, sfb, varlen_args)
734
+ else:
735
+
736
+ def run(a, b, d, sfa, sfb):
737
+ compiled(a, b, d, sfa, sfb, VarlenArguments())
738
+
739
+ return run
740
+
741
+
742
+ def blockscaled_gemm_reference(
743
+ a_ref: torch.Tensor,
744
+ b_ref: torch.Tensor,
745
+ sfa_ref: torch.Tensor,
746
+ sfb_ref: torch.Tensor,
747
+ ) -> torch.Tensor:
748
+ return torch.einsum(
749
+ "mkl,nkl->mnl",
750
+ torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref),
751
+ torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref),
752
+ )
build/torch-cuda/quack/broadcast_utils.py CHANGED
@@ -11,7 +11,7 @@ from .layout_utils import make_acc_tensor_mn_view
11
  @cute.jit
12
  def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
  if const_expr(tCrC.element_type != Float32): # Convert to f32
14
- tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
15
  tCrC_f32.store(tCrC.load().to(Float32))
16
  else:
17
  tCrC_f32 = tCrC
 
11
  @cute.jit
12
  def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
  if const_expr(tCrC.element_type != Float32): # Convert to f32
14
+ tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32)
15
  tCrC_f32.store(tCrC.load().to(Float32))
16
  else:
17
  tCrC_f32 = tCrC
build/torch-cuda/quack/cache_utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+ """Persistent .o cache for CuTe DSL compiled kernels.
3
+
4
+ Compiled kernels are exported as object files (.o) via export_to_c.
5
+ On subsequent runs the .o is loaded via tvm_ffi (~1ms) instead of
6
+ re-generating IR + re-JIT'ing (~100ms per kernel).
7
+
8
+ Controls:
9
+ QUACK_CACHE_ENABLED=0 — disable persistent .o cache (default: enabled)
10
+ QUACK_CACHE_DIR=path — override default cache directory
11
+ """
12
+
13
+ import fcntl
14
+ import functools
15
+ import hashlib
16
+ import os
17
+ import pickle
18
+ import sys
19
+ import tempfile
20
+ import time
21
+ from collections import namedtuple
22
+ from getpass import getuser
23
+ from pathlib import Path
24
+
25
+ import cutlass
26
+ import cutlass.cute as cute
27
+ import tvm_ffi
28
+
29
+ CACHE_ENABLED: bool = os.getenv("QUACK_CACHE_ENABLED", "1") == "1"
30
+ CACHE_DIR: str | None = os.getenv("QUACK_CACHE_DIR", None)
31
+ COMPILE_ONLY: bool = False
32
+
33
+ # Downstream projects can append directories here to include their sources
34
+ # in the cache fingerprint. Must be set before the first jit_cache call.
35
+ EXTRA_SOURCE_DIRS: list[Path] = []
36
+
37
+ EXPORT_FUNC_NAME = "func"
38
+ LOCK_TIMEOUT = 60
39
+ CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
40
+
41
+
42
+ def _noop_kernel(*args, **kwargs):
43
+ pass
44
+
45
+
46
+ def get_cache_path() -> Path:
47
+ if CACHE_DIR is not None:
48
+ cache_dir = Path(CACHE_DIR)
49
+ else:
50
+ cache_dir = Path(tempfile.gettempdir()) / getuser() / "quack_cache"
51
+ cache_dir.mkdir(parents=True, exist_ok=True)
52
+ return cache_dir
53
+
54
+
55
+ def _hash_source_dir(h, root: Path) -> None:
56
+ """Hash all Python sources under *root* into *h*."""
57
+ for src in sorted(root.rglob("*.py")):
58
+ if not src.is_file():
59
+ continue
60
+ h.update(src.relative_to(root).as_posix().encode())
61
+ content = src.read_bytes()
62
+ h.update(len(content).to_bytes(8, "little"))
63
+ h.update(content)
64
+
65
+
66
+ @functools.lru_cache(maxsize=1)
67
+ def _compute_source_fingerprint() -> str:
68
+ """Hash quack + extra source dirs plus runtime ABI stamps into a fingerprint."""
69
+ h = hashlib.sha256()
70
+ h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode())
71
+ h.update(f"cutlass={cutlass.__version__}".encode())
72
+ h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
73
+ _hash_source_dir(h, Path(__file__).resolve().parent)
74
+ for extra_dir in EXTRA_SOURCE_DIRS:
75
+ _hash_source_dir(h, Path(extra_dir).resolve())
76
+ return h.hexdigest()
77
+
78
+
79
+ def _key_to_hash(key: tuple) -> str:
80
+ return hashlib.sha256(pickle.dumps(key)).hexdigest()
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # File locking
85
+ # ---------------------------------------------------------------------------
86
+
87
+
88
+ class FileLock:
89
+ """Advisory file lock using fcntl.flock with timeout."""
90
+
91
+ def __init__(self, lock_path: Path, exclusive: bool, timeout: float = 15):
92
+ self.lock_path = lock_path
93
+ self.exclusive = exclusive
94
+ self.timeout = timeout
95
+ self._fd: int = -1
96
+
97
+ def __enter__(self) -> "FileLock":
98
+ flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
99
+ lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
100
+ self._fd = os.open(str(self.lock_path), flags)
101
+ deadline = time.monotonic() + self.timeout
102
+ while time.monotonic() < deadline:
103
+ try:
104
+ fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB)
105
+ return self
106
+ except OSError:
107
+ time.sleep(0.1)
108
+ os.close(self._fd)
109
+ self._fd = -1
110
+ raise RuntimeError(f"Timed out waiting for lock: {self.lock_path}")
111
+
112
+ def __exit__(self, *exc) -> None:
113
+ if self._fd >= 0:
114
+ fcntl.flock(self._fd, fcntl.LOCK_UN)
115
+ os.close(self._fd)
116
+ self._fd = -1
117
+
118
+
119
+ # ---------------------------------------------------------------------------
120
+ # JIT cache decorator
121
+ # ---------------------------------------------------------------------------
122
+
123
+
124
+ def jit_cache(fn):
125
+ """Decorator that caches compiled CuTe DSL kernels in-memory and on disk.
126
+
127
+ The decorated function should return a compiled kernel (i.e. call cute.compile).
128
+ The disk cache key is (fn.__qualname__, *args, **sorted_kwargs).
129
+ """
130
+ cache = {}
131
+ hits = 0
132
+ misses = 0
133
+
134
+ @functools.wraps(fn)
135
+ def wrapper(*args, **kwargs):
136
+ nonlocal hits, misses
137
+ cache_key = args + tuple(sorted(kwargs.items())) if kwargs else args
138
+
139
+ # 1. In-memory hit
140
+ if cache_key in cache:
141
+ hits += 1
142
+ return _noop_kernel if COMPILE_ONLY else cache[cache_key]
143
+
144
+ # 2. Disk hit
145
+ disk_key = (fn.__qualname__,) + cache_key
146
+ if CACHE_ENABLED:
147
+ sha = _key_to_hash(disk_key)
148
+ cache_path = get_cache_path() / _compute_source_fingerprint()
149
+ cache_path.mkdir(parents=True, exist_ok=True)
150
+ o_path = cache_path / f"{sha}.o"
151
+ lock_path = cache_path / f"{sha}.lock"
152
+ try:
153
+ with FileLock(lock_path, exclusive=False, timeout=LOCK_TIMEOUT):
154
+ if o_path.exists():
155
+ m = cute.runtime.load_module(str(o_path), enable_tvm_ffi=True)
156
+ loaded = m[EXPORT_FUNC_NAME]
157
+ cache[cache_key] = loaded
158
+ hits += 1
159
+ return _noop_kernel if COMPILE_ONLY else loaded
160
+ except RuntimeError:
161
+ pass
162
+
163
+ # 3. Compile
164
+ misses += 1
165
+ compiled_fn = fn(*args, **kwargs)
166
+
167
+ # 4. Store
168
+ cache[cache_key] = compiled_fn
169
+ if CACHE_ENABLED:
170
+ try:
171
+ with FileLock(lock_path, exclusive=True, timeout=LOCK_TIMEOUT):
172
+ if not o_path.exists():
173
+ o_path.parent.mkdir(parents=True, exist_ok=True)
174
+ compiled_fn.export_to_c(
175
+ object_file_path=str(o_path),
176
+ function_name=EXPORT_FUNC_NAME,
177
+ )
178
+ except Exception as e:
179
+ print(f"quack cache: export failed for key {sha}: {e}")
180
+
181
+ return _noop_kernel if COMPILE_ONLY else compiled_fn
182
+
183
+ def cache_clear():
184
+ nonlocal hits, misses
185
+ cache.clear()
186
+ hits = 0
187
+ misses = 0
188
+
189
+ def cache_info():
190
+ return CacheInfo(hits=hits, misses=misses, maxsize=None, currsize=len(cache))
191
+
192
+ wrapper.cache = cache
193
+ wrapper.cache_clear = cache_clear
194
+ wrapper.cache_info = cache_info
195
+ return wrapper
build/torch-cuda/quack/copy_utils.py CHANGED
@@ -1,15 +1,25 @@
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
 
3
- import re
4
- from typing import Optional, Type, Tuple, Callable
5
 
6
  import cutlass
7
  import cutlass.cute as cute
8
 
9
- from cutlass import Int32, Boolean, const_expr
10
- from cutlass.cute.nvgpu import cpasync, warpgroup
 
11
  from cutlass.cutlass_dsl import dsl_user_op
12
  import cutlass.pipeline
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  @dsl_user_op
@@ -26,7 +36,7 @@ def cvt_copy(
26
  ) -> None:
27
  assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
28
  if const_expr(src.element_type != dst.element_type):
29
- src_cvt = cute.make_fragment_like(src, dst.element_type)
30
  src_cvt.store(src.load().to(dst.element_type))
31
  src = src_cvt
32
  if const_expr(retile):
@@ -34,9 +44,33 @@ def cvt_copy(
34
  cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @dsl_user_op
38
  def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
39
- dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
40
  cute.autovec_copy(src, dst, loc=loc, ip=ip)
41
  return dst
42
 
@@ -52,13 +86,23 @@ def load_s2r_retile(
52
  ) -> cute.Tensor:
53
  # Will also accept dst_shape being a tensor, in which case we write into that tensor
54
  if const_expr(not isinstance(dst_shape, cute.Tensor)):
55
- dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
56
  else:
57
  dst = dst_shape
58
  cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
59
  return dst
60
 
61
 
 
 
 
 
 
 
 
 
 
 
62
  @dsl_user_op
63
  def get_copy_atom(
64
  dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
@@ -117,7 +161,7 @@ def tiled_copy_2d(
117
  @cute.jit
118
  def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
119
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
120
- tApA = cute.make_fragment(
121
  cute.make_layout(
122
  (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
123
  stride=(cute.size(tAcA, mode=[2]), 0, 1),
@@ -147,28 +191,108 @@ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
147
  # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
148
 
149
 
150
- def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
151
- """Extract swizzle parameters from a pointer's swizzle_type.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
154
- b, m, s are the swizzle parameters (bits, base, shift).
155
 
156
- Returns:
157
- A cute.Swizzle object constructed from the extracted parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- Raises:
160
- ValueError: If the swizzle_type string cannot be parsed
161
- """
162
- # Ideally there should be a better API to get swizzle parameters, but we'll just parse
163
- # the string here.
164
- swizzle_str = str(ptr.type.swizzle_type)
165
- # Extract the inner part "S<b,m,s>"
166
- match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
167
- if match:
168
- b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
169
- return b, m, s
 
 
 
 
 
 
 
 
 
 
 
170
  else:
171
- raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
 
 
 
 
172
 
173
 
174
  def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
@@ -178,15 +302,16 @@ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
178
 
179
 
180
  def swizzle_ptr(ptr: cute.Pointer):
181
- b, m, s = parse_swizzle_from_pointer(ptr)
182
- ptr_int = swizzle_int(ptr.toint(), b, m, s)
183
  return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
184
 
185
 
186
  def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
187
  outer = tensor.layout
188
  width = tensor.element_type.width
189
- inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
 
190
  # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
191
  # for 16 bits and <3, 2, 3> for 32 bits)
192
  new_layout = cute.recast_layout(
@@ -242,15 +367,16 @@ def sm90_get_smem_load_op(
242
  raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
243
  is_m_major = layout_c.is_m_major_c()
244
  if elem_ty_c.width == 16:
245
- return cute.make_copy_atom(
246
- cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
247
- )
248
  else:
249
  return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
250
 
251
 
252
  def get_smem_store_atom(
253
- arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
 
 
 
254
  ) -> cute.CopyAtom:
255
  if const_expr(arch < 90 or element_type.width != 16):
256
  return cute.make_copy_atom(
@@ -259,14 +385,22 @@ def get_smem_store_atom(
259
  num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
260
  )
261
  else:
 
 
 
 
 
262
  return cute.make_copy_atom(
263
- cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
264
  element_type,
265
  )
266
 
267
 
268
  def get_smem_load_atom(
269
- arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
 
 
 
270
  ) -> cute.CopyAtom:
271
  if const_expr(arch < 90 or element_type.width != 16):
272
  return cute.make_copy_atom(
@@ -275,8 +409,13 @@ def get_smem_load_atom(
275
  num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
276
  )
277
  else:
 
 
 
 
 
278
  return cute.make_copy_atom(
279
- cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
280
  element_type,
281
  )
282
 
@@ -288,9 +427,10 @@ def get_smem_store_C(
288
  arch: int,
289
  transpose: bool = False,
290
  position_independent=False,
 
291
  ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
292
  dtype = sC.element_type
293
- copy_atom = get_smem_store_atom(arch, dtype, transpose)
294
  tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
295
  thr_copy = tiled_copy.get_slice(tidx)
296
  if const_expr(not position_independent):
@@ -298,8 +438,9 @@ def get_smem_store_C(
298
  else:
299
  tRS_sC = partition_D_position_independent(thr_copy, sC)
300
 
301
- def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
302
- cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
 
303
 
304
  return copy_fn, thr_copy, tRS_sC
305
 
@@ -324,14 +465,55 @@ def get_smem_load_C(
324
  thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
325
  tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
326
 
327
- def copy_fn(src_idx: Int32, **new_kwargs):
328
- return load_s2r_retile(
329
- tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
330
- )
331
 
332
  return copy_fn, thr_copy, tSR_sC
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  def get_smem_store_A(
336
  tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
337
  ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
@@ -368,8 +550,6 @@ def get_smem_load_A(
368
  tSR_sA = thr_copy.partition_S(sA)
369
  else:
370
  tSR_sA = partition_S_position_independent(thr_copy, sA)
371
- copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
372
- thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
373
  tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
374
 
375
  def copy_fn(src_idx: Int32, **new_kwargs):
@@ -383,6 +563,195 @@ def get_smem_load_A(
383
  return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
384
 
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  def tma_get_copy_fn(
387
  atom: cute.CopyAtom,
388
  cta_coord: cute.Coord,
@@ -391,6 +760,9 @@ def tma_get_copy_fn(
391
  dst_tensor: cute.Tensor,
392
  filter_zeros: bool = False,
393
  single_stage: bool = False,
 
 
 
394
  **kwargs,
395
  ) -> Callable:
396
  src_is_smem = const_expr(
@@ -407,17 +779,23 @@ def tma_get_copy_fn(
407
  cta_layout,
408
  cute.group_modes(smem_tensor, 0, group_rank_smem),
409
  cute.group_modes(gmem_tensor, 0, group_rank_gmem),
 
 
410
  )
411
  if const_expr(filter_zeros):
412
  s = cute.filter_zeros(s)
413
  g = cute.filter_zeros(g)
414
  src, dst = (s, g) if src_is_smem else (g, s)
415
 
416
- def copy_tma(src_idx, dst_idx, **new_kwargs):
417
- cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
 
 
 
418
 
419
- def copy_tma_single_stage(**new_kwargs):
420
- cute.copy(atom, src, dst, **new_kwargs, **kwargs)
 
421
 
422
  return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
423
 
@@ -438,22 +816,22 @@ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsyn
438
  def gather_m_get_copy_fn(
439
  thr_copy_A: cute.ThrCopy,
440
  mA: cute.Tensor, # (whatever, K)
441
- sA: cute.Tensor, # (tile_M, tile_N, STAGE)
442
  gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
443
  limit_m: Int32,
444
  limit_k: Int32,
445
  ) -> Callable:
446
- tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
447
- tAsA = thr_copy_A.partition_D(sA)
448
  # k-major
449
  assert tAsA.shape[2] == 1
450
  tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
451
 
452
- is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
453
  if const_expr(not is_even_m_smem):
454
- limit_m = min(limit_m, tile_shape_mk[0])
455
  elems_per_load = cute.size(tAsA.shape[0][0])
456
- cA = cute.make_identity_tensor(tile_shape_mk)
457
  tAcA = thr_copy_A.partition_S(cA)
458
  t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
459
  # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
@@ -464,10 +842,10 @@ def gather_m_get_copy_fn(
464
  # Read and cache indices for A
465
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
466
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
467
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
468
  for m in cutlass.range(rows_per_thread, unroll_full=True):
469
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
470
- m_idx = cute.make_fragment(rows_per_thread, Int32)
471
  for m in cutlass.range(rows_per_thread, unroll_full=True):
472
  row_idx = tAcA[0, m, 0][0]
473
  if tApA_m[m]:
@@ -475,13 +853,13 @@ def gather_m_get_copy_fn(
475
  else:
476
  m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
477
 
478
- mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
479
 
480
  def copy_fn(src_idx, dst_idx, pred: bool = False):
481
  tApA_k = None
482
  if const_expr(pred):
483
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
484
- limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
485
  for k in cutlass.range(cols_per_thread, unroll_full=True):
486
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
487
  mA_cur = mA_k[None, (None, src_idx)]
@@ -506,7 +884,7 @@ def gather_m_get_copy_fn(
506
  def gather_k_get_copy_fn(
507
  thr_copy_A: cute.ThrCopy,
508
  mA: cute.Tensor, # (tile_M, whatever)
509
- sA: cute.Tensor, # (tile_M, tile_N, STAGE)
510
  gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
511
  limit_m: Int32,
512
  limit_k: Int32,
@@ -538,7 +916,7 @@ def gather_k_get_copy_fn(
538
  # Read and cache indices for A
539
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
540
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
541
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
542
  for m in cutlass.range(rows_per_thread, unroll_full=True):
543
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
544
  threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
@@ -554,12 +932,12 @@ def gather_k_get_copy_fn(
554
  # Prefetch mAIdx early, even before smem is free
555
  tApA_k = None
556
  if const_expr(pred):
557
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
558
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
559
  for k in cutlass.range(cols_per_thread, unroll_full=True):
560
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
561
  gAIdx_cur = gAIdx[None, src_idx]
562
- k_idx = cute.make_fragment(cols_per_thread, Int32)
563
  for k in cutlass.range(cols_per_thread):
564
  col_idx = tAcA[0, 0, k][1]
565
  if const_expr(not pred):
@@ -576,13 +954,13 @@ def gather_k_get_copy_fn(
576
  ) -> Tuple[cute.Tensor, cute.Tensor]:
577
  tApA_k = None
578
  if const_expr(pred):
579
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
580
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
581
  for k in cutlass.range(cols_per_thread, unroll_full=True):
582
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
583
  a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
584
  sAIdx_cur = sAIdx[None, dst_idx]
585
- k_idx = cute.make_fragment(cols_per_thread, Int32)
586
  for k in cutlass.range(cols_per_thread):
587
  col_idx = tAcA[0, 0, k][1]
588
  k_idx[k] = sAIdx_cur[col_idx]
@@ -612,3 +990,194 @@ def gather_k_get_copy_fn(
612
  return copy_fn, prefetch_from_gmem_fn if const_expr(
613
  gAIdx is not None
614
  ) else prefetch_from_smem_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
 
3
+ from typing import Optional, Type, Tuple, Callable, Sequence
4
+ from functools import partial
5
 
6
  import cutlass
7
  import cutlass.cute as cute
8
 
9
+ from cutlass import Int32, Int16, Boolean, const_expr
10
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
11
+ from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa
12
  from cutlass.cutlass_dsl import dsl_user_op
13
  import cutlass.pipeline
14
+ from cutlass._mlir.dialects import llvm
15
+ from cutlass._mlir import ir
16
+ from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
17
+
18
+ from . import layout_utils
19
+ from .utils import make_vector
20
+
21
+
22
+ Sm100MmaPeerBitMask = 0xFEFFFFFF
23
 
24
 
25
  @dsl_user_op
 
36
  ) -> None:
37
  assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
38
  if const_expr(src.element_type != dst.element_type):
39
+ src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
40
  src_cvt.store(src.load().to(dst.element_type))
41
  src = src_cvt
42
  if const_expr(retile):
 
44
  cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
45
 
46
 
47
+ @dsl_user_op
48
+ def sr_cvt_copy(
49
+ tiled_copy: cute.TiledCopy,
50
+ src: cute.Tensor,
51
+ dst: cute.Tensor,
52
+ seed: Int32,
53
+ tidx: Int32,
54
+ *,
55
+ loc=None,
56
+ ip=None,
57
+ ) -> None:
58
+ """Like cvt_copy but uses stochastic rounding for FP32 -> BF16 conversion."""
59
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
60
+ from .rounding import convert_f32_to_bf16_sr
61
+ from cutlass.cute.tensor import TensorSSA
62
+
63
+ src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
64
+ src_vec = src.load()
65
+ raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx, loc=loc, ip=ip)
66
+ src_cvt.store(TensorSSA(raw_vec, src_vec.shape, dst.element_type))
67
+ src = src_cvt
68
+ cute.copy(tiled_copy, src, dst, loc=loc, ip=ip)
69
+
70
+
71
  @dsl_user_op
72
  def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
73
+ dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
74
  cute.autovec_copy(src, dst, loc=loc, ip=ip)
75
  return dst
76
 
 
86
  ) -> cute.Tensor:
87
  # Will also accept dst_shape being a tensor, in which case we write into that tensor
88
  if const_expr(not isinstance(dst_shape, cute.Tensor)):
89
+ dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
90
  else:
91
  dst = dst_shape
92
  cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
93
  return dst
94
 
95
 
96
+ @dsl_user_op
97
+ def load_t2r(
98
+ thr_copy: cute.ThrCopy, shape: cute.Shape, src: cute.Tensor, *, loc=None, ip=None
99
+ ) -> cute.Tensor:
100
+ cDst = cute.make_identity_tensor(shape)
101
+ dst = cute.make_rmem_tensor(thr_copy.partition_D(cDst).shape, src.element_type, loc=loc, ip=ip)
102
+ cute.copy(thr_copy, src, dst, loc=loc, ip=ip)
103
+ return dst
104
+
105
+
106
  @dsl_user_op
107
  def get_copy_atom(
108
  dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
 
161
  @cute.jit
162
  def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
163
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
164
+ tApA = cute.make_rmem_tensor(
165
  cute.make_layout(
166
  (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
167
  stride=(cute.size(tAcA, mode=[2]), 0, 1),
 
191
  # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
192
 
193
 
194
+ # Ragged tensor trick for TMA: encodes variable-length sequences into a higher-rank
195
+ # tensor so that TMA's out-of-bounds checking handles sequence boundaries.
196
+ #
197
+ # Given a tensor T with a ragged dimension (variable-length across batches), we create
198
+ # a higher-rank tensor where the ragged dim is replaced with a fixed size `big_int`, and
199
+ # extra dim(s) are appended. When indexing into a specific sequence at (offset, length),
200
+ # `offset_ragged_tensor` computes coordinates such that:
201
+ # ragged_coord = big_int - length (OOB check clamps reads past the sequence end)
202
+ # extra_coord(s) = f(offset, length) (selects the correct memory region)
203
+ #
204
+ # ptr_shift=True: 1-extra-dim approach (adds 1 dim, supports up to 4D input):
205
+ # Shape: (*before, big_int, *after, max_int)
206
+ # Stride: (*original_strides, stride_r) where stride_r = T.stride[ragged_dim]
207
+ # Pointer shifted backward by big_int * stride_r elements.
208
+ # Address for coords (big_int - length) in ragged dim, (offset + length) in extra dim:
209
+ # addr = (base - big_int * s_r) + (big_int - length) * s_r + (offset + length) * s_r
210
+ # = base + offset * s_r [correct]
211
+ # Works for epilogue TMA store. Does NOT work for TMA load with large big_int
212
+ # — the shifted pointer must land in physically mapped GPU memory.
213
+ #
214
+ # ptr_shift=False: 2-extra-dim approach (adds 2 dims, supports up to 3D input):
215
+ # Shape: (*before, big_int, *after, max_int, max_int)
216
+ # Stride: (*before_strides, stride_r, *after_strides, 2^34 - stride_r, stride_r)
217
+ # No pointer shift. Uses 64-bit address wraparound to cancel the ragged offset.
218
+ # Let W = 2^34 - stride_r. Address for coords (big_int - length) in ragged dim,
219
+ # big_int in extra dim 0, (offset + length) in extra dim 1:
220
+ # addr = base + (big_int - length) * s_r + big_int * W + (offset + length) * s_r
221
+ # = base + big_int * (s_r + W) - length * s_r + (offset + length) * s_r
222
+ # = base + big_int * 2^34 + offset * s_r
223
+ # Since big_int = 2^30: big_int * 2^34 = 2^64 ≡ 0 (mod 2^64), so:
224
+ # addr = base + offset * s_r [correct]
225
+ # Works for all TMA paths since the base pointer is never shifted.
226
+ #
227
+ # Ragged tensor was adapted from the implementation from Triton, but here we have an option that
228
+ # only needs 1 extra dimension instead of 2.
229
+ # https://github.com/triton-lang/triton/blob/main/python/triton/tools/ragged_tma.py
230
+ BIG_INT = 2**30
231
+ MAX_INT = 2**31 - 1
232
+ BIG_INT_INV = 2**64 // BIG_INT
233
 
 
 
234
 
235
+ @dsl_user_op
236
+ def create_ragged_tensor_for_tma(
237
+ T: cute.Tensor,
238
+ ragged_dim: int = 0,
239
+ ptr_shift: bool = False,
240
+ *,
241
+ loc=None,
242
+ ip=None,
243
+ ) -> cute.Tensor:
244
+ rank = cute.rank(T)
245
+ if ragged_dim < 0:
246
+ ragged_dim += rank
247
+ if ptr_shift:
248
+ assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions"
249
+ new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,)
250
+ new_stride = T.stride + (T.stride[ragged_dim],)
251
+ ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1)
252
+ new_ptr = cute.domain_offset(ptr_offset, T).iterator
253
+ return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride))
254
+ else:
255
+ assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions"
256
+ stride_r = T.stride[ragged_dim]
257
+ new_shape = (
258
+ T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT)
259
+ )
260
+ new_stride = (
261
+ T.stride[:ragged_dim]
262
+ + (stride_r,)
263
+ + T.stride[ragged_dim + 1 :]
264
+ + (BIG_INT_INV - stride_r, stride_r)
265
+ )
266
+ return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride))
267
 
268
+
269
+ @dsl_user_op
270
+ def offset_ragged_tensor(
271
+ T: cute.Tensor,
272
+ offset: Int32,
273
+ length: Int32,
274
+ ragged_dim: int = 0,
275
+ ptr_shift: bool = False,
276
+ *,
277
+ loc=None,
278
+ ip=None,
279
+ ) -> cute.Tensor:
280
+ rank = cute.rank(T)
281
+ if ragged_dim < 0:
282
+ ragged_dim += rank
283
+ big_int = cute.size(T, mode=[ragged_dim])
284
+ offset_val = big_int - length
285
+ if ptr_shift:
286
+ # 1-extra-dim: rank = original_rank + 1
287
+ assert rank >= ragged_dim + 2
288
+ offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2)
289
+ index_tuple = (None,) * (rank - 1) + (offset + length,)
290
  else:
291
+ # 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims
292
+ assert rank >= ragged_dim + 3
293
+ offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3)
294
+ index_tuple = (None,) * (rank - 2) + (big_int, offset + length)
295
+ return cute.domain_offset(offset_tuple, T[index_tuple])
296
 
297
 
298
  def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
 
302
 
303
 
304
  def swizzle_ptr(ptr: cute.Pointer):
305
+ swz = ptr.type.swizzle_type
306
+ ptr_int = swizzle_int(ptr.toint(), swz.num_bits, swz.num_base, swz.num_shift)
307
  return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
308
 
309
 
310
  def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
311
  outer = tensor.layout
312
  width = tensor.element_type.width
313
+ swizzle_type = tensor.iterator.type.swizzle_type
314
+ inner = cute.make_swizzle(swizzle_type.num_bits, swizzle_type.num_base, swizzle_type.num_shift)
315
  # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
316
  # for 16 bits and <3, 2, 3> for 32 bits)
317
  new_layout = cute.recast_layout(
 
367
  raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
368
  is_m_major = layout_c.is_m_major_c()
369
  if elem_ty_c.width == 16:
370
+ return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
 
 
371
  else:
372
  return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
373
 
374
 
375
  def get_smem_store_atom(
376
+ arch: cutlass.Constexpr[int],
377
+ element_type: Type[cute.Numeric],
378
+ transpose: bool = False,
379
+ major_mode_size: Optional[int] = None,
380
  ) -> cute.CopyAtom:
381
  if const_expr(arch < 90 or element_type.width != 16):
382
  return cute.make_copy_atom(
 
385
  num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
386
  )
387
  else:
388
+ num_matrices = (
389
+ 4
390
+ if major_mode_size is None or major_mode_size % 16 == 0
391
+ else (2 if major_mode_size % 8 == 0 else 1)
392
+ )
393
  return cute.make_copy_atom(
394
+ warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
395
  element_type,
396
  )
397
 
398
 
399
  def get_smem_load_atom(
400
+ arch: cutlass.Constexpr[int],
401
+ element_type: Type[cute.Numeric],
402
+ transpose: bool = False,
403
+ major_mode_size: Optional[int] = None,
404
  ) -> cute.CopyAtom:
405
  if const_expr(arch < 90 or element_type.width != 16):
406
  return cute.make_copy_atom(
 
409
  num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
410
  )
411
  else:
412
+ num_matrices = (
413
+ 4
414
+ if major_mode_size is None or major_mode_size % 16 == 0
415
+ else (2 if major_mode_size % 8 == 0 else 1)
416
+ )
417
  return cute.make_copy_atom(
418
+ warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
419
  element_type,
420
  )
421
 
 
427
  arch: int,
428
  transpose: bool = False,
429
  position_independent=False,
430
+ major_mode_size: Optional[int] = None,
431
  ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
432
  dtype = sC.element_type
433
+ copy_atom = get_smem_store_atom(arch, dtype, transpose, major_mode_size=major_mode_size)
434
  tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
435
  thr_copy = tiled_copy.get_slice(tidx)
436
  if const_expr(not position_independent):
 
438
  else:
439
  tRS_sC = partition_D_position_independent(thr_copy, sC)
440
 
441
+ def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs):
442
+ dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx]
443
+ cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs)
444
 
445
  return copy_fn, thr_copy, tRS_sC
446
 
 
465
  thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
466
  tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
467
 
468
+ def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs):
469
+ src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx]
470
+ return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs)
 
471
 
472
  return copy_fn, thr_copy, tSR_sC
473
 
474
 
475
+ def epilog_smem_copy_atom(
476
+ tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False
477
+ ) -> cute.TiledCopy:
478
+ copy_atom_C = cute.make_copy_atom(
479
+ warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2),
480
+ cutlass.Float16, # this is just to get the right source layout
481
+ )
482
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
483
+ return tiled_copy_C_atom
484
+
485
+
486
+ def get_smem_store_epi(
487
+ tiled_mma: cute.TiledMma,
488
+ epi_tile: cute.Shape,
489
+ sC: Optional[cute.Tensor],
490
+ tidx: Int32,
491
+ arch: int,
492
+ transpose: bool = False,
493
+ position_independent=False,
494
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]:
495
+ dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16
496
+ tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile)
497
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
498
+ tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom)
499
+ thr_copy = tiled_copy.get_slice(tidx)
500
+ tRS_sC = None
501
+ if const_expr(sC is not None):
502
+ if const_expr(not position_independent):
503
+ tRS_sC = thr_copy.partition_D(sC)
504
+ else:
505
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
506
+ sC_shape = sC.shape[:2] if sC is not None else epi_tile
507
+ # (R2S, R2S_M, R2S_N, PIPE_C)
508
+ tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape
509
+ tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype)
510
+
511
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
512
+ cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs)
513
+
514
+ return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC
515
+
516
+
517
  def get_smem_store_A(
518
  tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
519
  ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
 
550
  tSR_sA = thr_copy.partition_S(sA)
551
  else:
552
  tSR_sA = partition_S_position_independent(thr_copy, sA)
 
 
553
  tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
554
 
555
  def copy_fn(src_idx: Int32, **new_kwargs):
 
563
  return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
564
 
565
 
566
+ @dsl_user_op
567
+ def cpasync_reduce_bulk_add_f32(
568
+ smem_ptr: cute.Pointer,
569
+ gmem_ptr: cute.Pointer,
570
+ store_bytes: int | Int32,
571
+ *,
572
+ loc=None,
573
+ ip=None,
574
+ ):
575
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
576
+ # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
577
+ llvm.inline_asm(
578
+ None,
579
+ [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
580
+ "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
581
+ "l,r,r",
582
+ # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
583
+ # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
584
+ # "l,r,r,l",
585
+ has_side_effects=True,
586
+ is_align_stack=False,
587
+ )
588
+
589
+
590
+ @dsl_user_op
591
+ def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
592
+ """
593
+ Get the address of the TMA descriptor embedded in a TMA Copy Atom.
594
+
595
+ Extracts the constant memory address of the TMA descriptor for use with
596
+ custom PTX instructions.
597
+
598
+ :param tma_atom: TMA Copy Atom from make_tiled_tma_atom
599
+ :return: Pointer to TMA descriptor in constant memory
600
+
601
+ Example:
602
+ >>> desc_ptr = get_tma_descriptor_address(tma_atom)
603
+ """
604
+ exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
605
+ tma_desc_ptr_type = ir.Type.parse(
606
+ "!cute.ptr<!cute_nvgpu.tma_descriptor_tiled, generic, align<128>>"
607
+ )
608
+ return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip)
609
+
610
+
611
+ @dsl_user_op
612
+ def tma_gather4_load(
613
+ tma_desc_ptr: cute.Pointer,
614
+ dst_smem_ptr: cute.Pointer,
615
+ mbarrier_ptr: cute.Pointer,
616
+ col_idx: Int32,
617
+ row_indices: Sequence[Int32],
618
+ *,
619
+ num_cta: int = 1,
620
+ multicast_mask=None,
621
+ loc=None,
622
+ ip=None,
623
+ ) -> None:
624
+ """
625
+ Perform TMA gather4 load from global memory to shared memory.
626
+
627
+ Issues PTX instruction:
628
+ cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
629
+ [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar];
630
+
631
+ This loads 4 rows (specified by row_indices) from a 2D tensor at the given
632
+ column index into shared memory, using the TMA descriptor.
633
+
634
+ :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned)
635
+ :type tma_desc_ptr: Pointer
636
+ :param dst_smem_ptr: Destination address in shared memory
637
+ :type dst_smem_ptr: Pointer
638
+ :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking
639
+ :type mbarrier_ptr: Pointer
640
+ :param col_idx: Column index
641
+ :type col_idx: Int32
642
+ :param row_indices: Sequence of exactly 4 row indices
643
+ :type row_indices: Sequence[Int32]
644
+ :param num_cta: Number of CTAs participating (default: 1)
645
+ :type num_cta: int
646
+ :param multicast_mask: Optional multicast mask
647
+ :type multicast_mask: Int16
648
+
649
+ Requirements:
650
+ - row_indices must contain exactly 4 elements
651
+ - Compute capability >= SM_100 (Blackwell)
652
+ - TMA descriptor must be properly initialized for 2D tensor
653
+
654
+ Example:
655
+ >>> from cutlass.cute.nvgpu import cpasync
656
+ >>> from cutlass.cute import core
657
+ >>>
658
+ >>> # Create TMA descriptor
659
+ >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...)
660
+ >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom)
661
+ >>>
662
+ >>> # Compute indices (typically from kernel logic)
663
+ >>> col_idx = core.get(...) or 5 # Int32 value
664
+ >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values
665
+ >>>
666
+ >>> # Gather 4 rows at computed column
667
+ >>> tma_gather4_load(
668
+ ... tma_desc_ptr=tma_desc_ptr,
669
+ ... dst_smem_ptr=smem_ptr,
670
+ ... mbarrier_ptr=barrier_ptr,
671
+ ... col_idx=col_idx,
672
+ ... row_indices=row_indices
673
+ ... )
674
+ """
675
+ if len(row_indices) != 4:
676
+ raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}")
677
+ col_val = Int32(col_idx).ir_value()
678
+ row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices]
679
+ # Convert pointers to integer addresses
680
+ desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
681
+ dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value()
682
+ mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip)
683
+ if num_cta > 1:
684
+ # Executed by both CTAs. Set peer bit to 0 so that the
685
+ # transaction bytes will update CTA0's barrier.
686
+ mbar_addr = mbar_addr & Sm100MmaPeerBitMask
687
+ mbar_addr = mbar_addr.ir_value()
688
+ # Handle multicast_mask - may already be ir.Value or Python int
689
+ multicast_mask_val = None
690
+ if multicast_mask is not None:
691
+ multicast_mask_val = Int16(multicast_mask).ir_value()
692
+ assert multicast_mask_val is None, "multicast is not supported yet"
693
+ # Emit inline PTX for TMA gather4
694
+ # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
695
+ # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar];
696
+ ptx = (
697
+ f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} "
698
+ "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];"
699
+ )
700
+
701
+ llvm.inline_asm(
702
+ None,
703
+ [
704
+ dst_addr,
705
+ desc_addr,
706
+ col_val,
707
+ row_vals[0],
708
+ row_vals[1],
709
+ row_vals[2],
710
+ row_vals[3],
711
+ mbar_addr,
712
+ ],
713
+ ptx,
714
+ "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register
715
+ has_side_effects=True,
716
+ is_align_stack=False,
717
+ loc=loc,
718
+ ip=ip,
719
+ )
720
+
721
+
722
+ def cpasync_bulk_get_copy_fn(
723
+ src_tensor: cute.Tensor,
724
+ dst_tensor: cute.Tensor,
725
+ single_stage: bool = False,
726
+ **kwargs,
727
+ ) -> Callable:
728
+ group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
729
+ group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
730
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
731
+ src = cute.group_modes(src_tensor, 0, group_rank_src)
732
+ dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
733
+
734
+ def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs):
735
+ atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
736
+ with cute.arch.elect_one():
737
+ cute.copy(
738
+ atom,
739
+ src[None, src_idx],
740
+ dst[None, dst_idx],
741
+ mbar_ptr=tma_bar_ptr,
742
+ **new_kwargs,
743
+ **kwargs,
744
+ )
745
+
746
+ def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs):
747
+ atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
748
+ with cute.arch.elect_one():
749
+ cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs)
750
+
751
+ return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
752
+
753
+
754
+ @dsl_user_op
755
  def tma_get_copy_fn(
756
  atom: cute.CopyAtom,
757
  cta_coord: cute.Coord,
 
760
  dst_tensor: cute.Tensor,
761
  filter_zeros: bool = False,
762
  single_stage: bool = False,
763
+ *,
764
+ loc=None,
765
+ ip=None,
766
  **kwargs,
767
  ) -> Callable:
768
  src_is_smem = const_expr(
 
779
  cta_layout,
780
  cute.group_modes(smem_tensor, 0, group_rank_smem),
781
  cute.group_modes(gmem_tensor, 0, group_rank_gmem),
782
+ loc=loc,
783
+ ip=ip,
784
  )
785
  if const_expr(filter_zeros):
786
  s = cute.filter_zeros(s)
787
  g = cute.filter_zeros(g)
788
  src, dst = (s, g) if src_is_smem else (g, s)
789
 
790
+ @dsl_user_op
791
+ def copy_tma(src_idx, dst_idx, *, loc=None, ip=None, **new_kwargs):
792
+ cute.copy(
793
+ atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs, loc=loc, ip=ip
794
+ )
795
 
796
+ @dsl_user_op
797
+ def copy_tma_single_stage(*, loc=None, ip=None, **new_kwargs):
798
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs, loc=loc, ip=ip)
799
 
800
  return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
801
 
 
816
  def gather_m_get_copy_fn(
817
  thr_copy_A: cute.ThrCopy,
818
  mA: cute.Tensor, # (whatever, K)
819
+ sA: cute.Tensor, # (tile_M, tile_K, STAGE)
820
  gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
821
  limit_m: Int32,
822
  limit_k: Int32,
823
  ) -> Callable:
824
+ tile_M, tile_K = cute.size(sA, mode=[0]), cute.size(sA, mode=[1])
825
+ tAsA = partition_D_position_independent(thr_copy_A, sA)
826
  # k-major
827
  assert tAsA.shape[2] == 1
828
  tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
829
 
830
+ is_even_m_smem = tile_M % thr_copy_A.tiler_mn[0].shape == 0
831
  if const_expr(not is_even_m_smem):
832
+ limit_m = min(limit_m, tile_M)
833
  elems_per_load = cute.size(tAsA.shape[0][0])
834
+ cA = cute.make_identity_tensor((tile_M, tile_K))
835
  tAcA = thr_copy_A.partition_S(cA)
836
  t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
837
  # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
 
842
  # Read and cache indices for A
843
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
844
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
845
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
846
  for m in cutlass.range(rows_per_thread, unroll_full=True):
847
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
848
+ m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
849
  for m in cutlass.range(rows_per_thread, unroll_full=True):
850
  row_idx = tAcA[0, m, 0][0]
851
  if tApA_m[m]:
 
853
  else:
854
  m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
855
 
856
+ mA_k = cute.logical_divide(mA, (None, tile_K))
857
 
858
  def copy_fn(src_idx, dst_idx, pred: bool = False):
859
  tApA_k = None
860
  if const_expr(pred):
861
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
862
+ limit_k_cur = limit_k - src_idx * tile_K
863
  for k in cutlass.range(cols_per_thread, unroll_full=True):
864
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
865
  mA_cur = mA_k[None, (None, src_idx)]
 
884
  def gather_k_get_copy_fn(
885
  thr_copy_A: cute.ThrCopy,
886
  mA: cute.Tensor, # (tile_M, whatever)
887
+ sA: cute.Tensor, # (tile_M, tile_K, STAGE)
888
  gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
889
  limit_m: Int32,
890
  limit_k: Int32,
 
916
  # Read and cache indices for A
917
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
918
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
919
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
920
  for m in cutlass.range(rows_per_thread, unroll_full=True):
921
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
922
  threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
 
932
  # Prefetch mAIdx early, even before smem is free
933
  tApA_k = None
934
  if const_expr(pred):
935
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
936
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
937
  for k in cutlass.range(cols_per_thread, unroll_full=True):
938
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
939
  gAIdx_cur = gAIdx[None, src_idx]
940
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
941
  for k in cutlass.range(cols_per_thread):
942
  col_idx = tAcA[0, 0, k][1]
943
  if const_expr(not pred):
 
954
  ) -> Tuple[cute.Tensor, cute.Tensor]:
955
  tApA_k = None
956
  if const_expr(pred):
957
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
958
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
959
  for k in cutlass.range(cols_per_thread, unroll_full=True):
960
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
961
  a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
962
  sAIdx_cur = sAIdx[None, dst_idx]
963
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
964
  for k in cutlass.range(cols_per_thread):
965
  col_idx = tAcA[0, 0, k][1]
966
  k_idx[k] = sAIdx_cur[col_idx]
 
990
  return copy_fn, prefetch_from_gmem_fn if const_expr(
991
  gAIdx is not None
992
  ) else prefetch_from_smem_fn
993
+
994
+
995
+ @cute.jit
996
+ def gather_m_get_tma_copy_fn(
997
+ tma_atom: cute.CopyAtom,
998
+ mA: cute.Tensor, # (whatever, K)
999
+ sA: cute.Tensor, # ((4, 32), (64, 1), STAGE)
1000
+ sAIdx: cute.Tensor, # (tile_M),
1001
+ warp_idx: Int32,
1002
+ num_warps: int,
1003
+ num_cta: int = 1,
1004
+ ) -> Callable:
1005
+ tile_M = cute.size(sAIdx, mode=[0])
1006
+ tile_K = cute.size(sA[None, None, 0]) // tile_M
1007
+ assert tile_M % 4 == 0
1008
+ # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2
1009
+ cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel
1010
+
1011
+ copy_AIdx_s2r = cute.make_tiled_copy_tv(
1012
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
1013
+ cute.make_layout(num_warps), # thr_layout
1014
+ cute.make_layout(4), # val_layout
1015
+ )
1016
+ warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
1017
+ tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx)
1018
+ # ((4, 1), 8, (64, 1), STAGE)
1019
+ tSR_sA = warp_copy_AIdx_s2r.partition_S(sA)
1020
+ tSR_rAIdx = load_s2r(tSR_sAIdx)
1021
+ tma_desc_ptr = get_tma_desc_addr(tma_atom)
1022
+ tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
1023
+
1024
+ def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
1025
+ tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
1026
+ col_idx = tile_K * src_idx
1027
+ for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1028
+ row_indices = [tSR_rAIdx[v, m] for v in range(4)]
1029
+ smem_ptr = tSR_sA_cur[None, m, None].iterator
1030
+ with cute.arch.elect_one():
1031
+ tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
1032
+
1033
+ return copy_fn
1034
+
1035
+
1036
+ @cute.jit
1037
+ def gather_k_get_tma_copy_fn(
1038
+ tma_atom: cute.CopyAtom,
1039
+ sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout
1040
+ sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem
1041
+ col_idx: Int32, # M offset in global tensor (contiguous dim for M-major)
1042
+ warp_idx: Int32,
1043
+ num_warps: int,
1044
+ num_cta: int = 1,
1045
+ ) -> Tuple[Callable, Callable]:
1046
+ """Build a copy function for TMA gather4 in K dimension (M-major A).
1047
+
1048
+ Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements.
1049
+ col_idx is the absolute M position in the global tensor.
1050
+ K indices come from sAIdx (prefetched to smem by the scheduler warp).
1051
+
1052
+ Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which:
1053
+ Issues gather4 calls with those K indices as row_indices
1054
+ """
1055
+ tile_K = cute.size(sAIdx, mode=[0])
1056
+ assert tile_K % 4 == 0
1057
+ cta_group = num_cta
1058
+
1059
+ # Tiled copy for loading K indices from smem to registers (4 per vector, across warps)
1060
+ copy_AIdx_s2r = cute.make_tiled_copy_tv(
1061
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
1062
+ cute.make_layout(num_warps), # thr_layout
1063
+ cute.make_layout(4), # val_layout — 4 K indices per gather4
1064
+ )
1065
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
1066
+ warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
1067
+ tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4))
1068
+ # ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192))
1069
+ tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA))
1070
+ tma_desc_ptr = get_tma_desc_addr(tma_atom)
1071
+ tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
1072
+
1073
+ def prefetch_from_smem_fn(
1074
+ a_prefetch_pipeline,
1075
+ src_idx,
1076
+ dst_idx,
1077
+ a_prefetch_consumer_state,
1078
+ ) -> cute.Tensor:
1079
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
1080
+ tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx])
1081
+ cute.arch.sync_warp()
1082
+ with cute.arch.elect_one():
1083
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
1084
+ return tSR_rAIdx
1085
+
1086
+ def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer):
1087
+ # Issue gather4: col_idx = M position, row_indices = 4 K positions
1088
+ tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
1089
+ gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64
1090
+ for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1091
+ row_indices = [tSR_rAIdx[v, k] for v in range(4)]
1092
+ for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True):
1093
+ smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator
1094
+ with cute.arch.elect_one():
1095
+ tma_gather4_load_fn(
1096
+ smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices
1097
+ )
1098
+
1099
+ return copy_fn, prefetch_from_smem_fn
1100
+
1101
+
1102
+ # ---------------------------------------------------------------------------
1103
+ # Store helpers
1104
+ # ---------------------------------------------------------------------------
1105
+
1106
+
1107
+ @dsl_user_op
1108
+ @cute.jit
1109
+ def store(
1110
+ ptr: cute.Pointer,
1111
+ val,
1112
+ pred: Optional[Boolean] = None,
1113
+ cop: cutlass.Constexpr = None,
1114
+ *,
1115
+ loc=None,
1116
+ ip=None,
1117
+ ):
1118
+ """Store a scalar value via cute.arch.store.
1119
+
1120
+ ptr: cute.Pointer (any address space).
1121
+ val: DSL Numeric value.
1122
+ pred: None → unconditional. DSL Boolean → skipped when pred == 0.
1123
+ cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1124
+ """
1125
+ if const_expr(pred is None):
1126
+ cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
1127
+ else:
1128
+ if pred:
1129
+ cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
1130
+
1131
+
1132
+ @dsl_user_op
1133
+ @cute.jit
1134
+ def store_v2(
1135
+ ptr: cute.Pointer,
1136
+ v0,
1137
+ v1,
1138
+ pred: Optional[Boolean] = None,
1139
+ cop: cutlass.Constexpr = None,
1140
+ *,
1141
+ loc=None,
1142
+ ip=None,
1143
+ ):
1144
+ """Vectorized store of 2 elements via cute.arch.store.
1145
+
1146
+ Packs v0, v1 into an MLIR <2 x T> vector.
1147
+ ptr: cute.Pointer (any address space, must be aligned for vector width).
1148
+ cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1149
+ """
1150
+ vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip)
1151
+ if const_expr(pred is None):
1152
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1153
+ else:
1154
+ if pred:
1155
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1156
+
1157
+
1158
+ @dsl_user_op
1159
+ @cute.jit
1160
+ def store_v4(
1161
+ ptr: cute.Pointer,
1162
+ v0,
1163
+ v1,
1164
+ v2,
1165
+ v3,
1166
+ pred: Optional[Boolean] = None,
1167
+ cop: cutlass.Constexpr = None,
1168
+ *,
1169
+ loc=None,
1170
+ ip=None,
1171
+ ):
1172
+ """Vectorized store of 4 elements via cute.arch.store.
1173
+
1174
+ Packs v0–v3 into an MLIR <4 x T> vector.
1175
+ ptr: cute.Pointer (any address space, must be aligned for vector width).
1176
+ cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1177
+ """
1178
+ vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip)
1179
+ if const_expr(pred is None):
1180
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1181
+ else:
1182
+ if pred:
1183
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
build/torch-cuda/quack/cross_entropy.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ from typing import Optional, Type, Literal
6
+
7
+ import torch
8
+ from ._ops_compat import add_quack_op_namespace_prefix
9
+ from torch import Tensor
10
+
11
+ import cuda.bindings.driver as cuda
12
+
13
+ import cutlass
14
+ import cutlass.cute as cute
15
+ from cutlass import Int32, Int64, Float32, Boolean, const_expr
16
+
17
+ from . import utils as utils
18
+ from . import copy_utils as copy_utils
19
+ from . import layout_utils as layout_utils
20
+ from .compile_utils import make_fake_tensor as fake_tensor
21
+ from .reduce import row_reduce, online_softmax_reduce
22
+ from .reduction_base import ReductionBase
23
+ from .cache_utils import jit_cache
24
+ from .cute_dsl_utils import torch2cute_dtype_map
25
+ from cutlass.base_dsl import Arch
26
+
27
+
28
+ class CrossEntropy(ReductionBase):
29
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
30
+ self.online_softmax = online_softmax
31
+ # 2 stages: 1 for max, 1 for sum
32
+ super().__init__(
33
+ dtype,
34
+ N,
35
+ stage=2 if not self.online_softmax else 1,
36
+ reduction_dtype=Float32 if not self.online_softmax else Int64,
37
+ )
38
+ self.reload_from = None if N <= 16384 or self.online_softmax else "smem"
39
+
40
+ def _threads_per_row(self):
41
+ N = self.N
42
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
43
+ if N <= limit:
44
+ return threads
45
+ return 256
46
+
47
+ def _set_cluster_n(self):
48
+ arch = cutlass.base_dsl.BaseDSL._get_dsl().get_arch_enum()
49
+ # SM8x (Ampere/Ada) lacks cluster support
50
+ if arch < Arch.sm_90:
51
+ self.cluster_n = 1
52
+ return
53
+ # SM12x supports cluster up to 8
54
+ max_cluster = 8 if arch.major == 12 else 16
55
+ N = self.N
56
+ if arch.major == 12 and const_expr(self.dtype.width >= 32):
57
+ # SM12x 99 KB SMEM: fp32 needs tighter clustering (same limits as fp16)
58
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
59
+ elif const_expr(self.dtype.width == 16):
60
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
61
+ else:
62
+ thresholds = [(16 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
63
+ for limit, cluster in thresholds:
64
+ if N <= limit:
65
+ self.cluster_n = cluster
66
+ return
67
+ self.cluster_n = max_cluster
68
+
69
+ @cute.jit
70
+ def __call__(
71
+ self,
72
+ mX: cute.Tensor, # (M, N)
73
+ mTarget: cute.Tensor, # (M,)
74
+ mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX
75
+ mLoss: cute.Tensor, # (M,)
76
+ mLSE: Optional[cute.Tensor], # (M,)
77
+ mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
78
+ ignore_index: Int32, # Index to ignore in loss computation
79
+ stream: cuda.CUstream,
80
+ ):
81
+ assert mX.element_type == self.dtype
82
+ if const_expr(mTargetLogit is None):
83
+ mTargetLogit = mX
84
+ if const_expr(mdX is not None):
85
+ assert mdX.element_type == self.dtype
86
+ self._set_cluster_n()
87
+ largest_dtype_width = const_expr(mX.element_type.width)
88
+ if const_expr(mdX is not None):
89
+ largest_dtype_width = const_expr(max(largest_dtype_width, mdX.element_type.width))
90
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
91
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
92
+ num_threads = tiled_copy.size
93
+ self.kernel(
94
+ mX,
95
+ mTarget,
96
+ mTargetLogit,
97
+ mLoss,
98
+ mLSE,
99
+ mdX,
100
+ ignore_index,
101
+ tiler_mn,
102
+ tiled_copy,
103
+ threads_per_row,
104
+ ).launch(
105
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
106
+ block=[num_threads, 1, 1],
107
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
108
+ stream=stream,
109
+ )
110
+
111
+ @cute.kernel
112
+ def kernel(
113
+ self,
114
+ mX: cute.Tensor, # (M, N)
115
+ mTarget: cute.Tensor, # (M,)
116
+ mTargetLogit: cute.Tensor, # (M, K) or (M,)
117
+ mLoss: cute.Tensor, # (M,)
118
+ mLSE: Optional[cute.Tensor], # (M,)
119
+ mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
120
+ ignore_index: Int32, # Index to ignore in loss computation
121
+ tiler_mn: cute.Shape,
122
+ tiled_copy: cute.TiledCopy,
123
+ threads_per_row: cutlass.Constexpr[int],
124
+ ):
125
+ tidx, _, _ = cute.arch.thread_idx()
126
+ bidx, _, _ = cute.arch.block_idx()
127
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
128
+ tv_layout = tiled_copy.layout_tv_tiled
129
+
130
+ shape = mX.shape
131
+ idX = cute.make_identity_tensor(shape)
132
+ # slice for CTAs
133
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
134
+
135
+ smem = cutlass.utils.SmemAllocator()
136
+ sX = smem.allocate_tensor(
137
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
138
+ )
139
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
140
+
141
+ thr_copy = tiled_copy.get_slice(tidx)
142
+
143
+ tXgX = thr_copy.partition_S(gX)
144
+ tXsX = thr_copy.partition_D(sX)
145
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
146
+ tXrX = cute.make_rmem_tensor_like(tXgX)
147
+
148
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
149
+ tXpX = (
150
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
151
+ )
152
+ copy = partial(copy_utils.copy, pred=tXpX)
153
+
154
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
155
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
156
+
157
+ row = tXcX[0][0]
158
+ target = Int32.zero
159
+ if row < shape[0]:
160
+ target = Int32(mTarget[row])
161
+
162
+ if row < shape[0]:
163
+ copy(tXgX, tXsX, is_async=True)
164
+ cute.arch.cp_async_commit_group()
165
+ cute.arch.cp_async_wait_group(0)
166
+ # Fill OOB values with -inf
167
+ if const_expr(not is_even_N):
168
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
169
+ cute.autovec_copy(tXsX, tXrX)
170
+ x = tXrX.load().to(Float32)
171
+
172
+ target_logit = Float32.zero
173
+ should_ignore = Boolean(target == ignore_index)
174
+ if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
175
+ # Only load target logit if not ignoring this index
176
+ if const_expr(cute.rank(mTargetLogit.shape) == 2):
177
+ target_logit = Float32(mTargetLogit[row, target])
178
+ else:
179
+ assert cute.rank(mTargetLogit.shape) == 1
180
+ target_logit = Float32(mTargetLogit[row])
181
+
182
+ if const_expr(not self.online_softmax):
183
+ max_x = row_reduce(
184
+ x,
185
+ cute.ReductionOp.MAX,
186
+ threads_per_row,
187
+ reduction_buffer[None, None, 0],
188
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
189
+ init_val=-Float32.inf,
190
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
191
+ )
192
+ if const_expr(self.reload_from == "smem"):
193
+ cute.autovec_copy(tXsX, tXrX)
194
+ x = tXrX.load().to(Float32)
195
+ log2_e = math.log2(math.e)
196
+ # This would use ffma instead of fadd then fmul
197
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
198
+ denom = row_reduce(
199
+ exp_x,
200
+ cute.ReductionOp.ADD,
201
+ threads_per_row,
202
+ reduction_buffer[None, None, 1],
203
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
204
+ init_val=0.0,
205
+ )
206
+ else:
207
+ max_x, denom, exp_x = online_softmax_reduce(
208
+ x,
209
+ threads_per_row,
210
+ reduction_buffer[None, None, 0],
211
+ mbar_ptr,
212
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
213
+ return_exp_x=const_expr(mdX is not None),
214
+ )
215
+
216
+ # Write loss and lse to gmem
217
+ if (
218
+ tXcX[0][1] == 0
219
+ and row < shape[0]
220
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
221
+ ):
222
+ lse = max_x + cute.math.log(denom, fastmath=True)
223
+ # Set loss to 0 if this index should be ignored, otherwise compute normally
224
+ loss_val = (lse - target_logit) if not should_ignore else Float32.zero
225
+ mLoss[row] = mLoss.element_type(loss_val)
226
+ if const_expr(mLSE is not None):
227
+ mLSE[row] = lse
228
+
229
+ # Compute gradient if mdX is provided
230
+ if const_expr(mdX is not None):
231
+ # Compute probabilities: exp(x) / sum(exp(x))
232
+ # If ignored, gradient should be zero
233
+ denom_inv = (
234
+ # 1.0 / denom
235
+ cute.arch.rcp_approx(denom)
236
+ if not (denom == 0.0 or denom != denom or should_ignore)
237
+ else Float32.zero
238
+ )
239
+ probs = exp_x * denom_inv
240
+ gdX = cute.local_tile(mdX, tiler_mn, (bidx, cluster_y))
241
+ tXgdX = thr_copy.partition_D(gdX)
242
+ tXrdX = cute.make_rmem_tensor_like(tXgdX)
243
+ tXcFull = thr_copy.partition_S(cX)
244
+ # Compute gradient: probs for all classes, (probs - 1) for target class
245
+ # If ignored, gradient is already zero
246
+ tXrdX_f32 = cute.make_rmem_tensor_like(tXrX, Float32)
247
+ tXrdX_f32.store(probs)
248
+ if not should_ignore:
249
+ for i in cutlass.range(cute.size(tXrX), unroll_full=True):
250
+ tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
251
+ tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
252
+ if row < shape[0]:
253
+ copy(tXrdX, tXgdX)
254
+
255
+
256
+ @jit_cache
257
+ def _compile_cross_entropy_fwd(
258
+ dtype, target_dtype, target_logit_dtype, N, has_lse, has_dx, target_logit_ndim
259
+ ):
260
+ batch_sym = cute.sym_int()
261
+ div = math.gcd(128 // dtype.width, N)
262
+ x_cute = fake_tensor(dtype, (batch_sym, N), div)
263
+ dx_cute = fake_tensor(dtype, (batch_sym, N), div) if has_dx else None
264
+ target_cute = fake_tensor(target_dtype, (batch_sym,))
265
+ if target_logit_dtype is not None:
266
+ if target_logit_ndim == 2:
267
+ target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym, cute.sym_int()), div)
268
+ else:
269
+ target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym,))
270
+ else:
271
+ target_logit_cute = None
272
+ loss_cute = fake_tensor(Float32, (batch_sym,))
273
+ lse_cute = fake_tensor(Float32, (batch_sym,)) if has_lse else None
274
+ # If there's dx, it's faster to not use online softmax since we want the exp(x - max)
275
+ cross_entropy_op = CrossEntropy(dtype, N, online_softmax=not has_dx)
276
+ return cute.compile(
277
+ cross_entropy_op,
278
+ x_cute,
279
+ target_cute,
280
+ target_logit_cute,
281
+ loss_cute,
282
+ lse_cute,
283
+ dx_cute,
284
+ Int32(0), # ignore_index, just for compilation
285
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
286
+ options="--enable-tvm-ffi",
287
+ )
288
+
289
+
290
+ @torch.library.custom_op(add_quack_op_namespace_prefix("cross_entropy_fwd_out"), mutates_args={"loss", "lse", "dx"})
291
+ def cross_entropy_fwd_out(
292
+ x: Tensor,
293
+ target: Tensor,
294
+ target_logit: Optional[Tensor],
295
+ loss: Tensor,
296
+ lse: Optional[Tensor],
297
+ dx: Optional[Tensor],
298
+ ignore_index: int = -100,
299
+ ) -> None:
300
+ """Cross entropy forward pass.
301
+
302
+ Args:
303
+ x: Input logits tensor of shape (M, N)
304
+ target: Target class indices tensor of shape (M,)
305
+ target_logit: (M, K) or (M,).
306
+ If provided, the target logit will be read from this tensor instead of x.
307
+ loss: Output loss tensor of shape (M,)
308
+ lse: Optional output log-sum-exp tensor of shape (M,)
309
+ dx: Optional output gradient tensor of shape (M, N)
310
+ ignore_index: Index to ignore in loss computation
311
+
312
+ Returns:
313
+ None (mutates loss, lse, and optionally dx in-place)
314
+ """
315
+ assert x.dim() == 2, "Input must be 2D"
316
+ assert target.dim() == 1, "Target must be 1D"
317
+ assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
318
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
319
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
320
+ if target_logit is not None:
321
+ assert target_logit.is_cuda, "Target logits must be on CUDA device"
322
+ assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
323
+ if dx is not None:
324
+ assert dx.is_cuda, "dx must be on CUDA device"
325
+ N = x.size(1)
326
+ dtype = torch2cute_dtype_map[x.dtype]
327
+ target_dtype = torch2cute_dtype_map[target.dtype]
328
+ target_logit_dtype = (
329
+ torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
330
+ )
331
+ target_logit_ndim = target_logit.ndim if target_logit is not None else None
332
+ _compile_cross_entropy_fwd(
333
+ dtype,
334
+ target_dtype,
335
+ target_logit_dtype,
336
+ N,
337
+ lse is not None,
338
+ dx is not None,
339
+ target_logit_ndim,
340
+ )(x, target, target_logit, loss, lse, dx, Int32(ignore_index))
341
+
342
+
343
+ @cross_entropy_fwd_out.register_fake
344
+ def _cross_entropy_fwd_out_fake(
345
+ x: Tensor,
346
+ target: Tensor,
347
+ target_logit: Optional[Tensor],
348
+ loss: Tensor,
349
+ lse: Optional[Tensor],
350
+ dx: Optional[Tensor],
351
+ ignore_index: int = -100,
352
+ ) -> None:
353
+ # See softmax.py _softmax_fwd_fake for why register_fake is needed.
354
+ from .cache_utils import COMPILE_ONLY
355
+
356
+ if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
357
+ N = x.size(1)
358
+ dtype = torch2cute_dtype_map[x.dtype]
359
+ target_dtype = torch2cute_dtype_map[target.dtype]
360
+ target_logit_dtype = (
361
+ torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
362
+ )
363
+ target_logit_ndim = target_logit.ndim if target_logit is not None else None
364
+ _compile_cross_entropy_fwd(
365
+ dtype,
366
+ target_dtype,
367
+ target_logit_dtype,
368
+ N,
369
+ lse is not None,
370
+ dx is not None,
371
+ target_logit_ndim,
372
+ )
373
+ _compile_cross_entropy_backward(dtype, target_dtype, N)
374
+
375
+
376
+ def cross_entropy_fwd(
377
+ x: torch.Tensor,
378
+ target: torch.Tensor,
379
+ target_logit: Optional[torch.Tensor] = None,
380
+ ignore_index: int = -100,
381
+ return_lse: bool = False,
382
+ return_dx: bool = False,
383
+ inplace_backward: bool = False,
384
+ ) -> torch.Tensor | tuple[torch.Tensor]:
385
+ M = x.size(0)
386
+ device = x.device
387
+ loss = torch.empty(M, device=device, dtype=torch.float32)
388
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
389
+ dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None
390
+ cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index)
391
+ if return_lse and return_dx:
392
+ return loss, lse, dx
393
+ elif return_lse:
394
+ return loss, lse
395
+ elif return_dx:
396
+ return loss, dx
397
+ else:
398
+ return loss
399
+
400
+
401
+ class CrossEntropyBackward:
402
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
403
+ self.dtype = dtype
404
+ self.N = N
405
+ self.vecsize = 128 // dtype.width
406
+
407
+ def _threads_per_row(self):
408
+ N = min(self.N, 16384) # We split by blocks of 16k
409
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
410
+ if N <= limit:
411
+ return threads
412
+ return 256
413
+
414
+ def _get_tiled_copy(self, vecsize: int):
415
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
416
+ N = min(self.N, 16384)
417
+ num_threads = 128 if N <= 16384 else 256
418
+ threads_per_row = self._threads_per_row()
419
+ cols_per_block = num_threads // threads_per_row
420
+ num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
421
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
422
+ tiled_copy = copy_utils.tiled_copy_2d(
423
+ self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
424
+ )
425
+ return tiled_copy, tiler_mn, threads_per_row
426
+
427
+ @cute.jit
428
+ def __call__(
429
+ self,
430
+ mX: cute.Tensor,
431
+ mTarget: cute.Tensor,
432
+ mDLoss: cute.Tensor,
433
+ mdX: cute.Tensor,
434
+ mLSE: cute.Tensor,
435
+ ignore_index: Int32, # Index to ignore in gradient computation
436
+ stream: cuda.CUstream,
437
+ ):
438
+ assert mX.element_type == self.dtype
439
+ assert mdX.element_type == self.dtype
440
+ # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
441
+ vecsize = math.gcd(self.N, 128 // self.dtype.width)
442
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
443
+ num_threads = tiled_copy.size
444
+ # (M,) -> (M, N) with stride 0 in the N dimension
445
+ mDLoss, mTarget, mLSE = [
446
+ layout_utils.expand(X, dim=1, size=self.N) for X in (mDLoss, mTarget, mLSE)
447
+ ]
448
+ self.kernel(
449
+ mX,
450
+ mTarget,
451
+ mDLoss,
452
+ mdX,
453
+ mLSE,
454
+ ignore_index,
455
+ mX.shape,
456
+ tiler_mn,
457
+ tiled_copy,
458
+ threads_per_row,
459
+ ).launch(
460
+ grid=[
461
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
462
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
463
+ 1,
464
+ ],
465
+ block=[num_threads, 1, 1],
466
+ stream=stream,
467
+ )
468
+
469
+ @cute.kernel
470
+ def kernel(
471
+ self,
472
+ mX: cute.Tensor, # (M, N)
473
+ mTarget: cute.Tensor, # (M,)
474
+ mDLoss: cute.Tensor, # (M,)
475
+ mdX: cute.Tensor, # (M, N)
476
+ mLSE: cute.Tensor, # (M,)
477
+ ignore_index: Int32, # Index to ignore in gradient computation
478
+ shape: cute.Shape,
479
+ tiler_mn: cute.Shape,
480
+ tiled_copy: cute.TiledCopy,
481
+ threads_per_row: cutlass.Constexpr[int],
482
+ ):
483
+ tidx, _, _ = cute.arch.thread_idx()
484
+ bidx, bidy, _ = cute.arch.block_idx()
485
+
486
+ smem = cutlass.utils.SmemAllocator()
487
+ sX = smem.allocate_tensor(
488
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
489
+ )
490
+
491
+ idX = cute.make_identity_tensor(shape)
492
+ gX, gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX)]
493
+
494
+ thr_copy = tiled_copy.get_slice(tidx)
495
+
496
+ tXgX = thr_copy.partition_S(gX)
497
+ tXsX = thr_copy.partition_D(sX)
498
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
499
+ tXcFull = thr_copy.partition_S(cX)
500
+ tXgdX = thr_copy.partition_D(gdX)
501
+ tXrX, tXrdX = [cute.make_rmem_tensor_like(thr) for thr in (tXgX, tXgdX)]
502
+
503
+ is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
504
+ tXpX = (
505
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
506
+ )
507
+ copy = partial(copy_utils.copy, pred=tXpX)
508
+
509
+ row = tXcX[0][0]
510
+ if row < shape[0]:
511
+ copy(tXgX, tXsX, is_async=True)
512
+ cute.arch.cp_async_commit_group()
513
+ cute.arch.cp_async_wait_group(0)
514
+ if const_expr(not is_even_N):
515
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
516
+ cute.autovec_copy(tXsX, tXrX)
517
+ x = tXrX.load().to(Float32)
518
+
519
+ target = Int32.zero
520
+ dloss = Float32.zero
521
+ lse = Float32.zero
522
+ if row < shape[0]:
523
+ target = Int32(mTarget[row])
524
+ should_ignore = Boolean(target == ignore_index)
525
+ # Set dloss to 0 if this index should be ignored
526
+ if not should_ignore:
527
+ dloss = Float32(mDLoss[row])
528
+ lse = Float32(mLSE[row])
529
+
530
+ log2_e = math.log2(math.e)
531
+ probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
532
+ prob_shifted = probs - 1.0
533
+ mask = cute.make_rmem_tensor_like(tXrX, Boolean)
534
+ for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
535
+ mask[i] = tXcFull[i][1] == target
536
+ grad = cute.where(mask.load(), prob_shifted, probs)
537
+ grad = grad * dloss
538
+
539
+ tXrdX.store(grad.to(tXrdX.element_type))
540
+ if row < shape[0]:
541
+ copy(tXrdX, tXgdX)
542
+
543
+
544
+ @jit_cache
545
+ def _compile_cross_entropy_backward(dtype, target_dtype, N):
546
+ batch_sym = cute.sym_int()
547
+ div = math.gcd(128 // dtype.width, N)
548
+ x_cute, dx_cute = [fake_tensor(dtype, (batch_sym, N), div)] * 2
549
+ target_cute = fake_tensor(target_dtype, (batch_sym,))
550
+ dloss_cute, lse_cute = [fake_tensor(Float32, (batch_sym,))] * 2
551
+ cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
552
+ return cute.compile(
553
+ cross_entropy_backward_op,
554
+ x_cute,
555
+ target_cute,
556
+ dloss_cute,
557
+ dx_cute,
558
+ lse_cute,
559
+ Int32(0), # ignore_index, just for compilation
560
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
561
+ options="--enable-tvm-ffi",
562
+ )
563
+
564
+
565
+ def _cross_entropy_backward(
566
+ x: torch.Tensor,
567
+ target: torch.Tensor,
568
+ dloss: torch.Tensor,
569
+ lse: torch.Tensor,
570
+ dx: torch.Tensor,
571
+ ignore_index=-100,
572
+ ) -> None:
573
+ """Cross entropy backward pass.
574
+ Args:
575
+ x: Input logits tensor of shape (M, N)
576
+ target: Target class indices tensor of shape (M,)
577
+ dloss: Upstream gradients tensor of shape (M,)
578
+ lse: Log-sum-exp values tensor of shape (M,)
579
+ Returns:
580
+ Input gradients tensor of shape (M, N)
581
+ """
582
+ assert x.dim() == 2, "Input must be 2D"
583
+ assert target.dim() == 1, "Target must be 1D"
584
+ assert dloss.dim() == 1, "dloss must be 1D"
585
+ assert lse.dim() == 1, "lse must be 1D"
586
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
587
+ assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
588
+ assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
589
+ assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
590
+ "Tensors must be on CUDA device"
591
+ )
592
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
593
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
594
+ N = x.size(1)
595
+ dtype = torch2cute_dtype_map[x.dtype]
596
+ target_dtype = torch2cute_dtype_map[target.dtype]
597
+ _compile_cross_entropy_backward(dtype, target_dtype, N)(
598
+ x, target, dloss, dx, lse, Int32(ignore_index)
599
+ )
600
+
601
+
602
+ @torch.library.custom_op(add_quack_op_namespace_prefix("cross_entropy_bwd_out"), mutates_args={"dx"})
603
+ def cross_entropy_bwd_out(
604
+ x: torch.Tensor,
605
+ target: torch.Tensor,
606
+ dloss: torch.Tensor,
607
+ lse: torch.Tensor,
608
+ dx: torch.Tensor,
609
+ ignore_index: int = -100,
610
+ ) -> None:
611
+ _cross_entropy_backward(x, target, dloss, lse, dx, ignore_index)
612
+
613
+
614
+ @cross_entropy_bwd_out.register_fake
615
+ def _cross_entropy_bwd_out_fake(
616
+ x: torch.Tensor,
617
+ target: torch.Tensor,
618
+ dloss: torch.Tensor,
619
+ lse: torch.Tensor,
620
+ dx: torch.Tensor,
621
+ ignore_index: int = -100,
622
+ ) -> None:
623
+ # See softmax.py _softmax_fwd_fake for why register_fake is needed.
624
+ from .cache_utils import COMPILE_ONLY
625
+
626
+ if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
627
+ N = x.size(1)
628
+ dtype = torch2cute_dtype_map[x.dtype]
629
+ target_dtype = torch2cute_dtype_map[target.dtype]
630
+ _compile_cross_entropy_backward(dtype, target_dtype, N)
631
+
632
+
633
+ def cross_entropy_bwd(
634
+ x: torch.Tensor,
635
+ target: torch.Tensor,
636
+ dloss: torch.Tensor,
637
+ lse: torch.Tensor,
638
+ ignore_index: int = -100,
639
+ inplace_backward: bool = False,
640
+ ) -> None:
641
+ if inplace_backward and not torch.compiler.is_compiling():
642
+ dx = x
643
+ _cross_entropy_backward(
644
+ x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index
645
+ )
646
+ else:
647
+ dx = torch.empty_like(x)
648
+ cross_entropy_bwd_out(
649
+ x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index
650
+ )
651
+ return dx
652
+
653
+
654
+ class CrossEntropyFunction(torch.autograd.Function):
655
+ @staticmethod
656
+ def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False):
657
+ if lse_partial is None:
658
+ loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True)
659
+ else:
660
+ # if we already compute partial lse, then to compute the final lse we treat
661
+ # @lse_partial as @x and @x as @target_logit
662
+ loss, lse = cross_entropy_fwd(
663
+ lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True
664
+ )
665
+ ctx.save_for_backward(x, target, lse)
666
+ ctx.ignore_index = ignore_index
667
+ ctx.inplace_backward = inplace_backward
668
+ return loss
669
+
670
+ @staticmethod
671
+ def backward(ctx, dloss):
672
+ x, target, lse = ctx.saved_tensors
673
+ dx = cross_entropy_bwd(
674
+ x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward
675
+ )
676
+ return dx, None, None, None, None
677
+
678
+
679
+ def cross_entropy(
680
+ x: torch.Tensor,
681
+ target: torch.Tensor,
682
+ lse_partial: Optional[torch.Tensor] = None,
683
+ ignore_index: int = -100,
684
+ reduction: Literal["none", "mean", "sum"] = "mean",
685
+ inplace_backward: bool = False,
686
+ ) -> torch.Tensor:
687
+ """Cross entropy loss with automatic differentiation support.
688
+
689
+ Args:
690
+ x: Input logits tensor of shape (M, N)
691
+ target: Target class indices tensor of shape (M,)
692
+ lse_partial: Optional precomputed log-sum-exp partial results
693
+ reduction: Specifies the reduction to apply to the output:
694
+ 'none': no reduction will be applied (default)
695
+ 'mean': the sum of the output will be divided by the number of elements
696
+ 'sum': the output will be summed
697
+ inplace_backward: Whether to perform backward pass in-place
698
+ ignore_index: Index to ignore in loss computation (loss will be 0 for these indices)
699
+
700
+ Returns:
701
+ Cross entropy loss tensor:
702
+ - If reduction='none': tensor of shape (M,) with per-example losses
703
+ - If reduction='mean': scalar tensor with mean loss
704
+ - If reduction='sum': scalar tensor with sum of losses
705
+ """
706
+ loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward)
707
+ if reduction == "mean":
708
+ return loss.sum() / (target != ignore_index).sum().float()
709
+ elif reduction == "sum":
710
+ return loss.sum()
711
+ elif reduction == "none":
712
+ return loss
713
+ else:
714
+ raise ValueError(
715
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
716
+ )
build/torch-cuda/quack/cute_dsl_ptxas.py CHANGED
@@ -1,8 +1,16 @@
1
  """
2
  System ptxas replacement for CUTLASS DSL.
 
 
 
 
 
3
  Environment variables:
4
  CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
 
5
  CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
 
 
6
  """
7
 
8
  import os
@@ -16,29 +24,81 @@ import cutlass
16
 
17
 
18
  CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
 
 
 
19
  VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
20
 
21
  _original_load_cuda_library = None
 
22
  _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
23
 
24
 
25
- def _log(msg):
26
  if VERBOSE:
27
  print(f"[ptxas] {msg}", file=sys.stderr)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def _get_ptx(compiled_func) -> tuple[str, Path] | None:
31
- """Find and read PTX file, stripping null bytes."""
32
  func_name = getattr(compiled_func, "function_name", None)
33
  if not func_name:
 
34
  return None
35
 
36
- dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
37
- for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
38
- content = ptx_path.read_text().rstrip("\x00")
39
- if ".entry " in content and content.rstrip().endswith("}"):
40
- _log(f"Found PTX: {ptx_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return content, ptx_path
 
 
 
 
 
 
 
 
 
42
  return None
43
 
44
 
@@ -102,13 +162,15 @@ def _patched_load_cuda_library(self):
102
  _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
103
  return _original_load_cuda_library(self)
104
 
105
- # Register kernels on all devices
106
  _, cuda_load_to_device = self._get_cuda_init_and_load()
107
- lib_ptr = ctypes.c_void_p(int(library))
 
 
108
  dev_id = ctypes.c_int32(0)
109
  err_val = ctypes.c_int32(0)
110
  args = (ctypes.c_void_p * 3)(
111
- ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
112
  ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
113
  ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
114
  )
@@ -126,26 +188,50 @@ def _patched_load_cuda_library(self):
126
  if not _user_wanted_ptx:
127
  ptx_path.unlink(missing_ok=True)
128
 
129
- return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
 
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  def patch():
133
  """Install system ptxas hook. Call before importing cutlass."""
134
- global _original_load_cuda_library, _user_wanted_ptx
135
 
136
  assert CUTE_DSL_PTXAS_PATH is not None
137
  if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
138
  raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
139
 
140
- # Track if user originally wanted PTX kept
141
  _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
142
- # os.environ['CUTE_DSL_KEEP_PTX'] = '1'
143
  assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
144
  "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
145
  )
146
 
147
- cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
148
- _original_load_cuda_library = cls._load_cuda_library
149
- cls._load_cuda_library = _patched_load_cuda_library
150
- _log("Patch applied")
151
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  System ptxas replacement for CUTLASS DSL.
3
+
4
+ Usage::
5
+
6
+ CUTE_DSL_KEEP_PTX=1 CUTE_DSL_PTXAS_PATH=/usr/local/cuda/bin/ptxas pytest tests/
7
+
8
  Environment variables:
9
  CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
10
+ CUTE_DSL_KEEP_PTX - Must be set to 1 before cutlass is imported
11
  CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
12
+ CUTE_DSL_DUMP_DIR - Directory for dumped PTX files (default: cwd)
13
+ CUTE_DSL_KEEP_CUBIN - Set to 1 to save compiled cubin files
14
  """
15
 
16
  import os
 
24
 
25
 
26
  CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
27
+
28
+ if CUTE_DSL_PTXAS_PATH:
29
+ os.environ["CUTE_DSL_KEEP_PTX"] = "1"
30
  VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
31
 
32
  _original_load_cuda_library = None
33
+ _original_create_tvm_ffi_function = None
34
  _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
35
 
36
 
37
+ def _log(msg: str):
38
  if VERBOSE:
39
  print(f"[ptxas] {msg}", file=sys.stderr)
40
 
41
 
42
+ def _read_ptx(ptx_path: Path) -> str | None:
43
+ try:
44
+ return ptx_path.read_bytes().decode("utf-8", errors="ignore").rstrip("\x00")
45
+ except OSError as exc:
46
+ _log(f"Failed to read {ptx_path}: {exc}")
47
+ return None
48
+
49
+
50
+ def _read_complete_ptx(ptx_path: Path) -> str | None:
51
+ content = _read_ptx(ptx_path)
52
+ if content is None or not content.rstrip().endswith("}"):
53
+ return None
54
+ return content
55
+
56
+
57
  def _get_ptx(compiled_func) -> tuple[str, Path] | None:
58
+ """Find dumped PTX for the compiled function."""
59
  func_name = getattr(compiled_func, "function_name", None)
60
  if not func_name:
61
+ _log("Compiled function is missing function_name")
62
  return None
63
 
64
+ dump_dir = Path(os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd()))
65
+ dump_dir.mkdir(parents=True, exist_ok=True)
66
+
67
+ ptx_paths = sorted(
68
+ dump_dir.rglob("*.ptx"), key=lambda path: path.stat().st_mtime_ns, reverse=True
69
+ )
70
+ _log(f"Searching dumped PTX for {func_name} in {dump_dir}")
71
+ _log(f"Found {len(ptx_paths)} PTX candidate files in {dump_dir}")
72
+
73
+ # Strategy 1: match by filename
74
+ filename_matches = [ptx_path for ptx_path in ptx_paths if func_name in ptx_path.name]
75
+ if filename_matches:
76
+ _log(f"Found {len(filename_matches)} filename matches for {func_name}")
77
+ for ptx_path in filename_matches:
78
+ content = _read_complete_ptx(ptx_path)
79
+ if content is None:
80
+ continue
81
+ _log(f"Using PTX filename match for {func_name}: {ptx_path}")
82
+ return content, ptx_path
83
+
84
+ # Strategy 2: match by .entry directive inside PTX
85
+ entry_pattern = re.compile(rf"\.entry\s+{re.escape(func_name)}(?:\s|\()", re.MULTILINE)
86
+ for ptx_path in ptx_paths:
87
+ content = _read_complete_ptx(ptx_path)
88
+ if content is None:
89
+ continue
90
+ if entry_pattern.search(content):
91
+ _log(f"Found PTX for {func_name}: {ptx_path}")
92
  return content, ptx_path
93
+
94
+ # Strategy 3: use sole candidate as fallback
95
+ if len(ptx_paths) == 1:
96
+ content = _read_complete_ptx(ptx_paths[0])
97
+ if content is not None:
98
+ _log(f"Using sole PTX candidate for {func_name}: {ptx_paths[0]}")
99
+ return content, ptx_paths[0]
100
+
101
+ _log(f"No PTX found for function {func_name} in {dump_dir}")
102
  return None
103
 
104
 
 
162
  _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
163
  return _original_load_cuda_library(self)
164
 
165
+ # Register kernels on all devices (must match cuda_load_to_device's void*** convention)
166
  _, cuda_load_to_device = self._get_cuda_init_and_load()
167
+ lib_handle = ctypes.c_void_p(int(library))
168
+ ptr_to_lib = ctypes.pointer(lib_handle)
169
+ ptr_to_ptr_to_lib = ctypes.pointer(ptr_to_lib)
170
  dev_id = ctypes.c_int32(0)
171
  err_val = ctypes.c_int32(0)
172
  args = (ctypes.c_void_p * 3)(
173
+ ctypes.cast(ptr_to_ptr_to_lib, ctypes.c_void_p),
174
  ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
175
  ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
176
  )
 
188
  if not _user_wanted_ptx:
189
  ptx_path.unlink(missing_ok=True)
190
 
191
+ return [cuda_runtime.cudaLibrary_t(lib_handle.value)]
192
+
193
+
194
+ def _patched_create_tvm_ffi_function(self):
195
+ # Ensure CUDA library is loaded before TVM FFI creation
196
+ if getattr(self, "_ptxas_cuda_library", None) is None:
197
+ self._ptxas_cuda_library = self._load_cuda_library()
198
+ _log(
199
+ f"Loaded {len(self._ptxas_cuda_library)} CUDA libraries before creating TVM FFI function"
200
+ )
201
+ return _original_create_tvm_ffi_function(self)
202
 
203
 
204
  def patch():
205
  """Install system ptxas hook. Call before importing cutlass."""
206
+ global _original_load_cuda_library, _original_create_tvm_ffi_function, _user_wanted_ptx
207
 
208
  assert CUTE_DSL_PTXAS_PATH is not None
209
  if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
210
  raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
211
 
 
212
  _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
 
213
  assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
214
  "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
215
  )
216
 
217
+ patched = False
218
+ cuda_jit_function_cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
219
+ if cuda_jit_function_cls._load_cuda_library is not _patched_load_cuda_library:
220
+ _original_load_cuda_library = cuda_jit_function_cls._load_cuda_library
221
+ cuda_jit_function_cls._load_cuda_library = _patched_load_cuda_library
222
+ patched = True
223
+
224
+ from cutlass.cutlass_dsl.tvm_ffi_provider import TVMFFIJitCompiledFunctionBase
225
+
226
+ if (
227
+ TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function
228
+ is not _patched_create_tvm_ffi_function
229
+ ):
230
+ _original_create_tvm_ffi_function = TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function
231
+ TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function = _patched_create_tvm_ffi_function
232
+ patched = True
233
+
234
+ if patched:
235
+ _log(f"Installed system ptxas patch with {CUTE_DSL_PTXAS_PATH}")
236
+ else:
237
+ _log("System ptxas patch already installed")
build/torch-cuda/quack/cute_dsl_utils.py CHANGED
@@ -1,9 +1,12 @@
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
- from typing import Tuple
4
  from functools import lru_cache
5
  from dataclasses import dataclass, fields
6
 
 
 
 
7
  import torch
8
 
9
  try:
@@ -14,7 +17,7 @@ except ImportError:
14
  import cutlass
15
  import cutlass.cute as cute
16
  from cutlass import Int32, Int64, Float16, BFloat16, Float32
17
- from cutlass.base_dsl.typing import JitArgument
18
  from cutlass.cutlass_dsl import NumericMeta
19
 
20
 
@@ -25,6 +28,31 @@ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
25
  cute_compile_og = cute.compile
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  torch2cute_dtype_map = {
29
  torch.float16: Float16,
30
  torch.bfloat16: BFloat16,
@@ -39,66 +67,110 @@ def get_max_active_clusters(cluster_size):
39
  return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
40
 
41
 
 
 
 
 
 
 
 
 
 
42
  @lru_cache
43
- def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
 
 
 
 
 
 
 
 
44
  return torch.cuda.get_device_capability(device)
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @dataclass
48
  class ParamsBase:
49
  def __extract_mlir_values__(self):
50
- all_fields = [getattr(self, field.name) for field in fields(self)]
51
- non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
52
  values, self._values_pos = [], []
53
- for obj in non_constexpr_fields:
54
  obj_values = cutlass.extract_mlir_values(obj)
55
  values += obj_values
56
  self._values_pos.append(len(obj_values))
57
  return values
58
 
59
- def __new_from_mlir_values__(self, values):
60
- all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
61
- constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
62
- non_constexpr_fields = {
63
- n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
64
- }
65
- for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
66
- non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
67
- values = values[n_items:]
68
- return self.__class__(**non_constexpr_fields, **constexpr_fields)
69
-
70
-
71
- @dataclass
72
- class ArgumentsBase(JitArgument):
73
- def __c_pointers__(self):
74
- all_fields = [getattr(self, field.name) for field in fields(self)]
75
- non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
76
- c_ptrs = []
77
- for obj in non_constexpr_fields:
78
- if hasattr(obj, "__c_pointers__"):
79
- c_ptrs.extend(obj.__c_pointers__())
80
- return c_ptrs
81
-
82
- def __get_mlir_types__(self):
83
- all_fields = [getattr(self, field.name) for field in fields(self)]
84
- non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
85
- types, self._values_pos = [], []
86
- for obj in non_constexpr_fields:
87
- if hasattr(obj, "__get_mlir_types__"):
88
- obj_types = obj.__get_mlir_types__()
89
- types.extend(obj_types)
90
- self._values_pos.append(len(obj_types))
91
- else:
92
- self._values_pos.append(0)
93
- return types
94
-
95
- def __new_from_mlir_values__(self, values):
96
- all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
97
- constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
98
- non_constexpr_fields = {
99
- n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
100
- }
101
- for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
102
- non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
103
- values = values[n_items:]
104
- return self.__class__(**non_constexpr_fields, **constexpr_fields)
 
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
+ from typing import Tuple, get_origin
4
  from functools import lru_cache
5
  from dataclasses import dataclass, fields
6
 
7
+ import os
8
+ import re
9
+
10
  import torch
11
 
12
  try:
 
17
  import cutlass
18
  import cutlass.cute as cute
19
  from cutlass import Int32, Int64, Float16, BFloat16, Float32
20
+ from cutlass.base_dsl.tvm_ffi_builder import spec
21
  from cutlass.cutlass_dsl import NumericMeta
22
 
23
 
 
28
  cute_compile_og = cute.compile
29
 
30
 
31
+ # Patch TVM-FFI converter to handle Constexpr type annotations as compile-time constants.
32
+ # Fields annotated with cutlass.Constexpr[T] are emitted as ConstNone (not runtime args).
33
+ # At call time, pass None for these fields; the compile-time value is baked in.
34
+ import cutlass.cute._tvm_ffi_args_spec_converter as _converter_module # noqa
35
+
36
+ _original_convert_single_arg = _converter_module._convert_single_arg
37
+
38
+
39
+ def _patched_convert_single_arg(arg, arg_name, arg_type, ctx):
40
+ if arg_type is not None and get_origin(arg_type) is cutlass.Constexpr:
41
+ return spec.ConstNone(arg_name)
42
+ # If arg is a NamedTuple but arg_type doesn't have _fields (e.g. annotated as tuple),
43
+ # redirect so the converter uses the NamedTuple's own type hints.
44
+ if (
45
+ isinstance(arg, tuple)
46
+ and hasattr(type(arg), "_fields")
47
+ and (arg_type is None or not hasattr(arg_type, "_fields"))
48
+ ):
49
+ return _original_convert_single_arg(arg, arg_name, type(arg), ctx)
50
+ return _original_convert_single_arg(arg, arg_name, arg_type, ctx)
51
+
52
+
53
+ _converter_module._convert_single_arg = _patched_convert_single_arg
54
+
55
+
56
  torch2cute_dtype_map = {
57
  torch.float16: Float16,
58
  torch.bfloat16: BFloat16,
 
67
  return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
68
 
69
 
70
+ def _parse_arch_str(arch_str: str) -> Tuple[int, int]:
71
+ """Parse arch string (e.g. 'sm_90', 'sm90', '90', 'sm_100a') to (major, minor) tuple."""
72
+ match = re.match(r"^(?:sm_?)?(\d+)(\d)([af]?)$", arch_str.strip(), re.IGNORECASE)
73
+ if not match:
74
+ raise ValueError(f"Invalid QUACK_ARCH format: {arch_str!r} (expected e.g. '90', 'sm_90')")
75
+ major, minor, _ = match.groups()
76
+ return int(major), int(minor)
77
+
78
+
79
  @lru_cache
80
+ def _get_device_capacity_cached(device: torch.device = None) -> Tuple[int, int]:
81
+ """Return (major, minor) device capability.
82
+
83
+ Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
84
+ without a GPU present.
85
+ """
86
+ arch_override = os.environ.get("QUACK_ARCH")
87
+ if arch_override is not None:
88
+ return _parse_arch_str(arch_override)
89
  return torch.cuda.get_device_capability(device)
90
 
91
 
92
+ def get_device_capacity(
93
+ device: torch.device | torch.Tensor | None = None,
94
+ ) -> Tuple[int, int]:
95
+ """Return (major, minor) device capability.
96
+
97
+ Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
98
+ without a GPU present.
99
+
100
+ Accepts either a ``torch.device`` or a tensor and canonicalizes to the
101
+ underlying device before consulting the cached helper. This avoids leaking
102
+ tensors through the LRU cache key.
103
+ """
104
+ if isinstance(device, torch.Tensor):
105
+ device = device.device
106
+ return _get_device_capacity_cached(device)
107
+
108
+
109
+ def _partition_fields(obj):
110
+ """Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type."""
111
+ all_fields = {field.name: getattr(obj, field.name) for field in fields(obj)}
112
+ constexpr = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
113
+ non_constexpr = {n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)}
114
+ return constexpr, non_constexpr
115
+
116
+
117
+ def _new_from_mlir_values(self, values):
118
+ constexpr_fields, non_constexpr_fields = _partition_fields(self)
119
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
120
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
121
+ values = values[n_items:]
122
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
123
+
124
+
125
+ def _namedtuple_new_from_mlir_values(self, values):
126
+ """Generic __new_from_mlir_values__ for NamedTuples.
127
+
128
+ Applied to NamedTuple classes via the ``@mlir_namedtuple`` decorator.
129
+
130
+ Fields that are None or Constexpr (StaticTypes) are preserved from ``self`` (the compile-time
131
+ template). Only non-static fields consume MLIR values. Multi-value fields (e.g. cute.Tensor)
132
+ consume the correct number of values via ``cutlass.new_from_mlir_values``.
133
+
134
+ Constexpr fields (annotated ``cutlass.Constexpr[T]``) are baked into the compiled kernel via
135
+ a converter patch (see above). At call time, pass None for these fields.
136
+ """
137
+ from cutlass.base_dsl.typing import get_mlir_types
138
+
139
+ values = list(values)
140
+ new_fields = []
141
+ for field_val in self:
142
+ if field_val is None or isinstance(field_val, StaticTypes):
143
+ new_fields.append(field_val)
144
+ else:
145
+ n_items = len(get_mlir_types(field_val))
146
+ new_fields.append(cutlass.new_from_mlir_values(field_val, values[:n_items]))
147
+ values = values[n_items:]
148
+ return self.__class__(*new_fields)
149
+
150
+
151
+ def mlir_namedtuple(cls):
152
+ """Decorator that adds MLIR value reconstruction to a NamedTuple class.
153
+
154
+ Usage::
155
+
156
+ @mlir_namedtuple
157
+ class MyArgs(NamedTuple):
158
+ tensor_arg: cute.Tensor
159
+ const_arg: cutlass.Constexpr[int] = 0
160
+ """
161
+ cls.__new_from_mlir_values__ = _namedtuple_new_from_mlir_values
162
+ return cls
163
+
164
+
165
  @dataclass
166
  class ParamsBase:
167
  def __extract_mlir_values__(self):
168
+ _, non_constexpr_fields = _partition_fields(self)
 
169
  values, self._values_pos = [], []
170
+ for obj in non_constexpr_fields.values():
171
  obj_values = cutlass.extract_mlir_values(obj)
172
  values += obj_values
173
  self._values_pos.append(len(obj_values))
174
  return values
175
 
176
+ __new_from_mlir_values__ = _new_from_mlir_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/quack/epi_composable.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ """ComposableEpiMixin: composes EpiOps into epilogue hook methods.
3
+
4
+ Subclasses declare _epi_ops as a tuple of EpiOp instances. The mixin auto-generates
5
+ epi_smem_bytes_per_stage, epi_get_smem_struct, epi_get_smem_tensors, epi_begin,
6
+ epi_begin_loop, epi_end, and EpilogueParams by querying each op.
7
+
8
+ epi_begin and epi_begin_loop return dicts keyed by op name, so epi_visit_subtile
9
+ can access values by name (e.g. epi_loop_tensors["alpha"]).
10
+
11
+ EpilogueParams is auto-generated from _epi_ops (via param_fields()) plus any
12
+ _extra_param_fields declared on the subclass. Subclasses still define
13
+ EpilogueArguments and epi_to_underlying_arguments manually.
14
+ """
15
+
16
+ from dataclasses import make_dataclass, MISSING
17
+
18
+ import cutlass.cute as cute
19
+ from cutlass import const_expr
20
+
21
+ from .epi_ops import EpiContext, Scalar
22
+
23
+
24
+ def _compute_smem_map(ops):
25
+ """Pre-compute name → smem tensor index for each non-Scalar op."""
26
+ smem_map = {}
27
+ idx = 0
28
+ for op in ops:
29
+ if not isinstance(op, Scalar):
30
+ smem_map[op.name] = idx
31
+ idx += 1
32
+ return smem_map
33
+
34
+
35
+ def _make_epi_params(epi_ops, extra_fields, bases):
36
+ """Build EpilogueParams dataclass from epi_ops + extra fields.
37
+
38
+ Required fields (default=MISSING) are placed first, then optional fields.
39
+ """
40
+ required, optional = [], []
41
+ for op in epi_ops:
42
+ for name, typ, default in op.param_fields():
43
+ (required if default is MISSING else optional).append((name, typ, default))
44
+ for name, typ, default in extra_fields:
45
+ (required if default is MISSING else optional).append((name, typ, default))
46
+ fields = [(n, t) for n, t, _ in required] + [(n, t, d) for n, t, d in optional]
47
+ return make_dataclass("EpilogueParams", fields, bases=bases)
48
+
49
+
50
+ class ComposableEpiMixin:
51
+ """Base mixin that composes EpiOps into the standard epilogue hooks."""
52
+
53
+ _epi_ops = ()
54
+ _extra_param_fields = () # [(name, type, default), ...] for non-op params (e.g. act_fn)
55
+ _epi_param_bases = () # Base classes for EpilogueParams (e.g. (ParamsBase,))
56
+ _epi_smem_map = {}
57
+ _epi_has_async_ops = False
58
+
59
+ def __init_subclass__(cls, **kwargs):
60
+ super().__init_subclass__(**kwargs)
61
+ if cls._epi_ops:
62
+ cls._epi_smem_map = _compute_smem_map(cls._epi_ops)
63
+ cls._epi_has_async_ops = any(op.needs_async_fence() for op in cls._epi_ops)
64
+ # Auto-generate EpilogueParams if not explicitly defined on this class
65
+ if "EpilogueParams" not in cls.__dict__:
66
+ cls.EpilogueParams = _make_epi_params(
67
+ cls._epi_ops, cls._extra_param_fields, cls._epi_param_bases
68
+ )
69
+
70
+ # --- Host-side: args → params ---
71
+
72
+ def _epi_ops_to_params_dict(self, args):
73
+ """Merge each op's to_params into a single dict. Subclasses call this,
74
+ add custom fields, then construct self.EpilogueParams(**d)."""
75
+ d = {}
76
+ for op in self._epi_ops:
77
+ d.update(op.to_params(self, args))
78
+ return d
79
+
80
+ # --- Host-side: smem allocation (queried from ops) ---
81
+
82
+ @classmethod
83
+ def epi_smem_bytes_per_stage(cls, args, cta_tile_shape_mnk, epi_tile):
84
+ return sum(
85
+ op.smem_bytes(getattr(args, op.name, None), cta_tile_shape_mnk, epi_tile)
86
+ for op in cls._epi_ops
87
+ )
88
+
89
+ def epi_get_smem_struct(self, params):
90
+ fields = {}
91
+ for op in self._epi_ops:
92
+ result = op.smem_struct_field(self, params)
93
+ if result is not None:
94
+ name, ftype = result
95
+ fields[name] = ftype
96
+ EpiSharedStorage = type("EpiSharedStorage", (), {"__annotations__": fields})
97
+ return cute.struct(EpiSharedStorage)
98
+
99
+ def epi_get_smem_tensors(self, params, storage):
100
+ return tuple(
101
+ op.get_smem_tensor(self, params, storage.epi)
102
+ for op in self._epi_ops
103
+ if not isinstance(op, Scalar)
104
+ )
105
+
106
+ def epi_get_tma_atoms(self, params, *, loc=None, ip=None):
107
+ atoms = []
108
+ for op in self._epi_ops:
109
+ atoms.extend(op.tma_atoms(self, params))
110
+ return atoms
111
+
112
+ # --- Device-side: kernel execution (delegates to ops) ---
113
+
114
+ @cute.jit
115
+ def epi_begin(
116
+ self,
117
+ params,
118
+ epi_smem_tensors,
119
+ epi_tile,
120
+ tiled_copy_t2r,
121
+ tiled_copy_r2s,
122
+ tile_coord_mnkl,
123
+ varlen_manager,
124
+ epilogue_barrier,
125
+ tidx,
126
+ ):
127
+ ctx = EpiContext(
128
+ self,
129
+ epi_tile,
130
+ tiled_copy_t2r,
131
+ tiled_copy_r2s,
132
+ tile_coord_mnkl,
133
+ varlen_manager,
134
+ epilogue_barrier,
135
+ tidx,
136
+ )
137
+ smem_map = self._epi_smem_map
138
+ results = {
139
+ op.name: op.begin(
140
+ self,
141
+ getattr(params, op.name, None),
142
+ epi_smem_tensors[smem_map[op.name]] if op.name in smem_map else None,
143
+ ctx,
144
+ )
145
+ for op in self._epi_ops
146
+ }
147
+ if const_expr(self._epi_has_async_ops):
148
+ has_async_data = any(
149
+ getattr(params, op.name, None) is not None
150
+ for op in self._epi_ops
151
+ if op.needs_async_fence()
152
+ )
153
+ if const_expr(has_async_data):
154
+ cute.arch.cp_async_commit_group()
155
+ cute.arch.cp_async_wait_group(0)
156
+ epilogue_barrier.arrive_and_wait()
157
+ return results
158
+
159
+ def epi_begin_loop(self, params, epi_tensors, epi_coord):
160
+ return {
161
+ op.name: op.begin_loop(self, epi_tensors[op.name], epi_coord) for op in self._epi_ops
162
+ }
163
+
164
+ @cute.jit
165
+ def epi_end(
166
+ self,
167
+ params,
168
+ epi_tensors,
169
+ epi_tile,
170
+ tiled_copy_t2r,
171
+ tiled_copy_r2s,
172
+ tile_coord_mnkl,
173
+ varlen_manager,
174
+ tidx,
175
+ ):
176
+ for op in self._epi_ops:
177
+ op.end(
178
+ self,
179
+ getattr(params, op.name, None),
180
+ epi_tensors[op.name],
181
+ epi_tile,
182
+ tiled_copy_t2r,
183
+ tiled_copy_r2s,
184
+ tile_coord_mnkl,
185
+ varlen_manager,
186
+ tidx,
187
+ )
build/torch-cuda/quack/epi_ops.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ """Composable epilogue operations (EpiOps) for GEMM kernels.
3
+
4
+ Each EpiOp encapsulates a single tensor kind's behavior across the epilogue lifecycle:
5
+ smem allocation, begin (one-time per-tile setup), begin_loop (per-subtile extraction),
6
+ end (cleanup).
7
+
8
+ The ops are composed via ComposableEpiMixin which iterates over a static _epi_ops tuple
9
+ to generate epi_smem_bytes_per_stage, epi_get_smem_struct, epi_get_smem_tensors,
10
+ epi_begin, and epi_begin_loop automatically.
11
+ """
12
+
13
+ import math
14
+ import operator
15
+ from functools import partial
16
+
17
+ import cutlass
18
+ import cutlass.cute as cute
19
+ from cutlass import Boolean, Float32, const_expr
20
+
21
+ from .epi_utils import assume_stride_divisibility, setup_epi_tensor
22
+ from .sm90_utils import partition_for_epilogue
23
+ from . import utils as utils
24
+ from . import copy_utils as copy_utils
25
+ from . import layout_utils as layout_utils
26
+
27
+
28
+ class EpiContext:
29
+ """Shared context passed to EpiOp.begin methods. Bundles common arguments."""
30
+
31
+ __slots__ = (
32
+ "epi_tile",
33
+ "tiled_copy_t2r",
34
+ "tiled_copy_r2s",
35
+ "tile_coord_mnkl",
36
+ "varlen_manager",
37
+ "epilogue_barrier",
38
+ "tidx",
39
+ "partition_for_epilogue_fn",
40
+ "num_epi_threads",
41
+ "batch_idx",
42
+ "tile_M",
43
+ "tile_N",
44
+ )
45
+
46
+ def __init__(
47
+ self,
48
+ gemm,
49
+ epi_tile,
50
+ tiled_copy_t2r,
51
+ tiled_copy_r2s,
52
+ tile_coord_mnkl,
53
+ varlen_manager,
54
+ epilogue_barrier,
55
+ tidx,
56
+ ):
57
+ self.epi_tile = epi_tile
58
+ self.tiled_copy_t2r = tiled_copy_t2r
59
+ self.tiled_copy_r2s = tiled_copy_r2s
60
+ self.tile_coord_mnkl = tile_coord_mnkl
61
+ self.varlen_manager = varlen_manager
62
+ self.epilogue_barrier = epilogue_barrier
63
+ self.tidx = tidx
64
+ self.tile_M = gemm.cta_tile_shape_mnk[0]
65
+ self.tile_N = gemm.cta_tile_shape_mnk[1]
66
+ self.batch_idx = tile_coord_mnkl[3]
67
+ self.num_epi_threads = gemm.num_epi_warps * cute.arch.WARP_SIZE
68
+ self.partition_for_epilogue_fn = partial(
69
+ partition_for_epilogue,
70
+ epi_tile=epi_tile,
71
+ tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
72
+ tidx=tidx,
73
+ reference_src=tiled_copy_t2r is None,
74
+ )
75
+
76
+
77
+ def _get_lane_warp_layouts(tiled_copy, reference_src=True):
78
+ """Derive lane and warp layouts along M and N from the epilogue tiled_copy.
79
+
80
+ Follows the CUTLASS Sm90RowReduction / Sm90ColReduction pattern.
81
+ Uses layout_src_tv_tiled (SM90, reference_src=True) or
82
+ layout_dst_tv_tiled (SM100, reference_src=False), matching the C++ impl's
83
+ get_layoutS_TV / get_layoutD_TV selection.
84
+
85
+ Returns (lane_layout_MN, warp_layout_MN) where each is a 2D layout (M, N):
86
+ lane_layout_MN[0] = lane_M: (lanes_in_M):(lane_stride_M) — e.g. 8:4
87
+ lane_layout_MN[1] = lane_N: (lanes_in_N):(lane_stride_N) — e.g. 4:1
88
+ warp_layout_MN[0] = warp_M: (warps_in_M):(warp_stride_M) — e.g. 4:1
89
+ warp_layout_MN[1] = warp_N: (warps_in_N):(warp_stride_N) — e.g. 1:0
90
+
91
+ For RowVecReduce (reduce along M): shuffle across lane_M, smem reduce across warp_M.
92
+ For ColVecReduce (reduce along N): shuffle across lane_N, direct write (warps_in_N == 1).
93
+ """
94
+ # right_inverse of the TV layout gives tile_element_idx -> tv_idx.
95
+ # SM90: use src (register) layout; SM100: use dst (smem) layout.
96
+ layout_tv = tiled_copy.layout_src_tv_tiled if reference_src else tiled_copy.layout_dst_tv_tiled
97
+ ref_layout = cute.right_inverse(layout_tv)
98
+ tile_M_size, tile_N_size = cute.size(tiled_copy.tiler_mn[0]), cute.size(tiled_copy.tiler_mn[1])
99
+ ref_layout_MN = cute.composition(
100
+ ref_layout, cute.make_layout((tile_M_size, tile_N_size))
101
+ ) # (tile_M, tile_N) -> tv_idx
102
+
103
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
104
+
105
+ # tv2lane: tv_idx -> lane_idx (lane = tv_idx % 32)
106
+ tv2lane = cute.make_layout((cute.arch.WARP_SIZE, num_warps, 1), stride=(1, 0, 0))
107
+ ref2lane = cute.composition(tv2lane, ref_layout_MN) # (tile_M, tile_N) -> lane_idx
108
+ # select mode [0] = M part, [1] = N part; filter removes stride-0
109
+ lane_M = cute.filter(cute.select(ref2lane, [0])) # lane_m -> lane_idx
110
+ lane_N = cute.filter(cute.select(ref2lane, [1])) # lane_n -> lane_idx
111
+ lane_layout_MN = layout_utils.concat_layout(lane_M, lane_N) # (lane_M, lane_N) -> lane_idx
112
+
113
+ # tv2warp: tv_idx -> warp_idx (warp = tv_idx / 32)
114
+ tv2warp = cute.make_layout((cute.arch.WARP_SIZE, num_warps, 1), stride=(0, 1, 0))
115
+ ref2warp = cute.composition(tv2warp, ref_layout_MN) # (tile_M, tile_N) -> warp_idx
116
+ warp_M = cute.filter(cute.select(ref2warp, [0])) # warp_m -> warp_idx
117
+ warp_N = cute.filter(cute.select(ref2warp, [1])) # warp_n -> warp_idx
118
+ warp_layout_MN = layout_utils.concat_layout(warp_M, warp_N) # (warp_M, warp_N) -> warp_idx
119
+
120
+ return lane_layout_MN, warp_layout_MN
121
+
122
+
123
+ class EpiOp:
124
+ """Base class for composable epilogue operations."""
125
+
126
+ def __init__(self, name):
127
+ self.name = name
128
+
129
+ # --- Host-side: args → params ---
130
+ def param_fields(self):
131
+ """Return [(field_name, type, default), ...] for auto-generating EpilogueParams.
132
+ Must match the keys returned by to_params()."""
133
+ return []
134
+
135
+ def to_params(self, gemm, args):
136
+ """Convert this op's arg field(s) to param dict entries.
137
+ Returns dict of {param_name: value}. Like EVT's to_underlying_arguments."""
138
+ return {}
139
+
140
+ # --- Host-side: smem allocation ---
141
+ def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile):
142
+ """Bytes of smem needed per stage. arg_tensor is the EpilogueArguments field."""
143
+ return 0
144
+
145
+ def smem_struct_field(self, gemm, params):
146
+ """Return (field_name, field_type) for @cute.struct, or None if no smem needed.
147
+ params is the full EpilogueParams object."""
148
+ return None
149
+
150
+ def get_smem_tensor(self, gemm, params, storage_epi):
151
+ """Extract smem tensor from storage.epi. Returns tensor or None.
152
+ params is the full EpilogueParams object."""
153
+ return None
154
+
155
+ def tma_atoms(self, gemm, params):
156
+ """Return list of TMA atoms for this op."""
157
+ return []
158
+
159
+ # --- Device-side: kernel execution ---
160
+ @cute.jit
161
+ def begin(self, gemm, param, smem_tensor, ctx):
162
+ """One-time per-tile setup. Returns state for begin_loop."""
163
+ return None
164
+
165
+ def begin_loop(self, gemm, state, epi_coord):
166
+ """Per-subtile extraction. Returns value for epi_visit_subtile."""
167
+ return state
168
+
169
+ def needs_async_fence(self):
170
+ """Whether this op issues async copies that need a fence."""
171
+ return False
172
+
173
+ def end(
174
+ self,
175
+ gemm,
176
+ param,
177
+ state,
178
+ epi_tile,
179
+ tiled_copy_t2r,
180
+ tiled_copy_r2s,
181
+ tile_coord_mnkl,
182
+ varlen_manager,
183
+ tidx,
184
+ ):
185
+ """Cleanup after all subtiles (reductions, direct writes)."""
186
+ pass
187
+
188
+
189
+ class Scalar(EpiOp):
190
+ """Loads a scalar value or device pointer once per tile. No smem."""
191
+
192
+ def __init__(self, name, dtype=None):
193
+ super().__init__(name)
194
+ self.dtype = dtype
195
+
196
+ def param_fields(self):
197
+ return [(self.name, object, None)]
198
+
199
+ def to_params(self, gemm, args):
200
+ return {self.name: getattr(args, self.name)}
201
+
202
+ @cute.jit
203
+ def begin(self, gemm, param, smem_tensor, ctx):
204
+ result = None
205
+ if const_expr(param is not None):
206
+ result = (
207
+ utils.load_scalar_or_pointer(param, dtype=self.dtype)
208
+ if const_expr(self.dtype is not None)
209
+ else utils.load_scalar_or_pointer(param)
210
+ )
211
+ return result
212
+
213
+
214
+ class VecLoad(EpiOp):
215
+ """Base class for broadcast vector loads (row or col) via cp_async.
216
+
217
+ Subclasses set `dim` to 0 (M/col) or 1 (N/row) and override `_get_gmem_vec`
218
+ for varlen handling.
219
+ """
220
+
221
+ dim = None # 0 for col (M), 1 for row (N)
222
+
223
+ def param_fields(self):
224
+ return [(self.name, object, None)]
225
+
226
+ def to_params(self, gemm, args):
227
+ return {self.name: assume_stride_divisibility(getattr(args, self.name))}
228
+
229
+ def _tile_size(self, cta_tile_shape_mnk):
230
+ return cta_tile_shape_mnk[self.dim]
231
+
232
+ def _broadcast_stride(self):
233
+ # Row: stride (0,1) — broadcast along M. Col: stride (1,0) — broadcast along N.
234
+ return (0, 1) if self.dim == 1 else (1, 0)
235
+
236
+ def _tile_dim(self, ctx):
237
+ return ctx.tile_N if self.dim == 1 else ctx.tile_M
238
+
239
+ def _coord_idx(self):
240
+ return 1 if self.dim == 1 else 0
241
+
242
+ def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile):
243
+ if arg_tensor is None:
244
+ return 0
245
+ return self._tile_size(cta_tile_shape_mnk) * (arg_tensor.element_type.width // 8)
246
+
247
+ def smem_struct_field(self, gemm, params):
248
+ tensor = getattr(params, self.name, None)
249
+ if tensor is None:
250
+ size, dtype = 0, Float32
251
+ else:
252
+ size = self._tile_size(gemm.cta_tile_shape_mnk)
253
+ dtype = tensor.element_type
254
+ return (f"s_{self.name}", cute.struct.Align[cute.struct.MemRange[dtype, size], 16])
255
+
256
+ def get_smem_tensor(self, gemm, params, storage_epi):
257
+ if getattr(params, self.name, None) is None:
258
+ return None
259
+ return getattr(storage_epi, f"s_{self.name}").get_tensor(
260
+ cute.make_layout(self._tile_size(gemm.cta_tile_shape_mnk))
261
+ )
262
+
263
+ def needs_async_fence(self):
264
+ return True
265
+
266
+ def _get_gmem_vec(self, param, ctx):
267
+ """Get the global memory vector for this tile. Override for varlen."""
268
+ return param[ctx.batch_idx, None]
269
+
270
+ @cute.jit
271
+ def begin(self, gemm, param, smem_tensor, ctx):
272
+ tDsV = None
273
+ if const_expr(param is not None):
274
+ dtype = param.element_type
275
+ num_copy_elems = const_expr(max(32, dtype.width)) // dtype.width
276
+ thr_copy = copy_utils.tiled_copy_1d(
277
+ dtype, ctx.num_epi_threads, num_copy_elems, is_async=True
278
+ ).get_slice(ctx.tidx)
279
+ mVec = self._get_gmem_vec(param, ctx)
280
+ tile_dim = self._tile_dim(ctx)
281
+ coord_idx = ctx.tile_coord_mnkl[self._coord_idx()]
282
+ gVec = cute.local_tile(mVec, (tile_dim,), (coord_idx,))
283
+ tVgV = thr_copy.partition_S(gVec)
284
+ tVsV = thr_copy.partition_D(smem_tensor)
285
+ tVcV = thr_copy.partition_S(cute.make_identity_tensor(tile_dim))
286
+ limit = min(cute.size(mVec, mode=[0]) - coord_idx * tile_dim, tile_dim)
287
+ pred = cute.make_rmem_tensor((1, cute.size(tVsV.shape[1])), Boolean)
288
+ for m in cutlass.range(cute.size(tVsV.shape[1]), unroll_full=True):
289
+ pred[0, m] = tVcV[0, m] < limit
290
+ cute.copy(thr_copy, tVgV, tVsV, pred=pred)
291
+ tDsV = ctx.partition_for_epilogue_fn(
292
+ cute.make_tensor(
293
+ smem_tensor.iterator,
294
+ cute.make_layout((ctx.tile_M, ctx.tile_N), stride=self._broadcast_stride()),
295
+ )
296
+ )
297
+ if const_expr(ctx.tiled_copy_t2r is not None):
298
+ tDsV = ctx.tiled_copy_r2s.retile(tDsV)
299
+ return tDsV
300
+
301
+ @cute.jit
302
+ def begin_loop(self, gemm, state, epi_coord):
303
+ tDrV_cvt = None
304
+ if const_expr(state is not None):
305
+ tDsV_cur = cute.group_modes(state, 3, cute.rank(state))[None, None, None, epi_coord]
306
+ tDrV = cute.make_rmem_tensor(tDsV_cur.layout, tDsV_cur.element_type)
307
+ cute.autovec_copy(cute.filter_zeros(tDsV_cur), cute.filter_zeros(tDrV))
308
+ tDrV_cvt = cute.make_rmem_tensor_like(tDrV, gemm.acc_dtype)
309
+ tDrV_cvt.store(tDrV.load().to(gemm.acc_dtype))
310
+ return tDrV_cvt
311
+
312
+
313
+ class RowVecLoad(VecLoad):
314
+ """Loads a row vector (N,) via cp_async, broadcasts along M with stride (0,1)."""
315
+
316
+ dim = 1
317
+
318
+
319
+ class ColVecLoad(VecLoad):
320
+ """Loads a col vector (M,) via cp_async, broadcasts along N with stride (1,0).
321
+
322
+ Optimization: with N-major subtile loop, consecutive epi_n iterations for the same
323
+ epi_m share the same column data. The smem→register copy only runs when epi_n == 0.
324
+ Supports varlen_m via domain_offset.
325
+ """
326
+
327
+ dim = 0
328
+
329
+ @cute.jit
330
+ def _get_gmem_vec(self, param, ctx):
331
+ if const_expr(not ctx.varlen_manager.varlen_m):
332
+ mVec = param[ctx.batch_idx, None]
333
+ else:
334
+ mVec = cute.domain_offset(
335
+ (ctx.varlen_manager.params.cu_seqlens_m[ctx.batch_idx],), param
336
+ )
337
+ return mVec
338
+
339
+ @cute.jit
340
+ def begin(self, gemm, param, smem_tensor, ctx):
341
+ tDsV = None
342
+ tDrV_cvt = None
343
+ if const_expr(param is not None):
344
+ dtype = param.element_type
345
+ num_copy_elems = const_expr(max(32, dtype.width)) // dtype.width
346
+ thr_copy = copy_utils.tiled_copy_1d(
347
+ dtype, ctx.num_epi_threads, num_copy_elems, is_async=True
348
+ ).get_slice(ctx.tidx)
349
+ mVec = self._get_gmem_vec(param, ctx)
350
+ tile_dim = self._tile_dim(ctx)
351
+ coord_idx = ctx.tile_coord_mnkl[self._coord_idx()]
352
+ gVec = cute.local_tile(mVec, (tile_dim,), (coord_idx,))
353
+ tVgV = thr_copy.partition_S(gVec)
354
+ tVsV = thr_copy.partition_D(smem_tensor)
355
+ tVcV = thr_copy.partition_S(cute.make_identity_tensor(tile_dim))
356
+ # ColVec uses varlen-aware limit
357
+ limit = min(
358
+ ctx.varlen_manager.len_m(ctx.batch_idx) - coord_idx * tile_dim,
359
+ tile_dim,
360
+ )
361
+ pred = cute.make_rmem_tensor((1, cute.size(tVsV.shape[1])), Boolean)
362
+ for m in cutlass.range(cute.size(tVsV.shape[1]), unroll_full=True):
363
+ pred[0, m] = tVcV[0, m] < limit
364
+ cute.copy(thr_copy, tVgV, tVsV, pred=pred)
365
+ tDsV = ctx.partition_for_epilogue_fn(
366
+ cute.make_tensor(
367
+ smem_tensor.iterator,
368
+ cute.make_layout((ctx.tile_M, ctx.tile_N), stride=self._broadcast_stride()),
369
+ )
370
+ )
371
+ if const_expr(ctx.tiled_copy_t2r is not None):
372
+ tDsV = ctx.tiled_copy_r2s.retile(tDsV)
373
+ # Pre-allocate register tensor reused across begin_loop calls
374
+ tDsV_sub = cute.group_modes(tDsV, 3, cute.rank(tDsV))[None, None, None, 0]
375
+ tDrV_cvt = cute.make_rmem_tensor(tDsV_sub.layout, gemm.acc_dtype)
376
+ return [tDsV, tDrV_cvt]
377
+
378
+ @cute.jit
379
+ def begin_loop(self, gemm, state, epi_coord):
380
+ tDsV, tDrV_cvt = state[0], state[1]
381
+ if const_expr(tDsV is not None):
382
+ # Col vector is constant across N subtiles — only copy on first N subtile.
383
+ # Assumes N-major epi subtile order: epi_tile_layout = ordered_layout(..., order=(1,0))
384
+ epi_n = epi_coord[1]
385
+ if epi_n == 0:
386
+ tDsV_cur = cute.group_modes(tDsV, 3, cute.rank(tDsV))[None, None, None, epi_coord]
387
+ tDrV = cute.make_rmem_tensor(tDsV_cur.layout, tDsV_cur.element_type)
388
+ cute.autovec_copy(cute.filter_zeros(tDsV_cur), cute.filter_zeros(tDrV))
389
+ tDrV_cvt.store(tDrV.load().to(gemm.acc_dtype))
390
+ return tDrV_cvt
391
+
392
+
393
+ class TileStore(EpiOp):
394
+ """Tile-sized output tensor stored via TMA (e.g. postact).
395
+
396
+ Args:
397
+ name: field name in EpilogueArguments/Params (e.g. "mPostAct")
398
+ epi_tile_fn: optional (gemm, epi_tile) -> epi_tile for half-tile (GemmGated)
399
+ """
400
+
401
+ def __init__(self, name, epi_tile_fn=None):
402
+ super().__init__(name)
403
+ self.epi_tile_fn = epi_tile_fn
404
+
405
+ def _tma_atom_key(self):
406
+ return f"tma_atom_{self.name}"
407
+
408
+ def _smem_layout_key(self):
409
+ return f"epi_{self.name}_smem_layout_staged"
410
+
411
+ def _epi_tile_key(self):
412
+ return f"epi_tile_{self.name}"
413
+
414
+ def param_fields(self):
415
+ from dataclasses import MISSING
416
+
417
+ return [
418
+ (self._tma_atom_key(), object, MISSING),
419
+ (self.name, object, MISSING),
420
+ (self._smem_layout_key(), object, MISSING),
421
+ (self._epi_tile_key(), object, MISSING),
422
+ ]
423
+
424
+ def to_params(self, gemm, args):
425
+ tensor = getattr(args, self.name)
426
+ epi_tile = self.epi_tile_fn(gemm, gemm.epi_tile) if self.epi_tile_fn else None
427
+ tma_atom, tma_tensor, smem_layout, epi_tile_out = setup_epi_tensor(
428
+ gemm, tensor, epi_tile=epi_tile
429
+ )
430
+ return {
431
+ self._tma_atom_key(): tma_atom,
432
+ self.name: tma_tensor,
433
+ self._smem_layout_key(): smem_layout,
434
+ self._epi_tile_key(): epi_tile_out,
435
+ }
436
+
437
+ def smem_bytes(self, arg_tensor, cta_tile_shape_mnk, epi_tile):
438
+ if arg_tensor is None:
439
+ return 0
440
+ if self.epi_tile_fn is not None:
441
+ epi_tile = self.epi_tile_fn(None, epi_tile)
442
+ return cute.size(cute.shape(epi_tile)) * (arg_tensor.element_type.width // 8)
443
+
444
+ def smem_struct_field(self, gemm, params):
445
+ smem_layout_key = self._smem_layout_key()
446
+ if not hasattr(params, smem_layout_key):
447
+ return (f"s_{self.name}", cute.struct.MemRange[Float32, 0])
448
+ return (
449
+ f"s_{self.name}",
450
+ cute.struct.Align[
451
+ cute.struct.MemRange[
452
+ gemm.postact_dtype,
453
+ cute.cosize(getattr(params, smem_layout_key)),
454
+ ],
455
+ gemm.buffer_align_bytes,
456
+ ],
457
+ )
458
+
459
+ def get_smem_tensor(self, gemm, params, storage_epi):
460
+ smem_layout_key = self._smem_layout_key()
461
+ if not hasattr(params, smem_layout_key):
462
+ return None
463
+ smem_layout = getattr(params, smem_layout_key)
464
+ return getattr(storage_epi, f"s_{self.name}").get_tensor(
465
+ smem_layout.outer,
466
+ swizzle=smem_layout.inner,
467
+ )
468
+
469
+ def tma_atoms(self, gemm, params):
470
+ tma_key = self._tma_atom_key()
471
+ if hasattr(params, tma_key):
472
+ return [getattr(params, tma_key)]
473
+ return []
474
+
475
+
476
+ @cute.jit
477
+ def vec_multiply(gemm, tRS_rD, tDrColVec, tDrRowVec):
478
+ """Multiply tRS_rD by colvec and/or rowvec in-place. Uses packed f32x2 on SM100+."""
479
+ if const_expr(tDrColVec is not None):
480
+ if const_expr(gemm.arch < 100):
481
+ for i in cutlass.range(cute.size(tDrColVec), unroll_full=True):
482
+ tRS_rD[i] *= tDrColVec[i]
483
+ else:
484
+ for i in cutlass.range(cute.size(tRS_rD) // 2, unroll_full=True):
485
+ tRS_rD[2 * i], tRS_rD[2 * i + 1] = cute.arch.mul_packed_f32x2(
486
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
487
+ (tDrColVec[2 * i], tDrColVec[2 * i + 1]),
488
+ )
489
+ if const_expr(tDrRowVec is not None):
490
+ if const_expr(gemm.arch < 100):
491
+ for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True):
492
+ tRS_rD[i] *= tDrRowVec[i]
493
+ else:
494
+ for i in cutlass.range(cute.size(tRS_rD) // 2, unroll_full=True):
495
+ tRS_rD[2 * i], tRS_rD[2 * i + 1] = cute.arch.mul_packed_f32x2(
496
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
497
+ (tDrRowVec[2 * i], tDrRowVec[2 * i + 1]),
498
+ )
499
+
500
+
501
+ @cute.jit
502
+ def colvec_reduce_accumulate(gemm, tDrReduce, tRS_rInput, transform_fn=None, rScale=None):
503
+ """Accumulate transform_fn(input) or input * rScale into a ColVecReduce buffer.
504
+
505
+ If transform_fn is provided, accumulates transform_fn(input[i]).
506
+ If rScale is provided, accumulates input[i] * rScale[i] (uses mul/fma for SM100).
507
+ If neither, accumulates input directly (identity).
508
+ """
509
+ if const_expr(tDrReduce is not None):
510
+ if const_expr(transform_fn is None):
511
+ transform_fn = lambda x: x
512
+ if const_expr(gemm.arch < 100):
513
+ for i in cutlass.range(cute.size(tDrReduce), unroll_full=True):
514
+ val = transform_fn(tRS_rInput[i])
515
+ tDrReduce[i] += val * rScale[i] if const_expr(rScale is not None) else val
516
+ else:
517
+ tDrReduce_mn = layout_utils.convert_layout_zero_stride(tDrReduce, tDrReduce.layout)
518
+ tRS_rInput_mn = layout_utils.convert_layout_zero_stride(tRS_rInput, tDrReduce.layout)
519
+ if const_expr(rScale is not None):
520
+ rScale_mn = layout_utils.convert_layout_zero_stride(rScale, tDrReduce.layout)
521
+ for m in cutlass.range(cute.size(tDrReduce_mn, mode=[0]), unroll_full=True):
522
+ inp = lambda n: (tRS_rInput_mn[m, 2 * n], tRS_rInput_mn[m, 2 * n + 1])
523
+ val0 = transform_fn(inp(0))
524
+ if const_expr(rScale is not None):
525
+ row_sum = cute.arch.mul_packed_f32x2(val0, (rScale_mn[m, 0], rScale_mn[m, 1]))
526
+ else:
527
+ row_sum = val0
528
+ for n in cutlass.range(1, cute.size(tDrReduce_mn, mode=[1]) // 2, unroll_full=True):
529
+ val = transform_fn(inp(n))
530
+ if const_expr(rScale is not None):
531
+ row_sum = cute.arch.fma_packed_f32x2(
532
+ val, (rScale_mn[m, 2 * n], rScale_mn[m, 2 * n + 1]), row_sum
533
+ )
534
+ else:
535
+ row_sum = cute.arch.add_packed_f32x2(val, row_sum)
536
+ tDrReduce_mn[m, 0] += row_sum[0] + row_sum[1]
537
+
538
+
539
+ class ColVecReduce(EpiOp):
540
+ """Column vector reduction: accumulates across N subtiles in registers,
541
+ then warp-reduces and writes to gmem in epi_end.
542
+
543
+ No smem. The accumulation itself happens in epi_visit_subtile (user code).
544
+ This op handles the register allocation (begin), per-subtile slicing (begin_loop),
545
+ and final warp reduction + gmem write (end).
546
+ """
547
+
548
+ def param_fields(self):
549
+ return [(self.name, object, None)]
550
+
551
+ def to_params(self, gemm, args):
552
+ return {self.name: assume_stride_divisibility(getattr(args, self.name))}
553
+
554
+ @cute.jit
555
+ def begin(self, gemm, param, smem_tensor, ctx):
556
+ tDrReduce = None
557
+ if const_expr(param is not None):
558
+ colvec_mma_layout = cute.make_layout((ctx.tile_M, ctx.tile_N), stride=(1, 0))
559
+ tDrReduce_layout = ctx.partition_for_epilogue_fn(
560
+ cute.make_rmem_tensor(colvec_mma_layout, Float32)
561
+ ).layout
562
+ tDrReduce = cute.make_rmem_tensor(tDrReduce_layout, Float32)
563
+ cute.filter_zeros(tDrReduce).fill(0.0)
564
+ return tDrReduce
565
+
566
+ @cute.jit
567
+ def begin_loop(self, gemm, state, epi_coord):
568
+ result = None
569
+ if const_expr(state is not None):
570
+ result = cute.group_modes(state, 3, cute.rank(state))[None, None, None, epi_coord]
571
+ return result
572
+
573
+ @cute.jit
574
+ def end(
575
+ self,
576
+ gemm,
577
+ param,
578
+ state,
579
+ epi_tile,
580
+ tiled_copy_t2r,
581
+ tiled_copy_r2s,
582
+ tile_coord_mnkl,
583
+ varlen_manager,
584
+ tidx,
585
+ ):
586
+ """Intra-warp shuffle reduction across N lanes, then direct gmem write."""
587
+ if const_expr(param is not None):
588
+ tDrReduce = state
589
+ tiled_copy = tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s
590
+ reference_src = tiled_copy_t2r is None
591
+
592
+ # ── Derive lane layout from tiled_copy ──
593
+ lane_layout_MN, warp_layout_MN = _get_lane_warp_layouts(tiled_copy, reference_src)
594
+ # For ColVecReduce: reduce across N lanes (lanes_in_N threads share same M row)
595
+ lanes_in_N = cute.size(lane_layout_MN, mode=[1])
596
+ # Typically lanes_in_N is 4 for Sm90
597
+ assert lanes_in_N == 1 << int(math.log2(lanes_in_N)), (
598
+ "lanes_in_N must be a power of 2 for butterfly reduction"
599
+ )
600
+
601
+ # ── Intra-warp shuffle reduction across N lanes ──
602
+ if const_expr(lanes_in_N > 1):
603
+ assert lane_layout_MN.stride[1] == 1
604
+ tDrReduce_flt = cute.filter_zeros(tDrReduce)
605
+ for i in cutlass.range(cute.size(tDrReduce_flt), unroll_full=True):
606
+ tDrReduce_flt[i] = cute.arch.warp_reduction(
607
+ tDrReduce_flt[i], operator.add, threads_in_group=lanes_in_N
608
+ )
609
+
610
+ warp_N = warp_layout_MN[1]
611
+ assert cute.size(warp_N) == 1, (
612
+ "ColVecReduce assumes all reduction cols are within the same warp"
613
+ )
614
+
615
+ # ── Direct gmem write (no inter-warp reduction needed: warps_in_N == 1) ──
616
+ partition_for_epilogue_fn = partial(
617
+ partition_for_epilogue,
618
+ epi_tile=epi_tile,
619
+ tiled_copy=tiled_copy,
620
+ tidx=tidx,
621
+ reference_src=tiled_copy_t2r is None,
622
+ )
623
+ tile_M, tile_N = gemm.cta_tile_shape_mnk[:2]
624
+ batch_idx = tile_coord_mnkl[3]
625
+ limit_n = param.shape[2] if not varlen_manager.varlen_m else param.shape[1]
626
+ if tile_coord_mnkl[1] < limit_n:
627
+ if const_expr(not varlen_manager.varlen_m):
628
+ mColVec = param[batch_idx, None, tile_coord_mnkl[1]]
629
+ else:
630
+ mColVec = cute.domain_offset(
631
+ (varlen_manager.params.cu_seqlens_m[batch_idx],),
632
+ param[None, tile_coord_mnkl[1]],
633
+ )
634
+ gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
635
+ limit_m = min(
636
+ varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M,
637
+ tile_M,
638
+ )
639
+ tDcD = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N)))
640
+ tDrReduce_m = layout_utils.convert_layout_zero_stride(tDrReduce, tDrReduce.layout)[
641
+ None, 0
642
+ ]
643
+ tDcD_m = layout_utils.convert_layout_zero_stride(tDcD, tDrReduce.layout)[None, 0]
644
+ if tDcD_m[0][1] == 0:
645
+ for m in cutlass.range(cute.size(tDcD_m, mode=[0])):
646
+ row_idx = tDcD_m[m][0]
647
+ if row_idx < limit_m:
648
+ gColVec[row_idx] = tDrReduce_m[m]
build/torch-cuda/quack/epi_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ """Epilogue utilities: shared helpers for epilogue mixin classes."""
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ import cutlass.utils.blackwell_helpers as sm100_utils
7
+
8
+ from . import sm90_utils as sm90_utils
9
+ from . import copy_utils as copy_utils
10
+
11
+
12
+ def assume_stride_divisibility(tensor):
13
+ """Assume all strides are divisible by 32 bits (except static strides).
14
+
15
+ Used for broadcast vectors and similar tensors where stride alignment is guaranteed.
16
+ Returns a new tensor with the assumed strides.
17
+ """
18
+ if tensor is None:
19
+ return None
20
+ new_stride = tuple(
21
+ cute.assume(s, divby=32 // tensor.element_type.width) if not cute.is_static(s) else s
22
+ for s in tensor.stride
23
+ )
24
+ return cute.make_tensor(tensor.iterator, cute.make_layout(tensor.shape, stride=new_stride))
25
+
26
+
27
+ def assume_broadcast_strides(*tensors):
28
+ """Apply stride divisibility assumptions to multiple broadcast vectors.
29
+
30
+ Returns a list with None preserved for None inputs.
31
+ """
32
+ return [assume_stride_divisibility(t) for t in tensors]
33
+
34
+
35
+ def setup_epi_tensor(gemm, tensor, epi_tile=None, op_type="store"):
36
+ """Create TMA atom + smem layout for a supplemental epilogue tensor.
37
+
38
+ Args:
39
+ gemm: The GEMM object (provides arch, epi_stage, _make_tma_epi_atoms_and_tensors).
40
+ tensor: The global memory tensor to set up TMA for.
41
+ epi_tile: Epilogue tile shape. Defaults to gemm.epi_tile.
42
+ op_type: "store" or "load".
43
+
44
+ Returns:
45
+ (tma_atom, tma_tensor, smem_layout_staged, epi_tile)
46
+ """
47
+ if epi_tile is None:
48
+ epi_tile = gemm.epi_tile
49
+ dtype = tensor.element_type
50
+ layout = cutlass.utils.LayoutEnum.from_tensor(tensor)
51
+ utils_cls = sm100_utils if gemm.arch >= 100 else sm90_utils
52
+ smem_layout_staged = utils_cls.make_smem_layout_epi(dtype, layout, epi_tile, gemm.epi_stage)
53
+ tma_input = (
54
+ copy_utils.create_ragged_tensor_for_tma(tensor, ragged_dim=0, ptr_shift=True)
55
+ if cute.rank(tensor) == 2
56
+ else tensor
57
+ )
58
+ tma_atom, tma_tensor = gemm._make_tma_epi_atoms_and_tensors(
59
+ tma_input,
60
+ smem_layout_staged,
61
+ epi_tile,
62
+ op_type=op_type,
63
+ )
64
+ return tma_atom, tma_tensor, smem_layout_staged, epi_tile
build/torch-cuda/quack/fast_math.py CHANGED
@@ -1,80 +1,33 @@
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
- from typing import Tuple
4
- from dataclasses import dataclass
5
-
6
  import cutlass
7
  import cutlass.cute as cute
8
- from cutlass import Int32, Uint32
9
- from cutlass.cutlass_dsl import T, dsl_user_op
10
- from cutlass._mlir.dialects import llvm
11
-
12
- from .cute_dsl_utils import ParamsBase
13
-
14
-
15
- @cute.jit
16
- def clz(x: Int32) -> Int32:
17
- # for i in cutlass.range_constexpr(32):
18
- # if (1 << (31 - i)) & x:
19
- # return Int32(i)
20
- # return Int32(32)
21
- # Early exit is not supported yet
22
- res = Int32(32)
23
- done = False
24
- for i in cutlass.range(32):
25
- if ((1 << (31 - i)) & x) and not done:
26
- res = Int32(i)
27
- done = True
28
- return res
29
-
30
-
31
- def find_log2(x: Int32) -> Int32:
32
- a: Int32 = Int32(31 - clz(x))
33
- return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
34
-
35
-
36
- @dsl_user_op
37
- def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
38
- return Uint32(
39
- llvm.inline_asm(
40
- T.i32(),
41
- [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
42
- "mul.hi.u32 $0, $1, $2;",
43
- "=r,r,r",
44
- has_side_effects=False,
45
- is_align_stack=False,
46
- asm_dialect=llvm.AsmDialect.AD_ATT,
47
- )
48
- )
49
-
50
-
51
- @dataclass
52
- class FastDivmod(ParamsBase):
53
- divisor: Int32
54
- multiplier: Uint32
55
- shift_right: Uint32
56
-
57
- # called by host
58
- @staticmethod
59
- def create(divisor: Int32) -> "FastDivmod":
60
- """Construct the FastDivmod object, in host code.
61
- This precomputes some values based on the divisor and is computationally expensive.
62
- """
63
- p = Uint32(31 + find_log2(divisor))
64
- divisor_u32 = Uint32(divisor)
65
- multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66
- shift_right = Uint32(p - 32)
67
- return FastDivmod(divisor, multiplier, shift_right)
68
-
69
- @cute.jit
70
- def div(self, dividend: Int32) -> Int32:
71
- return (
72
- Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
73
- if self.divisor != 1
74
- else dividend
75
- )
76
-
77
- def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
78
- quotient = self.div(dividend)
79
- remainder = dividend - quotient * self.divisor
80
- return quotient, remainder
 
1
  # Copyright (c) 2025, Tri Dao.
2
 
 
 
 
3
  import cutlass
4
  import cutlass.cute as cute
5
+ from cutlass.base_dsl.typing import Integer
6
+ from cutlass.cutlass_dsl import dsl_user_op
7
+
8
+
9
+ class FastDivmod(cute.FastDivmodDivisor):
10
+ """We store the divisor along with the FastDivmodDivisor."""
11
+
12
+ @dsl_user_op
13
+ def __init__(
14
+ self,
15
+ divisor: Integer,
16
+ is_power_of_2: bool = None,
17
+ *,
18
+ loc=None,
19
+ ip=None,
20
+ ):
21
+ super().__init__(divisor, is_power_of_2=is_power_of_2, loc=loc, ip=ip)
22
+ self.divisor = divisor
23
+
24
+ def __extract_mlir_values__(self):
25
+ """Extract MLIR values for Host->Device transfer."""
26
+ return [self._divisor] + cutlass.extract_mlir_values(self.divisor)
27
+
28
+ def __new_from_mlir_values__(self, values):
29
+ """Reconstruct FastDivmodDivisor from MLIR values."""
30
+ new_obj = object.__new__(FastDivmod)
31
+ new_obj._divisor = values[0]
32
+ new_obj.divisor = cutlass.new_from_mlir_values(self.divisor, values[1:])
33
+ return new_obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/quack/gemm.py CHANGED
@@ -1,16 +1,141 @@
 
 
 
1
  from typing import Optional
2
- from functools import partial
3
 
4
  from torch import Tensor
5
 
6
  import cutlass.cute as cute
7
- import cutlass.torch as cutlass_torch
8
- from cutlass import Float32
9
- from cutlass.cute.runtime import from_dlpack, make_ptr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
12
- from .gemm_wrapper_utils import GemmWrapperBase
13
- from .gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def gemm(
@@ -26,6 +151,7 @@ def gemm(
26
  cluster_N: int,
27
  pingpong: bool = False,
28
  persistent: bool = True,
 
29
  max_swizzle_size: int = 8,
30
  rowvec_bias: Optional[Tensor] = None, # (l, n)
31
  colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
@@ -36,159 +162,121 @@ def gemm(
36
  A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
37
  batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
38
  add_to_output: bool = False,
 
 
 
 
 
39
  ) -> None:
40
- varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
41
- assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
42
- "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
43
- )
44
  gather_A = A_idx is not None
 
45
  if gather_A:
46
- assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
47
  assert cluster_N == 1, "gather_A requires cluster_N=1"
48
  if varlen:
49
  assert persistent, "varlen requires persistent=True"
50
  if add_to_output:
51
- assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
52
- if cu_seqlens_m is not None:
53
  assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
54
  assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
55
- if cu_seqlens_k is not None:
56
  assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
57
  assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
58
 
59
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
60
- A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
- GemmWrapperBase.permute_tensors(
63
- tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
- GemmWrapperBase.extract_dtypes(tensor_infos)
66
- major_configs = {
67
- "A": ("m", "k", "l"),
68
- "B": ("n", "k", "l"),
69
- "D": ("m", "n", "l"),
70
- "C": ("m", "n", "l"),
71
- }
72
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
73
 
74
- device_capacity = get_device_capacity(A.device)
75
- assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
76
- GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
77
-
78
- acc_dtype = Float32
79
- tile_shape_mn = (tile_M, tile_N)
80
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
81
- if not GemmCls.is_valid_dtypes(
82
- tensor_infos["A"].dtype,
83
- tensor_infos["B"].dtype,
84
- acc_dtype,
85
- tensor_infos["D"].dtype,
86
- tensor_infos["A"].major,
87
- tensor_infos["B"].major,
88
- ):
89
- raise TypeError("Skipping due to unsupported combination of types and majors")
90
 
91
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
92
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
93
 
94
- def scalar_arg(scalar: float | Tensor):
95
- if isinstance(scalar, float):
96
- return Float32(scalar) if scalar != 1.0 else None
 
 
97
  else:
98
- assert isinstance(scalar, Tensor)
99
- return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
100
 
101
- epi_args = GemmCls.EpilogueArguments(
102
- scalar_arg(alpha),
103
- scalar_arg(beta),
104
- mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
105
- leading_dim=1
106
- )
107
- if rowvec_bias is not None
108
- else None,
109
- mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
110
- leading_dim=1 if cu_seqlens_m is None else 0
111
- )
112
- if colvec_bias is not None
113
- else None,
114
- add_to_output=add_to_output,
115
  )
116
- scheduler_args = GemmWrapperBase.create_scheduler_args(
117
  max_active_clusters,
 
118
  tile_count_semaphore,
119
  batch_idx_permute,
120
- max_swizzle_size,
121
- )
122
-
123
- # Create varlen arguments if needed (assumes persistent=True when varlen)
124
- varlen_args = GemmWrapperBase.create_varlen_args(
125
- cu_seqlens_m,
126
- cu_seqlens_k,
127
- A_idx,
128
- max_active_clusters,
129
- cluster_shape_mnk,
130
- tensor_infos,
131
- GemmCls.num_epi_tensormaps,
132
- pingpong,
133
  )
 
134
 
135
- current_stream = cutlass_torch.current_stream()
136
- compile_key = GemmWrapperBase.get_compile_key(
137
- tensor_infos,
138
- None, # activation
139
- tile_shape_mn,
140
- cluster_shape_mnk,
141
- pingpong,
142
- persistent,
143
- tile_count_semaphore is not None,
144
- device_capacity,
145
- # Technically we don't need to recompile for different max_swizzle_size, but currently
146
- # not recompiling will skew the autotuning results due to power throttling.
147
- # Effectively we're recompiling as a way to pause between benchmarks during autotuning.
148
- max_swizzle_size,
149
- rowvec_bias.dtype if rowvec_bias is not None else None,
150
- colvec_bias.dtype if colvec_bias is not None else None,
151
- 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
152
- 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
153
- add_to_output,
154
- cu_seqlens_m is not None,
155
- cu_seqlens_k is not None,
156
- gather_A,
157
- batch_idx_permute is not None,
158
- key_tensor_names=("A", "B", "D", "C"),
159
- )
160
- cache = gemm.compile_cache
161
- if compile_key not in cache:
162
- if device_capacity[0] == 9:
163
- GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
164
- gemm_obj = GemmCls(
165
- acc_dtype,
166
- tensor_infos["A"].dtype,
167
- tile_shape_mn,
168
- cluster_shape_mnk,
169
- gather_A=gather_A,
170
- )
171
- cache[compile_key] = cute.compile(
172
- gemm_obj,
173
- tensor_infos["A"].cute_tensor,
174
- tensor_infos["B"].cute_tensor,
175
- tensor_infos["D"].cute_tensor,
176
- tensor_infos["C"].cute_tensor,
177
- epi_args,
178
- scheduler_args,
179
- varlen_args,
180
- current_stream,
181
  )
182
- cache[compile_key](
183
- tensor_infos["A"].cute_tensor,
184
- tensor_infos["B"].cute_tensor,
185
- tensor_infos["D"].cute_tensor,
186
- tensor_infos["C"].cute_tensor,
187
- epi_args,
188
- scheduler_args,
189
- varlen_args,
190
- current_stream,
191
- )
192
-
193
-
194
- gemm.compile_cache = {}
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
+ # GEMM compilation via TVM-FFI with fake tensors and NamedTuple args.
3
+
4
  from typing import Optional
 
5
 
6
  from torch import Tensor
7
 
8
  import cutlass.cute as cute
9
+ from cutlass import Int32, Float32
10
+ from cutlass.cute.runtime import make_ptr
11
+
12
+ from .cache_utils import jit_cache
13
+ from .compile_utils import make_fake_tensor as fake_tensor
14
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map
15
+ from .gemm_default_epi import (
16
+ GemmDefaultEpiMixin,
17
+ GemmDefaultSm90,
18
+ GemmDefaultSm100,
19
+ GemmDefaultSm120,
20
+ )
21
+ from .rounding import RoundingMode
22
+ from .gemm_tvm_ffi_utils import (
23
+ get_majors,
24
+ get_dtypes,
25
+ perm3d,
26
+ make_scheduler_args,
27
+ make_varlen_args,
28
+ make_fake_scheduler_args,
29
+ make_fake_varlen_args,
30
+ make_fake_gemm_tensors,
31
+ compile_gemm_kernel,
32
+ )
33
+
34
+
35
+ @jit_cache
36
+ def _compile_gemm(
37
+ a_dtype,
38
+ b_dtype,
39
+ d_dtype,
40
+ c_dtype,
41
+ a_major,
42
+ b_major,
43
+ d_major,
44
+ c_major,
45
+ tile_shape_mn,
46
+ cluster_shape_mnk,
47
+ pingpong,
48
+ persistent,
49
+ is_dynamic_persistent,
50
+ rowvec_dtype,
51
+ colvec_dtype,
52
+ colvec_ndim,
53
+ alpha_mode,
54
+ beta_mode,
55
+ add_to_output,
56
+ concat_layout,
57
+ varlen_m,
58
+ varlen_k,
59
+ gather_A,
60
+ use_tma_gather,
61
+ has_batch_idx_permute,
62
+ device_capacity,
63
+ rounding_mode,
64
+ sr_seed_mode,
65
+ has_trace_ptr,
66
+ ):
67
+ sm_to_cls = {
68
+ 9: GemmDefaultSm90,
69
+ 10: GemmDefaultSm100,
70
+ 11: GemmDefaultSm100,
71
+ 12: GemmDefaultSm120,
72
+ }
73
+ GemmCls = sm_to_cls[device_capacity[0]]
74
+ mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
75
+ a_dtype,
76
+ b_dtype,
77
+ d_dtype,
78
+ c_dtype,
79
+ a_major,
80
+ b_major,
81
+ d_major,
82
+ c_major,
83
+ varlen_m=varlen_m,
84
+ varlen_k=varlen_k,
85
+ gather_A=gather_A,
86
+ )
87
+
88
+ def fake_scalar(mode, dtype=Float32):
89
+ if mode == 0:
90
+ return None
91
+ elif mode == 1:
92
+ return dtype(1.0 if dtype == Float32 else 0)
93
+ else:
94
+ return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4)
95
+
96
+ mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
97
+ if colvec_ndim == 2:
98
+ mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4)
99
+ elif colvec_ndim == 1: # m is total_m in this case
100
+ mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4)
101
+ else:
102
+ mColVec = None
103
 
104
+ epi_args = GemmCls.EpilogueArguments(
105
+ alpha=fake_scalar(alpha_mode),
106
+ beta=fake_scalar(beta_mode),
107
+ mRowVecBroadcast=mRowVec,
108
+ mColVecBroadcast=mColVec,
109
+ add_to_output=add_to_output,
110
+ rounding_mode=rounding_mode,
111
+ sr_seed=fake_scalar(sr_seed_mode, dtype=Int32),
112
+ )
113
+ scheduler_args = make_fake_scheduler_args(
114
+ (is_dynamic_persistent and device_capacity[0] == 9), has_batch_idx_permute, l
115
+ )
116
+ aidx_len = m if varlen_m else (k if varlen_k else None)
117
+ varlen_args = make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len)
118
+ return compile_gemm_kernel(
119
+ GemmCls,
120
+ a_dtype,
121
+ tile_shape_mn,
122
+ cluster_shape_mnk,
123
+ pingpong,
124
+ persistent,
125
+ gather_A,
126
+ is_dynamic_persistent,
127
+ device_capacity,
128
+ mA,
129
+ mB,
130
+ mD,
131
+ mC,
132
+ epi_args,
133
+ scheduler_args,
134
+ varlen_args,
135
+ has_trace_ptr=has_trace_ptr,
136
+ use_tma_gather=use_tma_gather,
137
+ concat_layout=concat_layout or None,
138
+ )
139
 
140
 
141
  def gemm(
 
151
  cluster_N: int,
152
  pingpong: bool = False,
153
  persistent: bool = True,
154
+ is_dynamic_persistent: bool = False,
155
  max_swizzle_size: int = 8,
156
  rowvec_bias: Optional[Tensor] = None, # (l, n)
157
  colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
 
162
  A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
163
  batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
164
  add_to_output: bool = False,
165
+ rounding_mode: int = RoundingMode.RN,
166
+ sr_seed: int | Tensor = 0,
167
+ use_tma_gather: bool = False,
168
+ concat_layout: dict | None = None,
169
+ trace_ptr=None, # Optional Int64 from TraceSession.ptr
170
  ) -> None:
171
+ varlen_m = cu_seqlens_m is not None
172
+ varlen_k = cu_seqlens_k is not None
173
+ varlen = varlen_m or varlen_k
 
174
  gather_A = A_idx is not None
175
+ assert not (varlen_m and varlen_k), "Only one of cu_seqlens_m and cu_seqlens_k"
176
  if gather_A:
177
+ assert varlen, "gather_A requires varlen"
178
  assert cluster_N == 1, "gather_A requires cluster_N=1"
179
  if varlen:
180
  assert persistent, "varlen requires persistent=True"
181
  if add_to_output:
182
+ assert not varlen_m, "Add to output not supported with varlen_m"
183
+ if varlen_m:
184
  assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
185
  assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
186
+ if varlen_k:
187
  assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
188
  assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
189
 
190
+ device_capacity = get_device_capacity(A.device)
191
+ assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
192
+ if use_tma_gather:
193
+ assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110"
194
+ if rounding_mode == RoundingMode.RS:
195
+ assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
196
+ if is_dynamic_persistent and device_capacity[0] == 9:
197
+ assert tile_count_semaphore is not None, (
198
+ "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
199
+ )
200
+
201
+ A_p, B_p, D_p, C_p = perm3d(A, B, D, C, varlen_m=varlen_m, varlen_k=varlen_k)
202
+ a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
203
+ a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
204
+
205
+ alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0)
206
+ beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0)
207
+ colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0
208
+ concat_layout = tuple(sorted(concat_layout)) if concat_layout else ()
209
+
210
+ sr_seed_mode = (
211
+ 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
212
  )
213
+ compiled_fn = _compile_gemm(
214
+ a_dtype,
215
+ b_dtype,
216
+ d_dtype,
217
+ c_dtype,
218
+ a_major,
219
+ b_major,
220
+ d_major,
221
+ c_major,
222
+ (tile_M, tile_N),
223
+ (cluster_M, cluster_N, 1),
224
+ pingpong,
225
+ persistent,
226
+ is_dynamic_persistent,
227
+ torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None,
228
+ torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None,
229
+ colvec_ndim,
230
+ alpha_mode,
231
+ beta_mode,
232
+ add_to_output,
233
+ concat_layout,
234
+ varlen_m,
235
+ varlen_k,
236
+ gather_A,
237
+ use_tma_gather,
238
+ batch_idx_permute is not None,
239
+ device_capacity,
240
+ rounding_mode,
241
+ sr_seed_mode,
242
+ trace_ptr is not None,
243
  )
 
 
 
 
 
 
 
 
244
 
245
+ from .cache_utils import COMPILE_ONLY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ if COMPILE_ONLY:
248
+ return
249
 
250
+ def scalar_arg(scalar, mode, dtype=Float32):
251
+ if mode == 0:
252
+ return None
253
+ elif mode == 1:
254
+ return dtype(scalar)
255
  else:
256
+ return scalar.data_ptr()
 
257
 
258
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
259
+
260
+ epi_args = GemmDefaultEpiMixin.EpilogueArguments(
261
+ alpha=scalar_arg(alpha, alpha_mode),
262
+ beta=scalar_arg(beta, beta_mode),
263
+ mRowVecBroadcast=rowvec_bias,
264
+ mColVecBroadcast=colvec_bias,
265
+ add_to_output=None,
266
+ rounding_mode=None,
267
+ sr_seed=scalar_arg(sr_seed, sr_seed_mode, dtype=Int32),
 
 
 
 
268
  )
269
+ scheduler_args = make_scheduler_args(
270
  max_active_clusters,
271
+ max_swizzle_size,
272
  tile_count_semaphore,
273
  batch_idx_permute,
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  )
275
+ varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx)
276
 
277
+ if device_capacity[0] in [10, 11]:
278
+ compiled_fn(
279
+ A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  )
281
+ else:
282
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr)
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/quack/gemm_act.py CHANGED
@@ -1,7 +1,7 @@
1
  # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
- from typing import Tuple, Optional, Callable
 
3
  from functools import partial
4
- from dataclasses import dataclass
5
 
6
  from torch import Tensor
7
 
@@ -9,183 +9,85 @@ import cutlass
9
  import cutlass.cute as cute
10
  import cutlass.utils.hopper_helpers as sm90_utils_og
11
  import cutlass.utils.blackwell_helpers as sm100_utils
12
- from cutlass import Int32, Float32, Boolean, const_expr
13
- from cutlass.cutlass_dsl import if_generate
14
- import cutlass.torch as cutlass_torch
15
- from cutlass.cute.runtime import from_dlpack
16
-
17
- from .cute_dsl_utils import ArgumentsBase, ParamsBase
18
- from .varlen_utils import VarlenManager
 
 
 
 
 
19
  from .gemm_sm90 import GemmSm90
20
  from .gemm_sm100 import GemmSm100
 
21
  from .gemm_default_epi import GemmDefaultEpiMixin
22
- from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
23
- from .gemm_wrapper_utils import GemmWrapperBase
24
- from . import sm90_utils as sm90_utils
25
- from . import copy_utils as copy_utils
26
- from . import activation
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  class GemmActMixin(GemmDefaultEpiMixin):
30
- num_epi_tensormaps: int = 1
 
 
31
 
32
- @dataclass
33
- class EpilogueArguments(ArgumentsBase):
34
  mPostAct: cute.Tensor
35
  act_fn: cutlass.Constexpr[Optional[Callable]] = None
36
  alpha: Optional[Float32 | cute.Tensor] = None
37
  beta: Optional[Float32 | cute.Tensor] = None
38
  mRowVecBroadcast: Optional[cute.Tensor] = None
39
  mColVecBroadcast: Optional[cute.Tensor] = None
 
 
40
 
41
- @dataclass
42
- class EpilogueParams(ParamsBase):
43
- tma_atom_postact: cute.CopyAtom
44
- mPostAct_mnl: cute.Tensor
45
- epi_postact_smem_layout_staged: cute.ComposedLayout
46
- epi_tile_postact: cute.Tile
47
- act_fn: cutlass.Constexpr[Optional[Callable]] = None
48
- alpha: Optional[Float32 | cute.Tensor] = None
49
- beta: Optional[Float32 | cute.Tensor] = None
50
- mRowVecBroadcast: Optional[cute.Tensor] = None
51
- mColVecBroadcast: Optional[cute.Tensor] = None
52
 
53
- def epi_to_underlying_arguments(
54
- self, args: EpilogueArguments, *, loc=None, ip=None
55
- ) -> EpilogueParams:
56
  self.postact_dtype = args.mPostAct.element_type
57
  self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
58
-
59
  self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
60
- epi_tile_postact = self.epi_tile
61
- utils_cls = sm100_utils if self.arch == 100 else sm90_utils
62
- epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
63
- self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
64
- )
65
- tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
66
- args.mPostAct,
67
- epi_postact_smem_layout_staged,
68
- epi_tile_postact,
69
- op_type="store",
70
- )
71
- # Assume all strides are divisible by 32 bits except the last stride
72
- new_stride = lambda t: tuple(
73
- cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
74
- for s in t.stride
75
- )
76
- mRowVecBroadcast, mColVecBroadcast = [
77
- cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
78
- if t is not None
79
- else None
80
- for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
81
- ]
82
- return self.EpilogueParams(
83
- tma_atom_postact,
84
- tma_tensor_postact,
85
- epi_postact_smem_layout_staged,
86
- epi_tile_postact,
87
- args.act_fn,
88
- alpha=args.alpha,
89
- beta=args.beta,
90
- mRowVecBroadcast=mRowVecBroadcast,
91
- mColVecBroadcast=mColVecBroadcast,
92
- )
93
 
94
- def epi_get_tma_atoms(
95
- self, params: EpilogueParams, *, loc=None, ip=None
96
- ) -> list[cute.CopyAtom]:
97
- return [params.tma_atom_postact]
98
 
99
- def epi_get_tensormap_update_shapes_orders(
100
  self,
101
- params: EpilogueParams,
102
- cu_seqlens_m: Optional[cute.Tensor],
103
- batch_idx: Int32,
104
- *,
105
- loc=None,
106
- ip=None,
107
- ) -> tuple[list[Int32], list[int]]:
108
- shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
109
- orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
110
- return shapes, orders
111
-
112
- @staticmethod
113
- def epi_smem_bytes_per_stage(
114
- args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
115
- ) -> int:
116
- postact_dtype = args.mPostAct.element_type
117
- postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
118
- rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
119
- args, cta_tile_shape_mnk, epi_tile
120
- )
121
- return postact_bytes_per_stage + rowvec_colvec_bytes
122
-
123
- def epi_get_smem_struct(self, params: EpilogueParams):
124
- row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
125
- col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
126
- row_vec_dtype = (
127
- params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
128
- )
129
- col_vec_dtype = (
130
- params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
131
- )
132
-
133
- @cute.struct
134
- class EpiSharedStorage:
135
- sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
136
- sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
137
- sPostAct: cute.struct.Align[
138
- cute.struct.MemRange[
139
- self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
140
- ],
141
- self.buffer_align_bytes,
142
- ]
143
-
144
- return EpiSharedStorage
145
-
146
- def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
147
- sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
148
- sPostAct = storage.epi.sPostAct.get_tensor(
149
- params.epi_postact_smem_layout_staged.outer,
150
- swizzle=params.epi_postact_smem_layout_staged.inner,
151
- )
152
- return (sRowVec, sColVec, sPostAct)
153
-
154
- @cute.jit
155
- def epilogue(
156
- self,
157
- params: EpilogueParams,
158
- epi_smem_tensors: Tuple[cute.Tensor, ...],
159
- tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
160
- epi_pipeline: cutlass.pipeline.PipelineAsync,
161
- epi_store_pipeline: cutlass.pipeline.PipelineAsync,
162
- epi_read_state: cutlass.pipeline.PipelineState,
163
- epi_producer_state: cutlass.pipeline.PipelineState,
164
- epi_tile: cute.Tile,
165
- load_acc_subtile: Callable,
166
- tRS_rD: cute.Tensor,
167
- tRS_rC: Optional[cute.Tensor],
168
- tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
169
- tiled_copy_r2s: cute.TiledCopy,
170
- tRS_sD: cute.Tensor,
171
- tiled_copy_s2r: Optional[cute.TiledCopy],
172
- tSR_rC: Optional[cute.Tensor],
173
- tSR_sC: Optional[cute.Tensor],
174
- copy_D: Optional[Callable],
175
- copy_C: Optional[Callable],
176
- tile_coord_mnkl: cute.Coord,
177
- varlen_manager: VarlenManager,
178
- epilogue_barrier: cutlass.pipeline.NamedBarrier,
179
- tile_scheduler,
180
- tidx: Int32,
181
- is_tma_warp: Boolean,
182
- ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
183
- has_C = const_expr(tRS_rC is not None)
184
- has_D = const_expr(copy_D is not None)
185
-
186
- tma_atom_postact = params.tma_atom_postact
187
- mPostAct_mnl = params.mPostAct_mnl
188
- sRowVec, sColVec, sPostAct = epi_smem_tensors
189
  get_smem_store_op = (
190
  partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
191
  if self.arch == 100
@@ -194,131 +96,56 @@ class GemmActMixin(GemmDefaultEpiMixin):
194
  copy_atom_postact_r2s = get_smem_store_op(
195
  self.postact_layout, self.postact_dtype, self.acc_dtype
196
  )
197
- # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
198
- # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
199
  tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
200
  tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
201
- (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
202
  batch_idx = tile_coord_mnkl[3]
203
  copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
204
- tma_atom_postact,
205
- varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
206
  self.cta_tile_shape_postact_mn,
207
- params.epi_tile_postact,
208
  sPostAct,
209
  tile_coord_mnkl,
210
- tma_desc_ptr=tma_desc_postact_ptr,
211
- )
212
-
213
- # We iterate over epi tiles in the N dimension first before the M dimension
214
- epi_tile_shape = cute.zipped_divide(
215
- cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
216
- ).shape[1]
217
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
218
- epi_tile_num = cute.size(epi_tile_shape)
219
- num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
220
-
221
- epi_tensors = self.epi_begin(
222
- params,
223
- epi_smem_tensors,
224
- epi_tile,
225
- tiled_copy_t2r,
226
- tiled_copy_r2s,
227
- tile_coord_mnkl,
228
- varlen_manager,
229
- epilogue_barrier,
230
- tidx,
231
  )
 
232
 
233
- if const_expr(copy_C is not None):
234
- for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
235
- gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
236
- if is_tma_warp:
237
- epi_pipeline.producer_acquire(epi_producer_state)
238
- copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
239
- epi_pipeline.producer_commit(epi_producer_state)
240
- epi_producer_state.advance()
241
-
242
- def tma_store_fn(src_idx, dst_idx):
243
- # Fence and barrier to make sure shared memory store is visible to TMA store
244
- cute.arch.fence_proxy(
245
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
246
- )
247
- epilogue_barrier.arrive_and_wait()
248
- # Copy from shared memory to global memory
249
- if is_tma_warp:
250
- if const_expr(has_D):
251
- copy_D(src_idx=src_idx, dst_idx=dst_idx)
252
- copy_postact(src_idx=src_idx, dst_idx=dst_idx)
253
- # Can't use if statement here, epi_store_pipeline object isn't captured somehow
254
- if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
255
- if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
256
- epilogue_barrier.arrive_and_wait()
257
-
258
- delay_tma_store = True
259
-
260
- src_idx_prev, dst_idx_prev = None, None
261
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
262
- # The global memory coordinate for the current epi tile
263
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
264
- # Copy from acc to D registers
265
- load_acc_subtile(tRS_rD, epi_idx)
266
- epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
267
- if const_expr(has_C):
268
- epi_pipeline.consumer_wait(epi_read_state)
269
- cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
270
- # Fence to make sure shared memory read is visible to TMA load
271
- cute.arch.fence_proxy(
272
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
273
  )
274
- cute.arch.sync_warp()
275
- with cute.arch.elect_one():
276
- epi_pipeline.consumer_release(epi_read_state)
277
- epi_read_state.advance()
278
- if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
279
- gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
280
- if is_tma_warp:
281
- epi_pipeline.producer_acquire(epi_producer_state)
282
- copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
283
- epi_pipeline.producer_commit(epi_producer_state)
284
- epi_producer_state.advance()
285
- tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
286
- epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
287
- if const_expr(delay_tma_store):
288
- if const_expr(epi_idx > 0):
289
- tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
290
- src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
291
- # Copy from D registers to shared memory
292
- if const_expr(has_D):
293
- copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
294
- cute.copy(
295
- tiled_copy_postact_r2s,
296
- tiled_copy_postact_r2s.retile(tRS_rPostAct),
297
- tRS_sPostAct[None, None, None, epi_buffer],
298
  )
299
- if const_expr(not delay_tma_store):
300
- tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
301
-
302
- if const_expr(delay_tma_store):
303
- tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
304
-
305
- self.epi_end(
306
- params,
307
- epi_tensors,
308
- epi_tile,
309
- tiled_copy_t2r,
310
- tiled_copy_r2s,
311
- tile_coord_mnkl,
312
- varlen_manager,
313
- tidx,
314
- )
315
-
316
- return epi_read_state, epi_producer_state
317
 
318
  @cute.jit
319
  def epi_visit_subtile(
320
  self,
321
- params: EpilogueParams,
322
  epi_loop_tensors: Tuple[cute.Tensor, ...],
323
  tRS_rD: cute.Tensor,
324
  tRS_rC: Optional[cute.Tensor] = None,
@@ -327,7 +154,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
327
  # Apply activation function if provided
328
  # If we don't have .shape here, the compiler generates local stores and loads
329
  if const_expr(params.act_fn is not None):
330
- tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
331
  if const_expr(self.arch < 100):
332
  for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
333
  tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
@@ -338,10 +165,7 @@ class GemmActMixin(GemmDefaultEpiMixin):
338
  )
339
  else:
340
  tRS_rPostAct = tRS_rD
341
- # Type conversion
342
- tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
343
- tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
344
- return tRS_rPostAct_out
345
 
346
 
347
  class GemmActSm90(GemmActMixin, GemmSm90):
@@ -352,12 +176,202 @@ class GemmActSm100(GemmActMixin, GemmSm100):
352
  pass
353
 
354
 
355
- act_fn_map = {
356
- None: None,
357
- "relu": activation.relu,
358
- "relu_sq": activation.relu_sq,
359
- "gelu_tanh_approx": activation.gelu_tanh_approx,
360
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
 
363
  def gemm_act(
@@ -365,7 +379,7 @@ def gemm_act(
365
  B: Tensor, # (l, n, k)
366
  D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
367
  C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
368
- PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
369
  tile_count_semaphore: Optional[Tensor], # (1,)
370
  activation: Optional[str],
371
  tile_M: int,
@@ -374,137 +388,132 @@ def gemm_act(
374
  cluster_N: int,
375
  pingpong: bool = False,
376
  persistent: bool = True,
 
377
  max_swizzle_size: int = 8,
378
  rowvec_bias: Optional[Tensor] = None, # (l, n)
379
  colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
380
  cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
381
  A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
 
 
 
 
382
  ) -> None:
383
- if cu_seqlens_m is not None:
 
 
 
 
 
 
 
 
384
  assert persistent, "varlen_m requires persistent=True"
385
  assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
386
  if D is not None:
387
  assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
388
  assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
389
- gather_A = A_idx is not None
390
  if gather_A:
391
- assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
392
  assert cluster_N == 1, "gather_A requires cluster_N=1"
393
- assert activation in act_fn_map, f"Unsupported activation {activation}"
394
 
395
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
396
- A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
397
- )
398
- GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
399
- GemmWrapperBase.extract_dtypes(tensor_infos)
400
- major_configs = {
401
- "A": ("m", "k", "l"),
402
- "B": ("n", "k", "l"),
403
- "D": ("m", "n", "l"),
404
- "C": ("m", "n", "l"),
405
- "PostAct": ("m", "n", "l"),
406
- }
407
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
 
 
 
 
 
408
 
409
  device_capacity = get_device_capacity(A.device)
410
- assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
411
- GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
412
-
413
- acc_dtype = Float32
414
- tile_shape_mn = (tile_M, tile_N)
415
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
416
- if not GemmCls.is_valid_dtypes(
417
- tensor_infos["A"].dtype,
418
- tensor_infos["B"].dtype,
419
- acc_dtype,
420
- tensor_infos["D"].dtype,
421
- tensor_infos["A"].major,
422
- tensor_infos["B"].major,
423
- ):
424
- raise TypeError("Skipping due to unsupported combination of types and majors")
425
 
426
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
427
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
428
- act_fn = act_fn_map[activation]
429
- epi_args = GemmCls.EpilogueArguments(
430
- tensor_infos["PostAct"].cute_tensor,
431
- act_fn,
432
- mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
433
- leading_dim=1
434
- )
435
- if rowvec_bias is not None
436
- else None,
437
- mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
438
- leading_dim=1 if cu_seqlens_m is None else 0
439
  )
440
- if colvec_bias is not None
441
- else None,
442
- )
443
- scheduler_args = GemmWrapperBase.create_scheduler_args(
444
- max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
445
- )
446
 
447
- # Create varlen arguments if needed (assumes persistent=True when varlen_m)
448
- varlen_args = GemmWrapperBase.create_varlen_args(
449
- cu_seqlens_m,
450
- None, # cu_seqlens_k
451
- A_idx,
452
- max_active_clusters,
453
- cluster_shape_mnk,
454
- tensor_infos,
455
- GemmCls.num_epi_tensormaps,
456
- pingpong,
457
  )
458
-
459
- current_stream = cutlass_torch.current_stream()
460
- compile_key = GemmWrapperBase.get_compile_key(
461
- tensor_infos,
462
- activation,
463
- tile_shape_mn,
464
- cluster_shape_mnk,
 
 
 
 
 
 
 
465
  pingpong,
466
  persistent,
467
- tile_count_semaphore is not None,
 
 
 
 
 
 
 
468
  device_capacity,
469
- max_swizzle_size,
470
- rowvec_bias.dtype if rowvec_bias is not None else None,
471
- colvec_bias.dtype if colvec_bias is not None else None,
472
- cu_seqlens_m is not None,
473
- A_idx is not None,
474
- key_tensor_names=("A", "B", "D", "PostAct", "C"),
475
  )
476
- cache = gemm_act.compile_cache
477
- if compile_key not in cache:
478
- if device_capacity[0] == 9:
479
- GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
480
- gemm_obj = GemmCls(
481
- acc_dtype,
482
- tensor_infos["A"].dtype,
483
- tile_shape_mn,
484
- cluster_shape_mnk,
485
- gather_A=gather_A,
486
- )
487
- cache[compile_key] = cute.compile(
488
- gemm_obj,
489
- tensor_infos["A"].cute_tensor,
490
- tensor_infos["B"].cute_tensor,
491
- tensor_infos["D"].cute_tensor,
492
- tensor_infos["C"].cute_tensor,
493
- epi_args,
494
- scheduler_args,
495
- varlen_args,
496
- current_stream,
497
- )
498
- cache[compile_key](
499
- tensor_infos["A"].cute_tensor,
500
- tensor_infos["B"].cute_tensor,
501
- tensor_infos["D"].cute_tensor,
502
- tensor_infos["C"].cute_tensor,
503
- epi_args,
504
- scheduler_args,
505
- varlen_args,
506
- current_stream,
507
  )
 
 
 
 
 
 
 
 
 
 
 
508
 
509
 
510
- gemm_act.compile_cache = {}
 
1
  # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from __future__ import annotations
3
+ from typing import NamedTuple, Tuple, Optional, Callable
4
  from functools import partial
 
5
 
6
  from torch import Tensor
7
 
 
9
  import cutlass.cute as cute
10
  import cutlass.utils.hopper_helpers as sm90_utils_og
11
  import cutlass.utils.blackwell_helpers as sm100_utils
12
+ from cutlass import Int32, Float32, const_expr
13
+ from cutlass.cute.runtime import make_ptr
14
+
15
+ from .compile_utils import make_fake_tensor as fake_tensor
16
+ from .cute_dsl_utils import (
17
+ ParamsBase,
18
+ mlir_namedtuple,
19
+ get_device_capacity,
20
+ get_max_active_clusters,
21
+ torch2cute_dtype_map,
22
+ )
23
+ from .epi_ops import TileStore
24
  from .gemm_sm90 import GemmSm90
25
  from .gemm_sm100 import GemmSm100
26
+ from .gemm_sm120 import GemmSm120
27
  from .gemm_default_epi import GemmDefaultEpiMixin
28
+ from .gemm_tvm_ffi_utils import (
29
+ get_major,
30
+ perm3d_single,
31
+ make_scheduler_args,
32
+ make_varlen_args,
33
+ make_fake_scheduler_args,
34
+ make_fake_varlen_args,
35
+ div_for_dtype,
36
+ make_fake_gemm_tensors,
37
+ compile_gemm_kernel,
38
+ )
39
+ from .cache_utils import jit_cache
40
+ from . import layout_utils as layout_utils
41
+ from .layout_utils import permute_gated_Cregs_b16
42
+ from .activation import act_fn_map, gate_fn_map
43
+ from .rounding import RoundingMode
44
 
45
 
46
  class GemmActMixin(GemmDefaultEpiMixin):
47
+ _epi_ops = (*GemmDefaultEpiMixin._epi_ops, TileStore("mPostAct"))
48
+ _extra_param_fields = (("act_fn", cutlass.Constexpr, None),)
49
+ _epi_param_bases = (ParamsBase,)
50
 
51
+ @mlir_namedtuple
52
+ class EpilogueArguments(NamedTuple):
53
  mPostAct: cute.Tensor
54
  act_fn: cutlass.Constexpr[Optional[Callable]] = None
55
  alpha: Optional[Float32 | cute.Tensor] = None
56
  beta: Optional[Float32 | cute.Tensor] = None
57
  mRowVecBroadcast: Optional[cute.Tensor] = None
58
  mColVecBroadcast: Optional[cute.Tensor] = None
59
+ rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
60
+ sr_seed: Optional[Int32 | cute.Tensor] = None
61
 
62
+ # EpilogueParams auto-generated from _epi_ops + _extra_param_fields
 
 
 
 
 
 
 
 
 
 
63
 
64
+ def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None):
65
+ self.rounding_mode = args.rounding_mode
 
66
  self.postact_dtype = args.mPostAct.element_type
67
  self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
 
68
  self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
69
+ d = self._epi_ops_to_params_dict(args)
70
+ d["act_fn"] = args.act_fn
71
+ for key in ("mRowVecBroadcast", "mColVecBroadcast"):
72
+ if key in self.concat_layout and key in d and d[key] is not None:
73
+ d[key] = layout_utils.concat_to_interleave(d[key], 1)
74
+ return self.EpilogueParams(**d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # epi_get_tma_atoms, epi_smem_bytes_per_stage, epi_get_smem_struct,
77
+ # epi_get_smem_tensors are all inherited from ComposableEpiMixin via _epi_ops.
 
 
78
 
79
+ def epi_setup_postact(
80
  self,
81
+ params,
82
+ epi_smem_tensors,
83
+ tiled_copy_r2s,
84
+ tiled_copy_t2r,
85
+ tile_coord_mnkl,
86
+ varlen_manager,
87
+ tidx,
88
+ ):
89
+ """Setup postact TMA copies and partitions before the epilogue loop."""
90
+ sPostAct = epi_smem_tensors[self._epi_smem_map["mPostAct"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  get_smem_store_op = (
92
  partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
93
  if self.arch == 100
 
96
  copy_atom_postact_r2s = get_smem_store_op(
97
  self.postact_layout, self.postact_dtype, self.acc_dtype
98
  )
 
 
99
  tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
100
  tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
 
101
  batch_idx = tile_coord_mnkl[3]
102
  copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
103
+ params.tma_atom_mPostAct,
104
+ varlen_manager.offset_batch_epi(params.mPostAct, batch_idx),
105
  self.cta_tile_shape_postact_mn,
106
+ params.epi_tile_mPostAct,
107
  sPostAct,
108
  tile_coord_mnkl,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
+ return tiled_copy_postact_r2s, tRS_sPostAct, copy_postact
111
 
112
+ @cute.jit
113
+ def epi_convert_postact(
114
+ self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
115
+ ):
116
+ """Convert postact from acc_dtype to postact_dtype. Override for custom postprocessing."""
117
+ if const_expr(
118
+ self.rounding_mode == RoundingMode.RS
119
+ and tRS_rPostAct.element_type == cutlass.Float32
120
+ and self.postact_dtype == cutlass.BFloat16
121
+ ):
122
+ from .rounding import convert_f32_to_bf16_sr
123
+ from cutlass.cute.tensor import TensorSSA
124
+
125
+ # Salt with 0x9E3779B1 to avoid sharing entropy with the D output seed
126
+ seed = (
127
+ sr_seed
128
+ + 0x9E3779B1
129
+ + (
130
+ tile_coord_mnkl[0] * 65537
131
+ + tile_coord_mnkl[1] * 257
132
+ + tile_coord_mnkl[3] * 17
133
+ + (num_prev_subtiles + epi_idx) * 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  )
136
+ tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype)
137
+ src_vec = tRS_rPostAct.load()
138
+ raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx)
139
+ tRS_rPostAct_out.store(TensorSSA(raw_vec, src_vec.shape, self.postact_dtype))
140
+ else:
141
+ tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype)
142
+ tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
143
+ return tRS_rPostAct_out
 
 
 
 
 
 
 
 
 
 
144
 
145
  @cute.jit
146
  def epi_visit_subtile(
147
  self,
148
+ params,
149
  epi_loop_tensors: Tuple[cute.Tensor, ...],
150
  tRS_rD: cute.Tensor,
151
  tRS_rC: Optional[cute.Tensor] = None,
 
154
  # Apply activation function if provided
155
  # If we don't have .shape here, the compiler generates local stores and loads
156
  if const_expr(params.act_fn is not None):
157
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
158
  if const_expr(self.arch < 100):
159
  for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
160
  tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
 
165
  )
166
  else:
167
  tRS_rPostAct = tRS_rD
168
+ return tRS_rPostAct
 
 
 
169
 
170
 
171
  class GemmActSm90(GemmActMixin, GemmSm90):
 
176
  pass
177
 
178
 
179
+ class GemmActSm120(GemmActMixin, GemmSm120):
180
+ pass
181
+
182
+
183
+ def _gated_epi_tile_fn(gemm, epi_tile):
184
+ """Halve the N dimension of the epi_tile for gated postact."""
185
+ if isinstance(epi_tile[1], cute.Layout):
186
+ return (epi_tile[0], cute.recast_layout(2, 1, epi_tile[1]))
187
+ return (epi_tile[0], epi_tile[1] // 2)
188
+
189
+
190
+ class GemmGatedMixin(GemmActMixin):
191
+ _epi_ops = (
192
+ *GemmDefaultEpiMixin._epi_ops,
193
+ TileStore("mPostAct", epi_tile_fn=_gated_epi_tile_fn),
194
+ )
195
+
196
+ def epi_to_underlying_arguments(
197
+ self, args: GemmActMixin.EpilogueArguments, *, loc=None, ip=None
198
+ ) -> GemmActMixin.EpilogueParams:
199
+ assert args.mPostAct.element_type.width == 16, (
200
+ "GemmGated only supports 16bit postact for now"
201
+ )
202
+ assert self.d_layout is None or self.d_layout.is_n_major_c()
203
+ assert cutlass.utils.LayoutEnum.from_tensor(args.mPostAct).is_n_major_c()
204
+ if self.arch == 90:
205
+ assert self.cta_tile_shape_mnk[1] % 32 == 0, (
206
+ "GemmGatedSm90 requires tileN to be divisible by 32"
207
+ )
208
+ self.rounding_mode = args.rounding_mode
209
+ self.postact_dtype = args.mPostAct.element_type
210
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
211
+ self.cta_tile_shape_postact_mn = (
212
+ self.cta_tile_shape_mnk[0],
213
+ self.cta_tile_shape_mnk[1] // 2,
214
+ )
215
+ d = self._epi_ops_to_params_dict(args)
216
+ d["act_fn"] = args.act_fn
217
+ for key in ("mRowVecBroadcast", "mColVecBroadcast"):
218
+ if key in self.concat_layout and key in d and d[key] is not None:
219
+ d[key] = layout_utils.concat_to_interleave(d[key], 1)
220
+ return self.EpilogueParams(**d)
221
+
222
+ @cute.jit
223
+ def epi_visit_subtile(
224
+ self,
225
+ params: GemmActMixin.EpilogueParams,
226
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
227
+ tRS_rD: cute.Tensor,
228
+ tRS_rC: Optional[cute.Tensor] = None,
229
+ ) -> Optional[cute.Tensor]:
230
+ GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
231
+ tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout)
232
+ # If we don't have .shape here, the compiler generates local stores and loads
233
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype)
234
+ if const_expr(self.arch < 100):
235
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
236
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1])
237
+ else:
238
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
239
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
240
+ (tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3])
241
+ )
242
+ return tRS_rPostAct
243
+
244
+ @cute.jit
245
+ def epi_convert_postact(
246
+ self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
247
+ ):
248
+ tRS_rPostAct_out = GemmActMixin.epi_convert_postact(
249
+ self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
250
+ )
251
+ if const_expr(self.arch == 90):
252
+ # Only need this if we're using STSM
253
+ permute_gated_Cregs_b16(tRS_rPostAct_out)
254
+ return tRS_rPostAct_out
255
+
256
+
257
+ class GemmGatedSm90(GemmGatedMixin, GemmSm90):
258
+ pass
259
+
260
+
261
+ class GemmGatedSm100(GemmGatedMixin, GemmSm100):
262
+ pass
263
+
264
+
265
+ class GemmGatedSm120(GemmGatedMixin, GemmSm120):
266
+ pass
267
+
268
+
269
+ @jit_cache
270
+ def _compile_gemm_act(
271
+ a_dtype,
272
+ b_dtype,
273
+ d_dtype,
274
+ c_dtype,
275
+ postact_dtype,
276
+ a_major,
277
+ b_major,
278
+ d_major,
279
+ c_major,
280
+ postact_major,
281
+ tile_shape_mn,
282
+ cluster_shape_mnk,
283
+ pingpong,
284
+ persistent,
285
+ is_dynamic_persistent,
286
+ activation,
287
+ rowvec_dtype,
288
+ colvec_dtype,
289
+ colvec_ndim,
290
+ varlen_m,
291
+ gather_A,
292
+ concat_layout,
293
+ device_capacity,
294
+ gemm_cls_name,
295
+ rounding_mode=RoundingMode.RN,
296
+ sr_seed_mode=0,
297
+ use_tma_gather=False,
298
+ ):
299
+ sm_to_cls = {
300
+ "act": {9: GemmActSm90, 10: GemmActSm100, 11: GemmActSm100, 12: GemmActSm120},
301
+ "gated": {9: GemmGatedSm90, 10: GemmGatedSm100, 11: GemmGatedSm100, 12: GemmGatedSm120},
302
+ }
303
+ if device_capacity[0] == 12 and gemm_cls_name == "act":
304
+ raise NotImplementedError("SM120 non-gated activation GEMM epilogue is not yet supported")
305
+ GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]]
306
+ pa_leading = 1 if postact_major == "n" else 0
307
+ mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
308
+ a_dtype,
309
+ b_dtype,
310
+ d_dtype,
311
+ c_dtype,
312
+ a_major,
313
+ b_major,
314
+ d_major,
315
+ c_major,
316
+ varlen_m=varlen_m,
317
+ gather_A=gather_A,
318
+ )
319
+ pa_n = cute.sym_int() if gemm_cls_name == "gated" else n
320
+ div_pa = div_for_dtype(postact_dtype)
321
+ pa_leading_dim = 1 if gemm_cls_name == "gated" else pa_leading
322
+ pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l)
323
+ mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa)
324
+
325
+ mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
326
+ if colvec_ndim == 2:
327
+ mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4)
328
+ elif colvec_ndim == 1:
329
+ mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4)
330
+ else:
331
+ mColVec = None
332
+
333
+ act_fn = act_fn_map[activation] if gemm_cls_name == "act" else gate_fn_map[activation]
334
+
335
+ def fake_scalar(mode, dtype=Int32):
336
+ if mode == 0:
337
+ return None
338
+ elif mode == 1:
339
+ return dtype(0)
340
+ else:
341
+ return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4)
342
+
343
+ epi_args = GemmCls.EpilogueArguments(
344
+ mPostAct,
345
+ act_fn,
346
+ mRowVecBroadcast=mRowVec,
347
+ mColVecBroadcast=mColVec,
348
+ rounding_mode=rounding_mode,
349
+ sr_seed=fake_scalar(sr_seed_mode),
350
+ )
351
+ scheduler_args = make_fake_scheduler_args(
352
+ (is_dynamic_persistent and device_capacity[0] == 9), False, l
353
+ )
354
+ varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None)
355
+ return compile_gemm_kernel(
356
+ GemmCls,
357
+ a_dtype,
358
+ tile_shape_mn,
359
+ cluster_shape_mnk,
360
+ pingpong,
361
+ persistent,
362
+ gather_A,
363
+ is_dynamic_persistent,
364
+ device_capacity,
365
+ mA,
366
+ mB,
367
+ mD,
368
+ mC,
369
+ epi_args,
370
+ scheduler_args,
371
+ varlen_args,
372
+ use_tma_gather=use_tma_gather,
373
+ concat_layout=concat_layout or None,
374
+ )
375
 
376
 
377
  def gemm_act(
 
379
  B: Tensor, # (l, n, k)
380
  D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
381
  C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
382
+ PostAct: Tensor, # (l, m, n) or (total_m, n//2) if gated
383
  tile_count_semaphore: Optional[Tensor], # (1,)
384
  activation: Optional[str],
385
  tile_M: int,
 
388
  cluster_N: int,
389
  pingpong: bool = False,
390
  persistent: bool = True,
391
+ is_dynamic_persistent: bool = False,
392
  max_swizzle_size: int = 8,
393
  rowvec_bias: Optional[Tensor] = None, # (l, n)
394
  colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
395
  cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
396
  A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
397
+ rounding_mode: int = RoundingMode.RN,
398
+ sr_seed: int | Tensor = 0,
399
+ use_tma_gather: bool = False,
400
+ concat_layout: tuple | None = None,
401
  ) -> None:
402
+ if activation in gate_fn_map:
403
+ gemm_cls_name = "gated"
404
+ else:
405
+ assert activation in act_fn_map, f"Unsupported activation {activation}"
406
+ gemm_cls_name = "act"
407
+
408
+ varlen_m = cu_seqlens_m is not None
409
+ gather_A = A_idx is not None
410
+ if varlen_m:
411
  assert persistent, "varlen_m requires persistent=True"
412
  assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
413
  if D is not None:
414
  assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
415
  assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
 
416
  if gather_A:
417
+ assert cu_seqlens_m is not None, "gather_A requires varlen"
418
  assert cluster_N == 1, "gather_A requires cluster_N=1"
 
419
 
420
+ A_p = perm3d_single(A, varlen_m)
421
+ B_p = perm3d_single(B)
422
+ D_p = perm3d_single(D, varlen_m)
423
+ C_p = perm3d_single(C, varlen_m)
424
+ PostAct_p = perm3d_single(PostAct, varlen_m)
425
+
426
+ a_major = get_major(A_p, "m", "k")
427
+ b_major = get_major(B_p, "n", "k")
428
+ d_major = get_major(D_p, "m", "n") if D_p is not None else None
429
+ c_major = get_major(C_p, "m", "n") if C_p is not None else None
430
+ postact_major = get_major(PostAct_p, "m", "n")
431
+
432
+ a_dtype = torch2cute_dtype_map[A.dtype]
433
+ b_dtype = torch2cute_dtype_map[B.dtype]
434
+ d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None
435
+ c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None
436
+ postact_dtype = torch2cute_dtype_map[PostAct.dtype]
437
+ colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0
438
 
439
  device_capacity = get_device_capacity(A.device)
440
+ assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
441
+ if rounding_mode == RoundingMode.RS:
442
+ assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
+ if is_dynamic_persistent and device_capacity[0] == 9:
445
+ assert tile_count_semaphore is not None, (
446
+ "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
 
 
 
 
 
 
 
 
 
 
447
  )
 
 
 
 
 
 
448
 
449
+ sr_seed_mode = (
450
+ 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
 
 
 
 
 
 
 
 
451
  )
452
+ concat_layout = tuple(sorted(concat_layout)) if concat_layout else ()
453
+ compiled_fn = _compile_gemm_act(
454
+ a_dtype,
455
+ b_dtype,
456
+ d_dtype,
457
+ c_dtype,
458
+ postact_dtype,
459
+ a_major,
460
+ b_major,
461
+ d_major,
462
+ c_major,
463
+ postact_major,
464
+ (tile_M, tile_N),
465
+ (cluster_M, cluster_N, 1),
466
  pingpong,
467
  persistent,
468
+ is_dynamic_persistent,
469
+ activation,
470
+ torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None,
471
+ torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None,
472
+ colvec_ndim,
473
+ varlen_m,
474
+ gather_A,
475
+ concat_layout,
476
  device_capacity,
477
+ gemm_cls_name,
478
+ rounding_mode=rounding_mode,
479
+ sr_seed_mode=sr_seed_mode,
480
+ use_tma_gather=use_tma_gather,
 
 
481
  )
482
+
483
+ from .cache_utils import COMPILE_ONLY
484
+
485
+ if COMPILE_ONLY:
486
+ return
487
+
488
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
489
+
490
+ def scalar_arg(scalar, mode, dtype=Int32):
491
+ if mode == 0:
492
+ return None
493
+ elif mode == 1:
494
+ return dtype(scalar)
495
+ else:
496
+ return scalar.data_ptr()
497
+
498
+ epi_args = GemmActMixin.EpilogueArguments(
499
+ PostAct_p,
500
+ None, # act_fn is Constexpr, pass None at call time
501
+ mRowVecBroadcast=rowvec_bias,
502
+ mColVecBroadcast=colvec_bias,
503
+ rounding_mode=None, # Constexpr, pass None at call time
504
+ sr_seed=scalar_arg(sr_seed, sr_seed_mode),
 
 
 
 
 
 
 
 
505
  )
506
+ scheduler_args = make_scheduler_args(
507
+ max_active_clusters,
508
+ max_swizzle_size,
509
+ tile_count_semaphore,
510
+ )
511
+ varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
512
+
513
+ if device_capacity[0] in [10, 11]:
514
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
515
+ else:
516
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
517
 
518
 
519
+ gemm_gated = gemm_act
build/torch-cuda/quack/gemm_blockscaled_interface.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026, Tri Dao.
2
+ """PyTorch-friendly interface for the SM100 MXFP8 blockscaled GEMM.
3
+
4
+ Shape / layout conventions (matches torch.matmul, torch._scaled_mm, cuBLAS):
5
+ A: (M, K) or (L, M, K) dtype float8_e4m3fn, K-contiguous (row-major)
6
+ B: (K, N) or (L, K, N) dtype float8_e4m3fn, K-contiguous (col-major)
7
+ A_scale: (M, K/32) or (L, M, K/32) dtype float8_e8m0fnu, K-contiguous
8
+ B_scale: (K/32, N) or (L, K/32, N) dtype float8_e8m0fnu, K-contiguous
9
+ out: (M, N) or (L, M, N) dtype bfloat16/float16, contiguous
10
+
11
+ "K-contiguous" means stride 1 on the K axis. This matches how torchao/cuBLAS
12
+ use `torch._scaled_mm(a, b.t(), ...)`:
13
+ - you store a weight as nn.Linear-style `W` of shape `(N, K)` row-major
14
+ - you pass `W.mT` (a zero-copy view of shape (K, N) with K-contig) as B
15
+ The interface applies `.mT` internally to reach the `(N, K) K-major` layout
16
+ the quack kernel consumes. No data is copied.
17
+ """
18
+
19
+ from functools import lru_cache
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ from torch import Tensor
24
+
25
+ import cutlass
26
+
27
+ from .blockscaled_gemm_utils import (
28
+ ceil_div,
29
+ compile_blockscaled_gemm_tvm_ffi,
30
+ pack_scale_2d_to_blocked_contig,
31
+ scale_blocked_for_cublas,
32
+ scale_view_for_kernel,
33
+ )
34
+ from .gemm_default_epi import GemmDefaultSm100
35
+ from .mx_utils import to_mx
36
+
37
+ _SF_VEC_SIZE = 32
38
+ _TORCH_TO_CUTLASS_D = {
39
+ torch.bfloat16: cutlass.BFloat16,
40
+ torch.float16: cutlass.Float16,
41
+ torch.float32: cutlass.Float32,
42
+ }
43
+
44
+
45
+ def _default_tiler_cluster(m: int, n: int) -> Tuple[Tuple[int, int], Tuple[int, int]]:
46
+ """Pick a reasonable default (mma_tiler_mn, cluster_shape_mn)."""
47
+ if m >= 512 and n >= 128:
48
+ return (256, 128), (2, 1)
49
+ return (128, 128), (1, 1)
50
+
51
+
52
+ @lru_cache(maxsize=64)
53
+ def _compile_cached(
54
+ m: int,
55
+ n: int,
56
+ k: int,
57
+ l: int,
58
+ mma_tiler_mn: Tuple[int, int],
59
+ cluster_shape_mn: Tuple[int, int],
60
+ out_torch_dtype,
61
+ ab_dtype_cutlass,
62
+ sf_dtype_cutlass,
63
+ ):
64
+ """Compile kernel for a given (shape, dtype, tiler, cluster) and cache it."""
65
+ dev = torch.device("cuda")
66
+ rm = ceil_div(m, 128)
67
+ rn = ceil_div(n, 128)
68
+ rk = ceil_div(k // _SF_VEC_SIZE, 4)
69
+ # K-major: (l, m, k) contiguous, viewed as (m, k, l) strides (k, 1, m*k)
70
+ fake_mA = torch.empty(l, m, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0)
71
+ fake_mB = torch.empty(l, n, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0)
72
+ # N-major: (l, m, n) contiguous, viewed as (m, n, l) strides (n, 1, m*n)
73
+ fake_mD = torch.empty(l, m, n, dtype=out_torch_dtype, device=dev).permute(1, 2, 0)
74
+ fake_sc_A = torch.empty(l, rm, rk, 512, dtype=torch.float8_e8m0fnu, device=dev)
75
+ fake_sc_B = torch.empty(l, rn, rk, 512, dtype=torch.float8_e8m0fnu, device=dev)
76
+ fake_mSFA = scale_view_for_kernel(fake_sc_A, m, k // _SF_VEC_SIZE, l)
77
+ fake_mSFB = scale_view_for_kernel(fake_sc_B, n, k // _SF_VEC_SIZE, l)
78
+ return compile_blockscaled_gemm_tvm_ffi(
79
+ ab_dtype_cutlass,
80
+ sf_dtype_cutlass,
81
+ _SF_VEC_SIZE,
82
+ _TORCH_TO_CUTLASS_D[out_torch_dtype],
83
+ mma_tiler_mn,
84
+ cluster_shape_mn,
85
+ fake_mA,
86
+ fake_mB,
87
+ fake_mD,
88
+ fake_mSFA,
89
+ fake_mSFB,
90
+ )
91
+
92
+
93
+ def _as_3d(x: Tensor, ndim_in: int) -> Tensor:
94
+ """Add a leading batch dim if input is 2D. Returns a view."""
95
+ if ndim_in == 2:
96
+ return x.unsqueeze(0)
97
+ return x
98
+
99
+
100
+ def _to_kernel_layout(
101
+ A: Tensor,
102
+ B: Tensor,
103
+ A_scale: Tensor,
104
+ B_scale: Tensor,
105
+ ) -> Tuple[int, int, int, int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool]:
106
+ """Normalize shapes/strides, validate, and repack scales. Returns
107
+ (m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d).
108
+
109
+ A: (M,K) or (L,M,K) K-contig. B: (K,N) or (L,K,N) K-contig.
110
+ A_scale: (M,K/32) or (L,M,K/32) K-contig. B_scale: (K/32,N) or (L,K/32,N) K-contig.
111
+ """
112
+ assert A.dtype == torch.float8_e4m3fn, f"A dtype must be float8_e4m3fn, got {A.dtype}"
113
+ assert B.dtype == torch.float8_e4m3fn, f"B dtype must be float8_e4m3fn, got {B.dtype}"
114
+ assert A_scale.dtype == torch.float8_e8m0fnu
115
+ assert B_scale.dtype == torch.float8_e8m0fnu
116
+ was_2d = A.dim() == 2
117
+ # Flip B from (K,N) to (N,K) via .mT (zero-copy). User's B K-contig → .mT K-contig.
118
+ A3 = _as_3d(A, A.dim()) # (l, m, k) K-contig row-major expected
119
+ B3 = _as_3d(B, B.dim()).mT # (l, n, k) K-contig (view) from (l, k, n)
120
+ l, m, k = A3.shape
121
+ l2, n, k2 = B3.shape
122
+ assert l == l2, f"batch mismatch: A={l}, B={l2}"
123
+ assert k == k2, f"K mismatch: A K={k}, B K={k2}"
124
+ assert k % _SF_VEC_SIZE == 0, f"K ({k}) must be divisible by {_SF_VEC_SIZE}"
125
+ assert A3.stride(-1) == 1, "A must be K-contiguous (stride 1 on K)"
126
+ assert B3.stride(-1) == 1, (
127
+ "B must be K-contiguous on its K axis (pass .mT of an (N,K) row-major tensor)"
128
+ )
129
+ sf_k = k // _SF_VEC_SIZE
130
+ as3 = _as_3d(A_scale, A_scale.dim()) # expected (l, m, sf_k) K-contig row-major
131
+ bs3 = _as_3d(B_scale, B_scale.dim()).mT # (l, n, sf_k) K-contig (view) from (l, sf_k, n)
132
+ assert as3.stride(-1) == 1, "A_scale must be K-contiguous"
133
+ assert bs3.stride(-1) == 1, (
134
+ "B_scale must be K-contiguous on its K axis (pass .mT of an (N, K/32) row-major tensor)"
135
+ )
136
+ assert as3.shape == (l, m, sf_k), (
137
+ f"A_scale shape: expected (l={l},m={m},sf_k={sf_k}) K-contig, got {tuple(as3.shape)}"
138
+ )
139
+ assert bs3.shape == (l, n, sf_k), (
140
+ f"B_scale shape: expected .mT of (l={l},sf_k={sf_k},n={n}) -> ({l},{n},{sf_k}), got {tuple(bs3.shape)}"
141
+ )
142
+ # Force row-major contiguous for packer/kernel consumption.
143
+ # A3 / B3 are views — .contiguous() materializes (l,m,k) / (l,n,k) row-major.
144
+ A3_c = A3.contiguous()
145
+ B3_c = B3.contiguous()
146
+ # (l, m, k) -> (m, k, l) K-major view (no copy; strides (k, 1, m*k))
147
+ mA_mkl = A3_c.permute(1, 2, 0)
148
+ mB_nkl = B3_c.permute(1, 2, 0)
149
+ sc_contig_A = pack_scale_2d_to_blocked_contig(as3.contiguous())
150
+ sc_contig_B = pack_scale_2d_to_blocked_contig(bs3.contiguous())
151
+ sfa_view = scale_view_for_kernel(sc_contig_A, m, sf_k, l)
152
+ sfb_view = scale_view_for_kernel(sc_contig_B, n, sf_k, l)
153
+ return m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d
154
+
155
+
156
+ def mxfp8_gemm_out(
157
+ A: Tensor,
158
+ B: Tensor,
159
+ A_scale: Tensor,
160
+ B_scale: Tensor,
161
+ out: Tensor,
162
+ *,
163
+ mma_tiler_mn: Optional[Tuple[int, int]] = None,
164
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
165
+ ) -> None:
166
+ """MXFP8 blockscaled GEMM with pre-allocated output. See module doc for shape conventions."""
167
+ m, n, k, l, mA, mB, _scA, _scB, sfa, sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale)
168
+ out_dtype = out.dtype
169
+ assert out_dtype in _TORCH_TO_CUTLASS_D, f"unsupported out dtype: {out_dtype}"
170
+ expected_out_shape = (m, n) if was_2d else (l, m, n)
171
+ assert tuple(out.shape) == expected_out_shape, (
172
+ f"out shape {tuple(out.shape)} != expected {expected_out_shape}"
173
+ )
174
+ assert out.is_contiguous(), "out must be contiguous"
175
+ # View caller's contiguous (M,N) or (L,M,N) as (M,N,L) N-major strided view, no copy.
176
+ out_3d = out.unsqueeze(0) if was_2d else out # (l, m, n)
177
+ mD = out_3d.permute(1, 2, 0) # (m, n, l), strides (n, 1, m*n)
178
+ if mma_tiler_mn is None or cluster_shape_mn is None:
179
+ tlr, clu = _default_tiler_cluster(m, n)
180
+ mma_tiler_mn = mma_tiler_mn or tlr
181
+ cluster_shape_mn = cluster_shape_mn or clu
182
+ if not GemmDefaultSm100.can_implement_blockscaled(
183
+ cutlass.Float8E4M3FN,
184
+ cutlass.Float8E8M0FNU,
185
+ _SF_VEC_SIZE,
186
+ _TORCH_TO_CUTLASS_D[out_dtype],
187
+ mma_tiler_mn,
188
+ cluster_shape_mn,
189
+ m,
190
+ n,
191
+ k,
192
+ l,
193
+ "k",
194
+ "k",
195
+ "n",
196
+ ):
197
+ raise ValueError(
198
+ f"unsupported config: m={m}, n={n}, k={k}, l={l}, "
199
+ f"tiler={mma_tiler_mn}, cluster={cluster_shape_mn}"
200
+ )
201
+ runner = _compile_cached(
202
+ m,
203
+ n,
204
+ k,
205
+ l,
206
+ mma_tiler_mn,
207
+ cluster_shape_mn,
208
+ out_dtype,
209
+ cutlass.Float8E4M3FN,
210
+ cutlass.Float8E8M0FNU,
211
+ )
212
+ runner(mA, mB, mD, sfa, sfb)
213
+
214
+
215
+ def mxfp8_gemm(
216
+ A: Tensor,
217
+ B: Tensor,
218
+ A_scale: Tensor,
219
+ B_scale: Tensor,
220
+ out: Optional[Tensor] = None,
221
+ out_dtype: torch.dtype = torch.bfloat16,
222
+ *,
223
+ mma_tiler_mn: Optional[Tuple[int, int]] = None,
224
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
225
+ ) -> Tensor:
226
+ """MXFP8 blockscaled GEMM. Allocates output if not provided."""
227
+ if out is None:
228
+ # A: (M,K) or (L,M,K); B: (K,N) or (L,K,N); out: (M,N) or (L,M,N)
229
+ if A.dim() == 2:
230
+ out_shape = (A.shape[0], B.shape[1])
231
+ else:
232
+ out_shape = (A.shape[0], A.shape[1], B.shape[2])
233
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
234
+ mxfp8_gemm_out(
235
+ A,
236
+ B,
237
+ A_scale,
238
+ B_scale,
239
+ out,
240
+ mma_tiler_mn=mma_tiler_mn,
241
+ cluster_shape_mn=cluster_shape_mn,
242
+ )
243
+ return out
244
+
245
+
246
+ def mxfp8_quantize(x: Tensor) -> Tuple[Tensor, Tensor]:
247
+ """Quantize a (..., K) bf16/fp32 tensor to MXFP8. Returns (qdata, scale_2d)
248
+ in torchao-convention layout. Last dim (K) must be divisible by 32."""
249
+ assert x.shape[-1] % _SF_VEC_SIZE == 0, (
250
+ f"last dim ({x.shape[-1]}) must be divisible by {_SF_VEC_SIZE}"
251
+ )
252
+ return to_mx(x.contiguous(), _SF_VEC_SIZE)
253
+
254
+
255
+ def mxfp8_gemm_quantize(
256
+ A: Tensor,
257
+ B: Tensor,
258
+ out: Optional[Tensor] = None,
259
+ out_dtype: torch.dtype = torch.bfloat16,
260
+ *,
261
+ mma_tiler_mn: Optional[Tuple[int, int]] = None,
262
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
263
+ ) -> Tensor:
264
+ """High-level: quantize bf16 A, B_as_NK to MXFP8, then run C = A @ B_as_NK.mT.
265
+ Inputs: A=(M,K)/(L,M,K), B_as_NK=(N,K)/(L,N,K) bf16/fp32. Quantization
266
+ scales along the last (K) dim. Returned output has shape (M,N)/(L,M,N)."""
267
+ A_q, A_sc = mxfp8_quantize(A)
268
+ B_q, B_sc = mxfp8_quantize(B)
269
+ # B_q, B_sc are (..., N, K) / (..., N, K/32). Flip to (..., K, N) / (..., K/32, N)
270
+ # K-contig zero-copy views to match the interface convention.
271
+ return mxfp8_gemm(
272
+ A_q,
273
+ B_q.mT,
274
+ A_sc,
275
+ B_sc.mT,
276
+ out=out,
277
+ out_dtype=out_dtype,
278
+ mma_tiler_mn=mma_tiler_mn,
279
+ cluster_shape_mn=cluster_shape_mn,
280
+ )
281
+
282
+
283
+ def mxfp8_gemm_cublas(
284
+ A: Tensor,
285
+ B: Tensor,
286
+ A_scale: Tensor,
287
+ B_scale: Tensor,
288
+ out_dtype: torch.dtype = torch.bfloat16,
289
+ ) -> Tensor:
290
+ """Reference path via torch._scaled_mm. Requires l=1 (or 2D inputs)."""
291
+ m, n, k, l, _mA, _mB, sc_A, sc_B, _sfa, _sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale)
292
+ assert l == 1, "torch._scaled_mm MXFP8 path is 2D only; pass 2D inputs or l=1"
293
+ # torch._scaled_mm: A=(M,K) row-major, B=(K,N) col-major (both K-contig) -- same layout user gave us.
294
+ a2d = A if A.dim() == 2 else A.squeeze(0)
295
+ b2d = B if B.dim() == 2 else B.squeeze(0)
296
+ sca = scale_blocked_for_cublas(sc_A, m, k // _SF_VEC_SIZE, 0)
297
+ scb = scale_blocked_for_cublas(sc_B, n, k // _SF_VEC_SIZE, 0)
298
+ out = torch._scaled_mm(
299
+ a2d,
300
+ b2d,
301
+ scale_a=sca,
302
+ scale_b=scb,
303
+ out_dtype=out_dtype,
304
+ )
305
+ return out if was_2d else out.unsqueeze(0)
306
+
307
+
308
+ def mxfp8_gemm_ref(
309
+ A: Tensor,
310
+ B: Tensor,
311
+ A_scale: Tensor,
312
+ B_scale: Tensor,
313
+ out_dtype: torch.dtype = torch.bfloat16,
314
+ ) -> Tensor:
315
+ """Dequantize + plain matmul reference. A=(M,K), B=(K,N)."""
316
+ was_2d = A.dim() == 2
317
+ # (l, m, k)
318
+ A3 = _as_3d(A, A.dim()).float()
319
+ # B is (K, N)/(L, K, N); flip to (l, n, k) for dequant by last-dim
320
+ B3 = _as_3d(B, B.dim()).mT.contiguous().float()
321
+ as3 = _as_3d(A_scale, A_scale.dim()).float()
322
+ bs3 = _as_3d(B_scale, B_scale.dim()).mT.contiguous().float()
323
+ a_dq = A3 * as3.repeat_interleave(_SF_VEC_SIZE, dim=-1)
324
+ b_dq = B3 * bs3.repeat_interleave(_SF_VEC_SIZE, dim=-1)
325
+ out3 = torch.einsum("lmk,lnk->lmn", a_dq, b_dq).to(out_dtype)
326
+ return out3.squeeze(0) if was_2d else out3
build/torch-cuda/quack/gemm_config.py CHANGED
@@ -1,6 +1,6 @@
1
  # Copyright (C) 2025, Fri Dao.
2
  import itertools
3
- from typing import Optional, List, Literal
4
  from functools import partial
5
  from dataclasses import dataclass
6
 
@@ -10,86 +10,145 @@ class GemmConfig:
10
  tile_m: int = 128
11
  tile_n: int = 192
12
  pingpong: bool = True
 
 
13
  cluster_m: int = 2
14
  cluster_n: int = 1
15
  swap_ab: bool = False
16
  # raster_order: int = 1
17
  max_swizzle_size: int = 8
 
 
 
18
 
19
 
20
- def get_all_configs(
21
- device_capacity: Literal[9, 10] = 9,
22
  epilogue: Optional[str] = None,
23
  tune_coop: bool = True,
24
- # tune_raster_order=True,
25
  ) -> List[GemmConfig]:
26
- assert device_capacity in [9, 10]
27
- if device_capacity == 9:
28
- tile_n_vals = [128, 144, 160, 176, 192, 208]
29
- tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
30
- (128, 224),
31
- (128, 256),
32
- # (192, 256), # Getting IOT instruction (core dumped) in the bwd
33
- ]
34
- tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
35
- if epilogue in ["gated"]:
36
- tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
37
- tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
38
- elif epilogue in ["lse"]:
39
- tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
40
- tile_mn_vals = []
41
- if tune_coop:
42
- tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
43
- tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
 
44
  cluster = [(1, 2), (2, 1)]
45
- # cluster = [(1, 1), (1, 2), (2, 1)]
46
- if epilogue in ["lse"]:
47
- cluster = [(1, 2), (2, 1)]
48
- swap_ab_vals = [False, True]
49
- if epilogue in ["lse", "gated"]:
50
- swap_ab_vals = [False]
51
- # raster_swizzle = (
52
- # [(0, 1)]
53
- # if not tune_raster_order
54
- # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
55
- # )
56
- return [
57
- GemmConfig(
58
- tile_m=tile_m,
59
- tile_n=tile_n,
60
- pingpong=pingpong,
61
- cluster_m=cluster_m,
62
- cluster_n=cluster_n,
63
- swap_ab=swap_ab,
64
- # raster_order=raster_order,
65
- # max_swizzle_size=max_swizzle_size,
66
- )
67
- for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
68
- tile_mn_vals,
69
- cluster,
70
- swap_ab_vals,
71
- # raster_swizzle,
72
- )
73
- ]
74
- elif device_capacity == 10:
75
- tile_n_vals = [128, 160, 192, 224, 256]
76
- tile_n_64_vals = [128, 192, 256]
77
- tile_mn_cluster_vals = (
78
- [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
79
- # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
80
- + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
81
- + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )
83
- swap_ab_vals = [False, True]
84
- if epilogue in ["lse", "gated"]:
85
- swap_ab_vals = [False]
86
- max_swizzle_size_vals = [4, 8, 16]
87
- GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
88
- return [
89
- GemmConfigCls(
90
- tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
91
- )
92
- for (m, n, (cm, cn)), sab, ms in itertools.product(
93
- tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
94
- )
95
- ]
 
 
 
 
 
 
 
1
  # Copyright (C) 2025, Fri Dao.
2
  import itertools
3
+ from typing import Optional, List
4
  from functools import partial
5
  from dataclasses import dataclass
6
 
 
10
  tile_m: int = 128
11
  tile_n: int = 192
12
  pingpong: bool = True
13
+ # by default, we use dynamic persistent tile scheduler on SM100 but not on SM90
14
+ is_dynamic_persistent: bool = True
15
  cluster_m: int = 2
16
  cluster_n: int = 1
17
  swap_ab: bool = False
18
  # raster_order: int = 1
19
  max_swizzle_size: int = 8
20
+ device_capacity: int = 9
21
+ # whether to use TMA gather (vs normal cp.async) for gather_A on SM100
22
+ use_tma_gather: bool = False
23
 
24
 
25
+ def _get_sm90_configs(
 
26
  epilogue: Optional[str] = None,
27
  tune_coop: bool = True,
 
28
  ) -> List[GemmConfig]:
29
+ tile_n_vals = [128, 160, 192, 208]
30
+ tile_mn_vals_coop = [(256, tile_n) for tile_n in tile_n_vals] + [
31
+ (128, 224),
32
+ (128, 256),
33
+ # (192, 256), # Getting IOT instruction (core dumped) in the bwd
34
+ ]
35
+ tile_mn_vals_pingpong = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
36
+ if epilogue in ["gated"]:
37
+ tile_mn_vals_coop = [(m, n) for m, n in tile_mn_vals_coop if n % 32 == 0 and m != 192]
38
+ tile_mn_vals_pingpong = [(m, n) for m, n in tile_mn_vals_pingpong if n % 32 == 0]
39
+ elif epilogue in ["lse"]:
40
+ tile_mn_vals_coop = [(m, n) for m, n in tile_mn_vals_coop if m != 192]
41
+ tile_mn_vals = []
42
+ if tune_coop:
43
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_vals_coop]
44
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_vals_pingpong]
45
+ cluster = [(1, 2), (2, 1)]
46
+ # cluster = [(1, 1), (1, 2), (2, 1)]
47
+ if epilogue in ["lse"]:
48
  cluster = [(1, 2), (2, 1)]
49
+ swap_ab_vals = [False, True]
50
+ if epilogue in ["lse", "gated"]:
51
+ swap_ab_vals = [False]
52
+
53
+ return [
54
+ GemmConfig(
55
+ tile_m=tile_m,
56
+ tile_n=tile_n,
57
+ pingpong=pingpong,
58
+ cluster_m=cluster_m,
59
+ cluster_n=cluster_n,
60
+ swap_ab=swap_ab,
61
+ device_capacity=9,
62
+ is_dynamic_persistent=False, # default to not use dynamic persistent on SM90
63
+ use_tma_gather=False, # TMA gather not supported on SM90
64
+ )
65
+ for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
66
+ tile_mn_vals,
67
+ cluster,
68
+ swap_ab_vals,
69
+ )
70
+ ]
71
+
72
+
73
+ def _get_sm100_configs(
74
+ epilogue: Optional[str] = None,
75
+ ) -> List[GemmConfig]:
76
+ tile_n_vals = [64, 128, 160, 192, 224, 256]
77
+ tile_mn_cluster_vals = (
78
+ [(128, tile_n, (1, 1)) for tile_n in tile_n_vals]
79
+ + [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
80
+ + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
81
+ + [(128, tile_n, (2, 2)) for tile_n in tile_n_vals]
82
+ + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
83
+ + [(256, tile_n, (2, 2)) for tile_n in tile_n_vals]
84
+ + [(256, 512, (2, 1))]
85
+ )
86
+ swap_ab_vals = [False, True]
87
+ if epilogue in ["lse", "gated"]:
88
+ swap_ab_vals = [False]
89
+ GemmConfigCls = partial(
90
+ GemmConfig, pingpong=False, device_capacity=10
91
+ ) # There's no pingpong on Sm100
92
+ use_clc_vals = [True, False]
93
+ use_tma_gather_vals = [True, False]
94
+ return [
95
+ GemmConfigCls(
96
+ tile_m=m,
97
+ tile_n=n,
98
+ cluster_m=cm,
99
+ cluster_n=cn,
100
+ swap_ab=sab,
101
+ max_swizzle_size=8,
102
+ is_dynamic_persistent=use_clc,
103
+ use_tma_gather=use_tma_gather,
104
+ )
105
+ for (m, n, (cm, cn)), sab, use_clc, use_tma_gather in itertools.product(
106
+ tile_mn_cluster_vals, swap_ab_vals, use_clc_vals, use_tma_gather_vals
107
+ )
108
+ ]
109
+
110
+
111
+ def _get_sm120_configs(
112
+ epilogue: Optional[str] = None,
113
+ tune_coop: bool = True,
114
+ ) -> List[GemmConfig]:
115
+ tile_mn_vals_coop = [(128, 128), (128, 64), (64, 128), (128, 160), (128, 192)]
116
+ tile_mn_vals_pingpong = [(128, 128), (128, 64), (64, 128), (128, 160)]
117
+ tile_mn_vals = []
118
+ if tune_coop:
119
+ tile_mn_vals += [(m, n, False) for m, n in tile_mn_vals_coop]
120
+ tile_mn_vals += [(m, n, True) for m, n in tile_mn_vals_pingpong]
121
+ swap_ab_vals = [False, True]
122
+ if epilogue in ["lse", "gated"]:
123
+ swap_ab_vals = [False]
124
+ return [
125
+ GemmConfig(
126
+ tile_m=tile_m,
127
+ tile_n=tile_n,
128
+ pingpong=pingpong,
129
+ cluster_m=1,
130
+ cluster_n=1,
131
+ swap_ab=swap_ab,
132
+ device_capacity=12,
133
+ is_dynamic_persistent=True,
134
+ use_tma_gather=False, # TMA gather not supported on SM120
135
  )
136
+ for (tile_m, tile_n, pingpong), swap_ab in itertools.product(tile_mn_vals, swap_ab_vals)
137
+ ]
138
+
139
+
140
+ def get_all_configs(
141
+ epilogue: Optional[str] = None,
142
+ tune_coop: bool = True,
143
+ ) -> List[GemmConfig]:
144
+ """Return autotuning configs for all supported device capabilities (sm90 + sm100 + sm120).
145
+
146
+ Each GemmConfig is tagged with its target device_capacity, so the caller can
147
+ filter at runtime based on the actual device. This avoids querying the device
148
+ (and initializing a CUDA context) at import time.
149
+ """
150
+ return (
151
+ _get_sm90_configs(epilogue, tune_coop)
152
+ + _get_sm100_configs(epilogue)
153
+ + _get_sm120_configs(epilogue, tune_coop)
154
+ )
build/torch-cuda/quack/gemm_dact.py CHANGED
@@ -1,33 +1,53 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- from typing import Optional, Tuple
3
- from functools import partial
4
 
 
5
  from torch import Tensor
6
 
7
  import cutlass
8
  import cutlass.cute as cute
9
- from cutlass import Float32, const_expr
10
- import cutlass.torch as cutlass_torch
11
-
12
  from .gemm_sm90 import GemmSm90
13
  from .gemm_sm100 import GemmSm100
 
14
  from .gemm_default_epi import GemmDefaultEpiMixin
15
  from .gemm_act import GemmActMixin
16
- from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
17
- from .gemm_wrapper_utils import GemmWrapperBase
18
- from . import activation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class GemmDActMixin(GemmActMixin):
22
  # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
23
  # and return 2 arguments (dx, out)
24
  EpilogueArguments = GemmActMixin.EpilogueArguments
25
- EpilogueParams = GemmActMixin.EpilogueParams
26
 
27
  @cute.jit
28
  def epi_visit_subtile(
29
  self,
30
- params: EpilogueParams,
31
  epi_loop_tensors: Tuple[cute.Tensor, ...],
32
  tRS_rD: cute.Tensor,
33
  tRS_rC: Optional[cute.Tensor] = None,
@@ -35,11 +55,11 @@ class GemmDActMixin(GemmActMixin):
35
  assert tRS_rC is not None
36
  # We don't add C to the accumulator
37
  GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
38
- tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
39
  tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
40
  # If we don't have .shape here, the compiler generates local stores and loads
41
  if const_expr(params.act_fn is not None):
42
- tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
43
  if const_expr(self.arch < 100):
44
  for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
45
  tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
@@ -54,10 +74,7 @@ class GemmDActMixin(GemmActMixin):
54
  )
55
  else:
56
  tRS_rPostAct = tRS_rC_acc
57
- # Type conversion
58
- tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
59
- tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
60
- return tRS_rPostAct_out
61
 
62
 
63
  class GemmDActSm90(GemmDActMixin, GemmSm90):
@@ -68,19 +85,283 @@ class GemmDActSm100(GemmDActMixin, GemmSm100):
68
  pass
69
 
70
 
71
- dact_fn_map = {
72
- None: None,
73
- "relu": activation.drelu,
74
- "relu_sq": activation.drelu_sq,
75
- "gelu_tanh_approx": activation.dgelu_tanh_approx,
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  def gemm_dact(
80
  A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
81
  B: Tensor, # (l, n, k)
82
- Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
83
- PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
84
  PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
85
  tile_count_semaphore: Optional[Tensor], # (1,)
86
  activation: Optional[str],
@@ -90,126 +371,138 @@ def gemm_dact(
90
  cluster_N: int,
91
  pingpong: bool = True,
92
  persistent: bool = True,
 
93
  max_swizzle_size: int = 8,
 
 
 
94
  cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
95
  A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
 
96
  ) -> None:
97
- if cu_seqlens_m is not None:
 
 
 
 
 
 
 
 
 
98
  assert persistent, "varlen_m requires persistent=True"
99
  assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
100
  assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
101
  assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
102
  assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
103
- gather_A = A_idx is not None
104
  if gather_A:
105
- assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
106
  assert cluster_N == 1, "gather_A requires cluster_N=1"
107
- assert activation in dact_fn_map, f"Unsupported activation {activation}"
108
-
109
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
110
- A,
111
- B,
112
- Out,
113
- PreAct,
114
- additional_tensors={"PostAct": PostAct},
115
- cu_seqlens_m=cu_seqlens_m,
116
- A_idx=A_idx,
117
- )
118
- GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
119
- GemmWrapperBase.extract_dtypes(tensor_infos)
120
- major_configs = {
121
- "A": ("m", "k", "l"),
122
- "B": ("n", "k", "l"),
123
- "D": ("m", "n", "l"),
124
- "C": ("m", "n", "l"),
125
- "PostAct": ("m", "n", "l"),
126
- }
127
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
128
 
129
- device_capacity = get_device_capacity(A.device)
130
- assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
131
- GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
132
-
133
- acc_dtype = Float32
134
- tile_shape_mn = (tile_M, tile_N)
135
- cluster_shape_mnk = (cluster_M, cluster_N, 1)
136
- if not GemmCls.is_valid_dtypes(
137
- tensor_infos["A"].dtype,
138
- tensor_infos["B"].dtype,
139
- acc_dtype,
140
- tensor_infos["D"].dtype,
141
- tensor_infos["A"].major,
142
- tensor_infos["B"].major,
143
- ):
144
- raise TypeError("Skipping due to unsupported combination of types and majors")
145
 
146
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
147
- GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
148
- act_fn = dact_fn_map[activation]
149
- epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
150
- scheduler_args = GemmWrapperBase.create_scheduler_args(
151
- max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
152
- )
153
 
154
- # Create varlen arguments if needed (assumes persistent=True when varlen_m)
155
- varlen_args = GemmWrapperBase.create_varlen_args(
156
- cu_seqlens_m,
157
- None, # cu_seqlens_k
158
- A_idx,
159
- max_active_clusters,
160
- cluster_shape_mnk,
161
- tensor_infos,
162
- GemmCls.num_epi_tensormaps,
163
- pingpong,
164
- )
165
 
166
- current_stream = cutlass_torch.current_stream()
167
- compile_key = GemmWrapperBase.get_compile_key(
168
- tensor_infos,
169
- activation,
170
- tile_shape_mn,
171
- cluster_shape_mnk,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  pingpong,
173
  persistent,
174
- tile_count_semaphore is not None,
 
 
 
 
 
 
 
175
  device_capacity,
176
- max_swizzle_size,
177
- cu_seqlens_m is not None,
178
- A_idx is not None,
179
- key_tensor_names=("A", "B", "D", "PostAct", "C"),
180
  )
181
- cache = gemm_dact.compile_cache
182
- if compile_key not in cache:
183
- if device_capacity[0] == 9:
184
- GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
185
- gemm = GemmCls(
186
- acc_dtype,
187
- tensor_infos["A"].dtype,
188
- tile_shape_mn,
189
- cluster_shape_mnk,
190
- gather_A=gather_A,
 
 
 
 
 
191
  )
192
- cache[compile_key] = cute.compile(
193
- gemm,
194
- tensor_infos["A"].cute_tensor,
195
- tensor_infos["B"].cute_tensor,
196
- tensor_infos["D"].cute_tensor,
197
- tensor_infos["C"].cute_tensor,
198
- epi_args,
199
- scheduler_args,
200
- varlen_args,
201
- current_stream,
202
  )
203
- cache[compile_key](
204
- tensor_infos["A"].cute_tensor,
205
- tensor_infos["B"].cute_tensor,
206
- tensor_infos["D"].cute_tensor,
207
- tensor_infos["C"].cute_tensor,
208
- epi_args,
209
- scheduler_args,
210
- varlen_args,
211
- current_stream,
212
  )
 
 
 
 
 
 
 
 
213
 
214
 
215
- gemm_dact.compile_cache = {}
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
+ from __future__ import annotations
3
+ from typing import NamedTuple, Optional, Tuple, Callable
4
 
5
+ import torch
6
  from torch import Tensor
7
 
8
  import cutlass
9
  import cutlass.cute as cute
10
+ from cutlass import Int32, Float32, const_expr
 
 
11
  from .gemm_sm90 import GemmSm90
12
  from .gemm_sm100 import GemmSm100
13
+ from .gemm_sm120 import GemmSm120
14
  from .gemm_default_epi import GemmDefaultEpiMixin
15
  from .gemm_act import GemmActMixin
16
+ from .epi_ops import ColVecReduce, colvec_reduce_accumulate
17
+ from .compile_utils import make_fake_tensor as fake_tensor
18
+ from .cute_dsl_utils import (
19
+ ParamsBase,
20
+ mlir_namedtuple,
21
+ torch2cute_dtype_map,
22
+ get_device_capacity,
23
+ get_max_active_clusters,
24
+ )
25
+ from .gemm_tvm_ffi_utils import (
26
+ get_major,
27
+ perm3d_single,
28
+ make_scheduler_args,
29
+ make_varlen_args,
30
+ make_fake_scheduler_args,
31
+ make_fake_varlen_args,
32
+ div_for_dtype,
33
+ make_fake_gemm_tensors,
34
+ compile_gemm_kernel,
35
+ )
36
+ from .cache_utils import jit_cache
37
+ from .rounding import RoundingMode
38
+ from . import layout_utils as layout_utils
39
+ from .activation import dact_fn_map, dgate_fn_map
40
 
41
 
42
  class GemmDActMixin(GemmActMixin):
43
  # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
44
  # and return 2 arguments (dx, out)
45
  EpilogueArguments = GemmActMixin.EpilogueArguments
 
46
 
47
  @cute.jit
48
  def epi_visit_subtile(
49
  self,
50
+ params,
51
  epi_loop_tensors: Tuple[cute.Tensor, ...],
52
  tRS_rD: cute.Tensor,
53
  tRS_rC: Optional[cute.Tensor] = None,
 
55
  assert tRS_rC is not None
56
  # We don't add C to the accumulator
57
  GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
58
+ tRS_rC_acc = cute.make_rmem_tensor_like(tRS_rC, self.acc_dtype)
59
  tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
60
  # If we don't have .shape here, the compiler generates local stores and loads
61
  if const_expr(params.act_fn is not None):
62
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
63
  if const_expr(self.arch < 100):
64
  for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
65
  tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
 
74
  )
75
  else:
76
  tRS_rPostAct = tRS_rC_acc
77
+ return tRS_rPostAct
 
 
 
78
 
79
 
80
  class GemmDActSm90(GemmDActMixin, GemmSm90):
 
85
  pass
86
 
87
 
88
+ class GemmDActSm120(GemmDActMixin, GemmSm120):
89
+ pass
90
+
91
+
92
+ class GemmDGatedMixin(GemmActMixin):
93
+ # Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout)
94
+ # and return 3 arguments (dx, dy, out)
95
+ _epi_ops = (*GemmActMixin._epi_ops, ColVecReduce("mColVecReduce"))
96
+ _extra_param_fields = (("act_bwd_fn", cutlass.Constexpr, None),)
97
+ _epi_param_bases = (ParamsBase,)
98
+
99
+ @mlir_namedtuple
100
+ class EpilogueArguments(NamedTuple):
101
+ mPostAct: cute.Tensor
102
+ act_bwd_fn: cutlass.Constexpr[Callable] = None
103
+ alpha: Optional[Float32 | cute.Tensor] = None
104
+ beta: Optional[Float32 | cute.Tensor] = None
105
+ mRowVecBroadcast: Optional[cute.Tensor] = None
106
+ mColVecBroadcast: Optional[cute.Tensor] = None
107
+ mColVecReduce: Optional[cute.Tensor] = None
108
+ rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
109
+ sr_seed: Optional[Int32 | cute.Tensor] = None
110
+
111
+ # EpilogueParams auto-generated from _epi_ops + _extra_param_fields
112
+
113
+ def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None):
114
+ # C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose
115
+ # for reusing the existing load/store code.
116
+ assert self.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now"
117
+ assert self.d_dtype.width == 32, "D storage type must be 32 bit"
118
+ assert self.c_dtype.width == 32, "C storage type must be 32 bit"
119
+ self.rounding_mode = args.rounding_mode
120
+ self.postact_dtype = args.mPostAct.element_type
121
+ self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
122
+ self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
123
+ d = self._epi_ops_to_params_dict(args)
124
+ d["act_bwd_fn"] = args.act_bwd_fn
125
+ return self.EpilogueParams(**d)
126
+
127
+ # epi_begin, epi_begin_loop, epi_end are inherited from ComposableEpiMixin via _epi_ops.
128
+
129
+ @cute.jit
130
+ def epi_visit_subtile(
131
+ self,
132
+ params,
133
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
134
+ tRS_rD: cute.Tensor,
135
+ tRS_rC: Optional[cute.Tensor] = None,
136
+ ) -> Optional[cute.Tensor]:
137
+ alpha = epi_loop_tensors["alpha"]
138
+ beta = epi_loop_tensors["beta"]
139
+ tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
140
+ tDrColVec = epi_loop_tensors["mColVecBroadcast"]
141
+ tDrColVecReduce = epi_loop_tensors["mColVecReduce"]
142
+ assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now
143
+ assert tRS_rC is not None
144
+ implicit_dtype = self.implicit_dtype
145
+ assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now"
146
+ tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype)
147
+ tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32)
148
+ tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32))
149
+ tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32)
150
+ tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32)
151
+ tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD)
152
+ if const_expr(tDrColVec is not None): # Scale D by colvec
153
+ if const_expr(self.arch < 100):
154
+ tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type))
155
+ else:
156
+ tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
157
+ tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout)
158
+ tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride(
159
+ tRS_rD_scaled, tDrColVec.layout
160
+ )
161
+ for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
162
+ for n in cutlass.range(
163
+ cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
164
+ ):
165
+ (
166
+ tRS_rD_scaled_mn[m, 2 * n],
167
+ tRS_rD_scaled_mn[m, 2 * n + 1],
168
+ ) = cute.arch.mul_packed_f32x2(
169
+ (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]),
170
+ (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
171
+ )
172
+ else:
173
+ tRS_rD_scaled.store(tRS_rD.load())
174
+ if const_expr(self.arch < 100):
175
+ for i in cutlass.range(cute.size(tRS_rD)):
176
+ (
177
+ tRS_rdXY_f32x2[2 * i],
178
+ tRS_rdXY_f32x2[2 * i + 1],
179
+ tRS_rOut[i],
180
+ ) = params.act_bwd_fn(
181
+ tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i]
182
+ )
183
+ else:
184
+ for i in cutlass.range(cute.size(tRS_rD) // 2):
185
+ (
186
+ (tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]),
187
+ (tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]),
188
+ (tRS_rOut[2 * i], tRS_rOut[2 * i + 1]),
189
+ ) = params.act_bwd_fn(
190
+ (tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]),
191
+ (tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]),
192
+ (tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]),
193
+ )
194
+ if const_expr(tDrColVecReduce is not None):
195
+ # Accumulate postact * dout before D is scaled by colvec_scale
196
+ colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rOut, rScale=tRS_rD)
197
+
198
+ if const_expr(tDrColVec is not None): # Scale Out by colvec
199
+ if const_expr(self.arch < 100):
200
+ tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type))
201
+ else:
202
+ tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
203
+ tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout)
204
+ for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
205
+ for n in cutlass.range(
206
+ cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
207
+ ):
208
+ tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = (
209
+ cute.arch.mul_packed_f32x2(
210
+ (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]),
211
+ (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]),
212
+ )
213
+ )
214
+ # Type conversion
215
+ tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype)
216
+ tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype))
217
+ tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load())
218
+ return tRS_rOut
219
+
220
+ # epi_end is inherited from ComposableEpiMixin → delegates to ColVecReduce.end()
221
+
222
+
223
+ class GemmDGatedSm90(GemmDGatedMixin, GemmSm90):
224
+ pass
225
+
226
+
227
+ class GemmDGatedSm100(GemmDGatedMixin, GemmSm100):
228
+ pass
229
+
230
+
231
+ class GemmDGatedSm120(GemmDGatedMixin, GemmSm120):
232
+ pass
233
+
234
+
235
+ @jit_cache
236
+ def _compile_gemm_dact(
237
+ a_dtype,
238
+ b_dtype,
239
+ d_dtype,
240
+ c_dtype,
241
+ postact_dtype,
242
+ implicit_dtype,
243
+ a_major,
244
+ b_major,
245
+ d_major,
246
+ c_major,
247
+ postact_major,
248
+ tile_shape_mn,
249
+ cluster_shape_mnk,
250
+ pingpong,
251
+ persistent,
252
+ is_dynamic_persistent,
253
+ activation,
254
+ colvec_scale_dtype,
255
+ colvec_scale_ndim,
256
+ colvec_reduce_dtype,
257
+ colvec_reduce_ndim,
258
+ varlen_m,
259
+ gather_A,
260
+ device_capacity,
261
+ gemm_cls_name,
262
+ use_tma_gather=False,
263
+ ):
264
+ is_dgated = gemm_cls_name == "dgated"
265
+ sm_to_cls = {
266
+ "dact": {9: GemmDActSm90, 10: GemmDActSm100, 11: GemmDActSm100, 12: GemmDActSm120},
267
+ "dgated": {
268
+ 9: GemmDGatedSm90,
269
+ 10: GemmDGatedSm100,
270
+ 11: GemmDGatedSm100,
271
+ 12: GemmDGatedSm120,
272
+ },
273
+ }
274
+ if device_capacity[0] == 12 and gemm_cls_name == "dact":
275
+ raise NotImplementedError("SM120 non-gated dactivation GEMM epilogue is not yet supported")
276
+ GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]]
277
+ mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
278
+ a_dtype,
279
+ b_dtype,
280
+ d_dtype,
281
+ c_dtype,
282
+ a_major,
283
+ b_major,
284
+ d_major,
285
+ c_major,
286
+ varlen_m=varlen_m,
287
+ gather_A=gather_A,
288
+ )
289
+ div_pa = div_for_dtype(postact_dtype)
290
+ pa_leading = 1 if postact_major == "n" else 0
291
+ pa_shape = (m, n) if varlen_m else (m, n, l)
292
+ mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading, divisibility=div_pa)
293
+
294
+ if is_dgated:
295
+ act_fn = dgate_fn_map[activation]
296
+
297
+ mColVec = None
298
+ if colvec_scale_ndim == 2:
299
+ mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4)
300
+ elif colvec_scale_ndim == 1:
301
+ mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4)
302
+ mColVecReduce = None
303
+ n_tiles = cute.sym_int()
304
+ if colvec_reduce_ndim == 3:
305
+ mColVecReduce = fake_tensor(
306
+ colvec_reduce_dtype,
307
+ (l, m, n_tiles),
308
+ leading_dim=2,
309
+ divisibility=1,
310
+ )
311
+ elif colvec_reduce_ndim == 2:
312
+ mColVecReduce = fake_tensor(
313
+ colvec_reduce_dtype,
314
+ (m, n_tiles),
315
+ leading_dim=1,
316
+ divisibility=1,
317
+ )
318
+ epi_args = GemmCls.EpilogueArguments(
319
+ mPostAct,
320
+ act_fn,
321
+ mColVecBroadcast=mColVec,
322
+ mColVecReduce=mColVecReduce,
323
+ )
324
+
325
+ def _set_implicit_dtype(gemm_obj):
326
+ gemm_obj.implicit_dtype = implicit_dtype
327
+
328
+ post_init = _set_implicit_dtype
329
+ else:
330
+ act_fn = dact_fn_map[activation]
331
+ epi_args = GemmCls.EpilogueArguments(mPostAct, act_fn)
332
+ post_init = None
333
+
334
+ scheduler_args = make_fake_scheduler_args(
335
+ (is_dynamic_persistent and device_capacity[0] == 9), False, l
336
+ )
337
+ varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None)
338
+ return compile_gemm_kernel(
339
+ GemmCls,
340
+ a_dtype,
341
+ tile_shape_mn,
342
+ cluster_shape_mnk,
343
+ pingpong,
344
+ persistent,
345
+ gather_A,
346
+ is_dynamic_persistent,
347
+ device_capacity,
348
+ mA,
349
+ mB,
350
+ mD,
351
+ mC,
352
+ epi_args,
353
+ scheduler_args,
354
+ varlen_args,
355
+ post_init=post_init,
356
+ use_tma_gather=use_tma_gather,
357
+ )
358
 
359
 
360
  def gemm_dact(
361
  A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
362
  B: Tensor, # (l, n, k)
363
+ Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m; or (l, m, 2*n)/(total_m, 2*n) if dgated
364
+ PreAct: Tensor, # same shape as Out
365
  PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
366
  tile_count_semaphore: Optional[Tensor], # (1,)
367
  activation: Optional[str],
 
371
  cluster_N: int,
372
  pingpong: bool = True,
373
  persistent: bool = True,
374
+ is_dynamic_persistent: bool = False,
375
  max_swizzle_size: int = 8,
376
+ colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m (dgated only)
377
+ # (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m (dgated only)
378
+ colvec_reduce: Optional[Tensor] = None,
379
  cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
380
  A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
381
+ use_tma_gather: bool = False,
382
  ) -> None:
383
+ is_dgated = activation in dgate_fn_map
384
+ if not is_dgated:
385
+ assert activation in dact_fn_map, f"Unsupported activation {activation}"
386
+ assert colvec_scale is None, "colvec_scale is only supported for gated activations"
387
+ assert colvec_reduce is None, "colvec_reduce is only supported for gated activations"
388
+ gemm_cls_name = "dgated" if is_dgated else "dact"
389
+
390
+ varlen_m = cu_seqlens_m is not None
391
+ gather_A = A_idx is not None
392
+ if varlen_m:
393
  assert persistent, "varlen_m requires persistent=True"
394
  assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
395
  assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
396
  assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
397
  assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
 
398
  if gather_A:
399
+ assert cu_seqlens_m is not None, "gather_A requires varlen"
400
  assert cluster_N == 1, "gather_A requires cluster_N=1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
+ # For dgated, capture implicit_dtype before viewing Out/PreAct as f32
403
+ implicit_dtype = None
404
+ if is_dgated:
405
+ AB_swapped = Out.stride(-1) != 1
406
+ implicit_dtype = torch2cute_dtype_map[Out.dtype]
407
+ assert Out.element_size() == 2, "Out dtype must be fp16 or bf16"
408
+ assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16"
409
+ if varlen_m or not AB_swapped:
410
+ Out = Out.view(torch.float32)
411
+ PreAct = PreAct.view(torch.float32)
412
+ else:
413
+ Out = Out.mT.view(torch.float32).mT
414
+ PreAct = PreAct.mT.view(torch.float32).mT
 
 
 
415
 
416
+ A_p = perm3d_single(A, varlen_m)
417
+ B_p = perm3d_single(B)
418
+ Out_p = perm3d_single(Out, varlen_m)
419
+ PreAct_p = perm3d_single(PreAct, varlen_m)
420
+ PostAct_p = perm3d_single(PostAct, varlen_m)
 
 
421
 
422
+ a_major = get_major(A_p, "m", "k")
423
+ b_major = get_major(B_p, "n", "k")
424
+ d_major = get_major(Out_p, "m", "n")
425
+ c_major = get_major(PreAct_p, "m", "n")
426
+ postact_major = get_major(PostAct_p, "m", "n")
 
 
 
 
 
 
427
 
428
+ a_dtype = torch2cute_dtype_map[A.dtype]
429
+ b_dtype = torch2cute_dtype_map[B.dtype]
430
+ d_dtype = torch2cute_dtype_map[Out.dtype]
431
+ c_dtype = torch2cute_dtype_map[PreAct.dtype]
432
+ postact_dtype = torch2cute_dtype_map[PostAct.dtype]
433
+
434
+ device_capacity = get_device_capacity(A.device)
435
+ assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
436
+
437
+ if is_dynamic_persistent and device_capacity[0] == 9:
438
+ assert tile_count_semaphore is not None, (
439
+ "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
440
+ )
441
+
442
+ compiled_fn = _compile_gemm_dact(
443
+ a_dtype,
444
+ b_dtype,
445
+ d_dtype,
446
+ c_dtype,
447
+ postact_dtype,
448
+ implicit_dtype,
449
+ a_major,
450
+ b_major,
451
+ d_major,
452
+ c_major,
453
+ postact_major,
454
+ (tile_M, tile_N),
455
+ (cluster_M, cluster_N, 1),
456
  pingpong,
457
  persistent,
458
+ is_dynamic_persistent,
459
+ activation,
460
+ torch2cute_dtype_map[colvec_scale.dtype] if colvec_scale is not None else None,
461
+ colvec_scale.ndim if colvec_scale is not None else 0,
462
+ torch2cute_dtype_map[colvec_reduce.dtype] if colvec_reduce is not None else None,
463
+ colvec_reduce.ndim if colvec_reduce is not None else 0,
464
+ varlen_m,
465
+ gather_A,
466
  device_capacity,
467
+ gemm_cls_name,
468
+ use_tma_gather=use_tma_gather,
 
 
469
  )
470
+
471
+ from .cache_utils import COMPILE_ONLY
472
+
473
+ if COMPILE_ONLY:
474
+ return
475
+
476
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
477
+ if is_dgated:
478
+ epi_args = GemmDGatedMixin.EpilogueArguments(
479
+ PostAct_p,
480
+ None, # act_bwd_fn is Constexpr
481
+ mColVecBroadcast=colvec_scale,
482
+ mColVecReduce=colvec_reduce,
483
+ rounding_mode=None,
484
+ sr_seed=None,
485
  )
486
+ else:
487
+ epi_args = GemmDActMixin.EpilogueArguments(
488
+ PostAct_p,
489
+ None,
490
+ rounding_mode=None,
491
+ sr_seed=None,
 
 
 
 
492
  )
493
+ scheduler_args = make_scheduler_args(
494
+ max_active_clusters,
495
+ max_swizzle_size,
496
+ tile_count_semaphore,
 
 
 
 
 
497
  )
498
+ varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
499
+
500
+ if device_capacity[0] in [10, 11]:
501
+ compiled_fn(
502
+ A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None
503
+ )
504
+ else:
505
+ compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None)
506
 
507
 
508
+ gemm_dgated = gemm_dact
build/torch-cuda/quack/gemm_default_epi.py CHANGED
@@ -1,189 +1,62 @@
1
  # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
- from typing import Optional, Tuple
3
- from functools import partial
4
- from dataclasses import dataclass
5
-
6
 
7
  import cutlass
8
  import cutlass.cute as cute
9
- from cutlass import Int32, Float32, Boolean, const_expr
10
 
11
- from .cute_dsl_utils import ArgumentsBase, ParamsBase
 
 
12
  from .gemm_sm90 import GemmSm90
13
  from .gemm_sm100 import GemmSm100
14
- from .sm90_utils import partition_for_epilogue
 
 
15
  from . import utils as utils
16
- from . import copy_utils as copy_utils
17
- from .varlen_utils import VarlenManager
18
 
19
 
20
- class GemmDefaultEpiMixin:
21
- num_epi_tensormaps: int = 0
 
 
 
 
 
 
22
 
23
- @dataclass
24
- class EpilogueArguments(ArgumentsBase):
25
  alpha: Optional[Float32 | cute.Tensor] = None
26
  beta: Optional[Float32 | cute.Tensor] = None
27
  mRowVecBroadcast: Optional[cute.Tensor] = None
28
  mColVecBroadcast: Optional[cute.Tensor] = None
29
- add_to_output: bool = False
30
-
31
- @dataclass
32
- class EpilogueParams(ParamsBase):
33
- alpha: Optional[Float32 | cute.Tensor] = None
34
- beta: Optional[Float32 | cute.Tensor] = None
35
- mRowVecBroadcast: Optional[cute.Tensor] = None
36
- mColVecBroadcast: Optional[cute.Tensor] = None
37
-
38
- def epi_to_underlying_arguments(
39
- self, args: EpilogueArguments, *, loc=None, ip=None
40
- ) -> EpilogueParams:
41
- # Assume all strides are divisible by 32 bits except the last stride
42
- new_stride = lambda t: tuple(
43
- cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
44
- for s in t.stride
45
- )
46
- mRowVecBroadcast, mColVecBroadcast = [
47
- cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
48
- if t is not None
49
- else None
50
- for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
51
- ]
52
- return self.EpilogueParams(
53
- alpha=args.alpha,
54
- beta=args.beta,
55
- mRowVecBroadcast=mRowVecBroadcast,
56
- mColVecBroadcast=mColVecBroadcast,
57
- )
58
-
59
- @cute.jit
60
- def epi_begin(
61
- self,
62
- params: EpilogueParams,
63
- epi_smem_tensors: Tuple[cute.Tensor, ...],
64
- epi_tile: cute.Tile,
65
- tiled_copy_t2r: Optional[cute.TiledCopy],
66
- tiled_copy_r2s: cute.TiledCopy,
67
- tile_coord_mnkl: cute.Coord,
68
- varlen_manager: VarlenManager,
69
- epilogue_barrier: cutlass.pipeline.NamedBarrier,
70
- tidx: Int32,
71
- ):
72
- alpha, beta = None, None
73
- if const_expr(hasattr(params, "alpha") and params.alpha is not None):
74
- alpha = utils.load_scalar_or_pointer(params.alpha)
75
- if const_expr(hasattr(params, "beta") and params.beta is not None):
76
- beta = utils.load_scalar_or_pointer(params.beta)
77
- sRowVec, sColVec, *rest = epi_smem_tensors
78
- tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]
79
- batch_idx = tile_coord_mnkl[3]
80
- num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
81
- # Don't need sync as we assume the previous epilogue has finished
82
-
83
- partition_for_epilogue_fn = partial(
84
- partition_for_epilogue,
85
- epi_tile=epi_tile,
86
- tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s,
87
- tidx=tidx,
88
- reference_src=tiled_copy_t2r is None,
89
- )
90
 
91
- tDsRowVec = None
92
- if const_expr(params.mRowVecBroadcast is not None):
93
- rowvec_dtype = params.mRowVecBroadcast.element_type
94
- num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width
95
- thr_copy_RV = copy_utils.tiled_copy_1d(
96
- params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
97
- ).get_slice(tidx)
98
- mRowVec = params.mRowVecBroadcast[batch_idx, None]
99
- gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],))
100
- tRVgRV = thr_copy_RV.partition_S(gRowVec)
101
- tRVsRV = thr_copy_RV.partition_D(sRowVec)
102
- tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N))
103
- limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N)
104
- tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean)
105
- for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True):
106
- tRVpRV[0, m] = tRVcRV[0, m] < limit_n
107
- cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV)
108
- # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
109
- tDsRowVec = partition_for_epilogue_fn(
110
- cute.make_tensor(
111
- sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1))
112
- )
113
- )
114
- if const_expr(tiled_copy_t2r is not None):
115
- tDsRowVec = tiled_copy_r2s.retile(tDsRowVec)
116
 
117
- tDsColVec = None
118
- if const_expr(params.mColVecBroadcast is not None):
119
- colvec_dtype = params.mColVecBroadcast.element_type
120
- num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width
121
- thr_copy_CV = copy_utils.tiled_copy_1d(
122
- params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True
123
- ).get_slice(tidx)
124
- if const_expr(not varlen_manager.varlen_m):
125
- mColVec = params.mColVecBroadcast[batch_idx, None]
126
- else:
127
- mColVec = cute.domain_offset(
128
- (varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast
129
- )
130
- gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],))
131
- tCVgCV = thr_copy_CV.partition_S(gColVec)
132
- tCVsCV = thr_copy_CV.partition_D(sColVec)
133
- tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M))
134
- limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M)
135
- tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean)
136
- for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True):
137
- tCVpCV[0, m] = tCVcCV[0, m] < limit_m
138
- cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV)
139
- tDsColVec = partition_for_epilogue_fn(
140
- cute.make_tensor(
141
- sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0))
142
- )
143
- )
144
- if const_expr(tiled_copy_t2r is not None):
145
- tDsColVec = tiled_copy_r2s.retile(tDsColVec)
146
-
147
- if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None):
148
- cute.arch.cp_async_commit_group()
149
- cute.arch.cp_async_wait_group(0)
150
- epilogue_barrier.arrive_and_wait()
151
- return alpha, beta, tDsRowVec, tDsColVec
152
-
153
- def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord):
154
- alpha, beta, tDsRowVec, tDsColVec = epi_tensors
155
- tDrRowVec_cvt = None
156
- if const_expr(tDsRowVec is not None):
157
- tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[
158
- None, None, None, epi_coord
159
- ]
160
- # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
161
- tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type)
162
- cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec))
163
- tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype)
164
- tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype))
165
- tDrColVec_cvt = None
166
- if const_expr(tDsColVec is not None):
167
- tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[
168
- None, None, None, epi_coord
169
- ]
170
- # This somehow doesn't work, some dim with stride 0 turns to non-zero stride
171
- # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur)
172
- tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type)
173
- cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec))
174
- tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype)
175
- tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype))
176
- return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt
177
 
178
  @cute.jit
179
  def epi_visit_subtile(
180
  self,
181
- params: EpilogueParams,
182
- epi_loop_tensors: Tuple[cute.Tensor, ...],
183
  tRS_rD: cute.Tensor,
184
  tRS_rC: Optional[cute.Tensor] = None,
185
  ) -> Optional[cute.Tensor]:
186
- alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors
 
 
 
187
  rD = tRS_rD.load()
188
  # Apply alpha scaling to accumulator if alpha is provided (not None)
189
  if const_expr(hasattr(params, "alpha") and params.alpha is not None):
@@ -206,49 +79,25 @@ class GemmDefaultEpiMixin:
206
  tRS_rD[i] += tDrColVec[i]
207
  return None
208
 
209
- @staticmethod
210
- def epi_smem_bytes_per_stage(
211
- args: Optional[EpilogueArguments],
212
- cta_tile_shape_mnk: Tuple[int, int, int],
213
- epi_tile: cute.Tile,
214
- ) -> int:
215
- row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1]
216
- col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0]
217
- row_vec_dtype = (
218
- args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32
219
- )
220
- col_vec_dtype = (
221
- args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32
222
- )
223
- return (
224
- row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width
225
- ) // 8
226
-
227
- def epi_get_smem_struct(self, params: EpilogueParams):
228
- row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
229
- col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
230
- row_vec_dtype = (
231
- params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
232
- )
233
- col_vec_dtype = (
234
- params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
235
- )
236
-
237
- @cute.struct
238
- class EpiSharedStorage:
239
- sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
240
- sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
241
-
242
- return EpiSharedStorage
243
 
244
- def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
245
- sRowVec = None
246
- if const_expr(params.mRowVecBroadcast is not None):
247
- sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1]))
248
- sColVec = None
249
- if const_expr(params.mColVecBroadcast is not None):
250
- sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0]))
251
- return (sRowVec, sColVec)
252
 
253
 
254
  class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
@@ -257,3 +106,7 @@ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
257
 
258
  class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
259
  pass
 
 
 
 
 
1
  # Copyright (c) 2025, Wentao Guo, Tri Dao.
2
+ from typing import NamedTuple, Optional
 
 
 
3
 
4
  import cutlass
5
  import cutlass.cute as cute
6
+ from cutlass import Int32, Float32, const_expr
7
 
8
+ from .cute_dsl_utils import mlir_namedtuple
9
+ from .epi_composable import ComposableEpiMixin
10
+ from .epi_ops import Scalar, RowVecLoad, ColVecLoad
11
  from .gemm_sm90 import GemmSm90
12
  from .gemm_sm100 import GemmSm100
13
+ from .gemm_sm120 import GemmSm120
14
+ from .rounding import RoundingMode
15
+ from . import layout_utils as layout_utils
16
  from . import utils as utils
 
 
17
 
18
 
19
+ class GemmDefaultEpiMixin(ComposableEpiMixin):
20
+ _epi_ops = (
21
+ Scalar("alpha"),
22
+ Scalar("beta"),
23
+ Scalar("sr_seed", dtype=Int32),
24
+ RowVecLoad("mRowVecBroadcast"),
25
+ ColVecLoad("mColVecBroadcast"),
26
+ )
27
 
28
+ @mlir_namedtuple
29
+ class EpilogueArguments(NamedTuple):
30
  alpha: Optional[Float32 | cute.Tensor] = None
31
  beta: Optional[Float32 | cute.Tensor] = None
32
  mRowVecBroadcast: Optional[cute.Tensor] = None
33
  mColVecBroadcast: Optional[cute.Tensor] = None
34
+ add_to_output: cutlass.Constexpr[bool] = False
35
+ rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
36
+ sr_seed: Optional[Int32 | cute.Tensor] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # EpilogueParams auto-generated from _epi_ops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def epi_to_underlying_arguments(self, args, *, loc=None, ip=None):
41
+ self.rounding_mode = args.rounding_mode
42
+ d = self._epi_ops_to_params_dict(args)
43
+ for key in ("mRowVecBroadcast", "mColVecBroadcast"):
44
+ if key in self.concat_layout and key in d and d[key] is not None:
45
+ d[key] = layout_utils.concat_to_interleave(d[key], 1)
46
+ return self.EpilogueParams(**d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @cute.jit
49
  def epi_visit_subtile(
50
  self,
51
+ params,
52
+ epi_loop_tensors,
53
  tRS_rD: cute.Tensor,
54
  tRS_rC: Optional[cute.Tensor] = None,
55
  ) -> Optional[cute.Tensor]:
56
+ alpha = epi_loop_tensors["alpha"]
57
+ beta = epi_loop_tensors["beta"]
58
+ tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
59
+ tDrColVec = epi_loop_tensors["mColVecBroadcast"]
60
  rD = tRS_rD.load()
61
  # Apply alpha scaling to accumulator if alpha is provided (not None)
62
  if const_expr(hasattr(params, "alpha") and params.alpha is not None):
 
79
  tRS_rD[i] += tDrColVec[i]
80
  return None
81
 
82
+ def epi_setup_postact(
83
+ self,
84
+ params,
85
+ epi_smem_tensors,
86
+ tiled_copy_r2s,
87
+ tiled_copy_t2r,
88
+ tile_coord_mnkl,
89
+ varlen_manager,
90
+ tidx,
91
+ ):
92
+ """Returns None — default epilogue has no postact output."""
93
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ @cute.jit
96
+ def epi_convert_postact(
97
+ self, tRS_rPostAct, sr_seed, tidx, tile_coord_mnkl, num_prev_subtiles, epi_idx
98
+ ):
99
+ """Convert postact from acc_dtype to output dtype. Override for custom postprocessing."""
100
+ return tRS_rPostAct
 
 
101
 
102
 
103
  class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
 
106
 
107
  class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
108
  pass
109
+
110
+
111
+ class GemmDefaultSm120(GemmDefaultEpiMixin, GemmSm120):
112
+ pass
build/torch-cuda/quack/gemm_interface.py CHANGED
@@ -3,18 +3,22 @@ from typing import Optional, Tuple, Literal
3
  from functools import partial
4
 
5
  import torch
 
6
  import torch.nn.functional as F
7
  from torch import Tensor
8
- from ._ops_compat import add_quack_op_namespace_prefix
9
 
10
  from .gemm_config import GemmConfig, get_all_configs
11
 
12
  from .autotuner import autotune, AutotuneConfig
13
  from .cute_dsl_utils import get_device_capacity
14
- from .gemm import gemm as gemm_sm90_sm100
15
- from .gemm_act import gemm_act as gemm_act_sm90_sm100
16
- from .gemm_dact import gemm_dact as gemm_dact_sm90_sm100
17
- from .gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
 
 
 
 
18
 
19
 
20
  # Dictionary mapping activation names to PyTorch functions
@@ -37,54 +41,100 @@ gated_to_pytorch_fn_map = {
37
  }
38
 
39
 
40
- def _get_default_device_capacity():
41
- if not torch.cuda.is_available():
42
- return (9, 0)
43
- cap = get_device_capacity(torch.device("cuda"))
44
- if cap[0] not in (9, 10):
45
- return (9, 0)
46
- return cap
 
 
 
 
 
 
47
 
48
 
49
- class _LazyDeviceCapacity:
50
- """Defer torch.cuda.get_device_capability until first access so the
51
- module can be imported in environments without a GPU (e.g. nix build)."""
52
- _value = None
53
- def __getitem__(self, idx):
54
- if self._value is None:
55
- self._value = _get_default_device_capacity()
56
- return self._value[idx]
57
 
58
 
59
- default_device_capacity = _LazyDeviceCapacity()
 
 
 
60
 
61
 
62
  def default_config(device):
63
- if get_device_capacity(device)[0] != 10:
64
- return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  else:
66
- return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
70
  kwargs = named_args | kwargs
 
 
71
  gather_A = kwargs.get("A_idx", None) is not None
72
  varlen_m = kwargs.get("cu_seqlens_m", None) is not None
73
  if varlen_m or gather_A: # Doesn't support swap_ab
74
  configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
75
  if gather_A:
76
- if get_device_capacity(kwargs["A"].device)[0] == 9:
77
- # tile_n == 208 causes register spills, as gather_A requires more registers for the producer
78
- configs = [
79
- conf
80
- for conf in configs
81
- if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
82
- ]
83
  return configs
84
 
85
 
86
  @autotune(
87
- configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
88
  key=["dynamic_scheduler"],
89
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
90
  )
@@ -104,9 +154,25 @@ def gemm_tuned(
104
  add_to_output: bool = False,
105
  dynamic_scheduler: bool = False,
106
  config: Optional[GemmConfig] = None,
 
 
 
107
  ) -> None:
108
  if config is None:
109
- config = default_config(A.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  varlen_m = cu_seqlens_m is not None
111
  varlen_k = cu_seqlens_k is not None
112
  varlen = varlen_m or varlen_k
@@ -135,10 +201,31 @@ def gemm_tuned(
135
  else:
136
  out_shape = (batch_size, A.shape[-2], B.shape[-2])
137
  assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
 
138
  tile_count_semaphore = (
139
- torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
- gemm_sm90_sm100(
142
  A if not config.swap_ab else B,
143
  B if not config.swap_ab else A,
144
  out if not config.swap_ab else out.mT,
@@ -150,6 +237,7 @@ def gemm_tuned(
150
  config.cluster_n,
151
  config.pingpong,
152
  persistent=True,
 
153
  max_swizzle_size=config.max_swizzle_size,
154
  rowvec_bias=bias if not config.swap_ab else None,
155
  colvec_bias=bias if config.swap_ab else None,
@@ -160,11 +248,15 @@ def gemm_tuned(
160
  A_idx=A_idx,
161
  batch_idx_permute=batch_idx_permute,
162
  add_to_output=add_to_output,
 
 
 
 
163
  )
164
 
165
 
166
  @autotune(
167
- configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
168
  key=["activation", "dynamic_scheduler"],
169
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
170
  )
@@ -177,7 +269,7 @@ def gemm_act_tuned(
177
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
178
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
179
  bias: Optional[Tensor] = None, # (N,) or (L, N)
180
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
181
  cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
182
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
183
  dynamic_scheduler: bool = False,
@@ -205,10 +297,13 @@ def gemm_act_tuned(
205
  PostAct = postact_out
206
  if bias is not None and bias.ndim == 1:
207
  bias = bias.unsqueeze(0) # (L, N)
 
208
  tile_count_semaphore = (
209
- torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
 
 
210
  )
211
- gemm_act_sm90_sm100(
212
  A if not config.swap_ab else B,
213
  B if not config.swap_ab else A,
214
  (D if not config.swap_ab else D.mT) if D is not None else None,
@@ -222,16 +317,18 @@ def gemm_act_tuned(
222
  config.cluster_n,
223
  config.pingpong,
224
  persistent=True,
 
225
  max_swizzle_size=config.max_swizzle_size,
226
  rowvec_bias=bias if not config.swap_ab else None,
227
  colvec_bias=bias if config.swap_ab else None,
228
  cu_seqlens_m=cu_seqlens_m,
229
  A_idx=A_idx,
 
230
  )
231
 
232
 
233
  @autotune(
234
- configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
235
  key=["activation", "dynamic_scheduler"],
236
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
237
  )
@@ -242,7 +339,7 @@ def gemm_dact_tuned(
242
  PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
243
  dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
244
  postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
245
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
246
  cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
247
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
248
  dynamic_scheduler: bool = True,
@@ -268,10 +365,13 @@ def gemm_dact_tuned(
268
  PostAct = postact_out.unsqueeze(0)
269
  else:
270
  PostAct = postact_out
 
271
  tile_count_semaphore = (
272
- torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
 
 
273
  )
274
- gemm_dact_sm90_sm100(
275
  A if not config.swap_ab else B,
276
  B if not config.swap_ab else A,
277
  D if not config.swap_ab else D.mT,
@@ -285,9 +385,11 @@ def gemm_dact_tuned(
285
  config.cluster_n,
286
  config.pingpong,
287
  persistent=True,
 
288
  max_swizzle_size=config.max_swizzle_size,
289
  cu_seqlens_m=cu_seqlens_m,
290
  A_idx=A_idx,
 
291
  )
292
 
293
 
@@ -305,6 +407,9 @@ def gemm(
305
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
306
  dynamic_scheduler: bool = False,
307
  tuned: bool = True,
 
 
 
308
  ) -> Tensor:
309
  """GEMM with optional output tensor and tuning control."""
310
  if out is None:
@@ -325,6 +430,9 @@ def gemm(
325
  out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
326
  alpha_tensor = alpha if not isinstance(alpha, float) else None
327
  alpha = alpha if isinstance(alpha, float) else 1.0
 
 
 
328
  gemm_out(
329
  A,
330
  B,
@@ -338,6 +446,10 @@ def gemm(
338
  batch_idx_permute=batch_idx_permute,
339
  dynamic_scheduler=dynamic_scheduler,
340
  tuned=tuned,
 
 
 
 
341
  )
342
  return out
343
 
@@ -364,10 +476,15 @@ def gemm_out(
364
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
365
  dynamic_scheduler: bool = False,
366
  tuned: bool = True,
 
 
 
 
367
  ) -> None:
368
  """GEMM with pre-allocated output tensor."""
369
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
370
  alpha = alpha_tensor if alpha_tensor is not None else alpha
 
371
  fn(
372
  A,
373
  B,
@@ -380,6 +497,9 @@ def gemm_out(
380
  A_idx=A_idx,
381
  batch_idx_permute=batch_idx_permute,
382
  dynamic_scheduler=dynamic_scheduler,
 
 
 
383
  )
384
 
385
 
@@ -394,10 +514,18 @@ def gemm_ref(
394
  cu_seqlens_k: Optional[Tensor] = None,
395
  A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
396
  out_dtype: Optional[torch.dtype] = None,
 
397
  ) -> Tensor:
398
  """Reference implementation for GEMM with pre-allocated output."""
399
  # The out_dtype argument requires torch >= 2.8
400
  out_dtype = A.dtype if out_dtype is None else out_dtype
 
 
 
 
 
 
 
401
  if cu_seqlens_m is None and cu_seqlens_k is None:
402
  fn = torch.bmm if A.ndim == 3 else torch.mm
403
  out = fn(A, B, out_dtype=out_dtype, out=out)
@@ -438,6 +566,9 @@ def gemm_ref(
438
  out *= alpha
439
  if bias is not None:
440
  out += bias
 
 
 
441
  return out
442
 
443
 
@@ -456,6 +587,7 @@ def gemm_add(
456
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
457
  dynamic_scheduler: bool = False,
458
  tuned: bool = True,
 
459
  ) -> Tensor:
460
  """GEMM with addition and optional output tensor."""
461
  if out is None:
@@ -480,23 +612,43 @@ def gemm_add(
480
  alpha = alpha if isinstance(alpha, float) else 1.0
481
  beta_tensor = beta if not isinstance(beta, float) else None
482
  beta = beta if isinstance(beta, float) else 1.0
483
- gemm_add_out(
484
- A,
485
- B,
486
- C if not add_to_output else None,
487
- out,
488
- alpha,
489
- beta,
490
- alpha_tensor,
491
- beta_tensor,
492
- cu_seqlens_m=cu_seqlens_m,
493
- cu_seqlens_k=cu_seqlens_k,
494
- A_idx=A_idx,
495
- batch_idx_permute=batch_idx_permute,
496
- add_to_output=add_to_output,
497
- dynamic_scheduler=dynamic_scheduler,
498
- tuned=tuned,
499
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  return out
501
 
502
 
@@ -525,6 +677,7 @@ def gemm_add_out(
525
  add_to_output: bool = False,
526
  dynamic_scheduler: bool = False,
527
  tuned: bool = True,
 
528
  ) -> None:
529
  """GEMM with addition and pre-allocated output tensor."""
530
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
@@ -543,6 +696,7 @@ def gemm_add_out(
543
  batch_idx_permute=batch_idx_permute,
544
  add_to_output=add_to_output,
545
  dynamic_scheduler=dynamic_scheduler,
 
546
  )
547
 
548
 
@@ -559,8 +713,18 @@ def gemm_add_ref(
559
  cu_seqlens_k: Optional[Tensor] = None,
560
  A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
561
  out_dtype: Optional[torch.dtype] = None,
 
562
  ) -> Tensor:
563
  """Reference implementation for GEMM with addition and pre-allocated output."""
 
 
 
 
 
 
 
 
 
564
  if cu_seqlens_m is None and cu_seqlens_k is None:
565
  if isinstance(alpha, float) and isinstance(beta, float):
566
  out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
@@ -571,6 +735,8 @@ def gemm_add_ref(
571
  result = (alpha * (A @ B) + beta * C).to(out_dtype)
572
  if out is not None:
573
  out.copy_(result)
 
 
574
  if bias is not None:
575
  bias = bias if A.ndim == 2 else bias.unsqueeze(1)
576
  out += bias
@@ -610,6 +776,8 @@ def gemm_add_ref(
610
  out[i].copy_(result)
611
  if bias is not None:
612
  out += bias
 
 
613
  return out
614
 
615
 
@@ -626,6 +794,7 @@ def gemm_add_inplace(
626
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
627
  dynamic_scheduler: bool = False,
628
  tuned: bool = True,
 
629
  ) -> None:
630
  """In-place GEMM with addition: out = alpha * A @ B + beta * out.
631
  Args:
@@ -657,6 +826,9 @@ def gemm_add_inplace(
657
  batch_idx_permute=batch_idx_permute,
658
  dynamic_scheduler=dynamic_scheduler,
659
  tuned=tuned,
 
 
 
660
  )
661
 
662
 
@@ -683,6 +855,7 @@ def gemm_add_inplace_op(
683
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
684
  dynamic_scheduler: bool = False,
685
  tuned: bool = True,
 
686
  ) -> None:
687
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
688
  alpha = alpha_tensor if alpha_tensor is not None else alpha
@@ -702,6 +875,7 @@ def gemm_add_inplace_op(
702
  batch_idx_permute=batch_idx_permute,
703
  add_to_output=add_to_output,
704
  dynamic_scheduler=dynamic_scheduler,
 
705
  )
706
 
707
 
@@ -710,7 +884,7 @@ def gemm_act(
710
  B: Tensor, # (K, N) or (L, K, N)
711
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
712
  bias: Optional[Tensor] = None, # (N,) or (L, N)
713
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
714
  preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
715
  postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
716
  out_dtype: Optional[torch.dtype] = None,
@@ -720,8 +894,10 @@ def gemm_act(
720
  store_preact: bool = True,
721
  dynamic_scheduler: bool = False,
722
  tuned: bool = True,
 
723
  ) -> Tuple[Optional[Tensor], Tensor]:
724
- """GEMM with activation and optional output tensors."""
 
725
  out_dtype = A.dtype if out_dtype is None else out_dtype
726
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
727
  varlen_m = cu_seqlens_m is not None
@@ -733,26 +909,47 @@ def gemm_act(
733
  out_shape = (A.shape[0], B.shape[-1])
734
  else:
735
  out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
 
736
  if preact_out is None and store_preact:
737
  preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
738
  if postact_out is None:
739
- postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
740
- gemm_act_out(
741
- A,
742
- B,
743
- preact_out,
744
- postact_out,
745
- C,
746
- bias,
747
- activation,
748
- cu_seqlens_m,
749
- A_idx,
750
- dynamic_scheduler,
751
- tuned,
752
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
  return preact_out, postact_out
754
 
755
 
 
 
 
756
  @torch.library.custom_op(
757
  add_quack_op_namespace_prefix("gemm_act_out"),
758
  mutates_args=("preact_out", "postact_out"),
@@ -766,7 +963,7 @@ def gemm_act_out(
766
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
767
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
768
  bias: Optional[Tensor] = None, # (N,) or (L, N)
769
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
770
  cu_seqlens_m: Optional[Tensor] = None,
771
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
772
  dynamic_scheduler: bool = False,
@@ -782,57 +979,111 @@ def gemm_act_ref(
782
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
783
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
784
  bias: Optional[Tensor] = None, # (N,) or (L, N)
785
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
786
  cu_seqlens_m: Optional[Tensor] = None,
787
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
788
  out_dtype: Optional[torch.dtype] = None,
789
  postact_dtype: Optional[torch.dtype] = None,
790
  store_preact: bool = True,
 
791
  ) -> Tuple[Optional[Tensor], Tensor]:
 
792
  out_dtype = A.dtype if out_dtype is None else out_dtype
793
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
794
  if C is None:
795
- out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
 
 
796
  else:
797
- out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
798
- postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
799
- return out.to(out_dtype) if store_preact else None, postact
 
 
 
 
 
 
 
 
 
 
 
 
800
 
801
 
802
  def gemm_dact(
803
  A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
804
  B: Tensor, # (K, N) or (L, K, N)
805
- PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
806
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
807
- dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
 
 
808
  postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
809
  out_dtype: Optional[torch.dtype] = None,
810
  postact_dtype: Optional[torch.dtype] = None,
 
 
811
  cu_seqlens_m: Optional[Tensor] = None,
812
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
813
  dynamic_scheduler: bool = True,
814
  tuned: bool = True,
815
- ) -> Tuple[Tensor, Tensor]:
816
- """GEMM with activation gradient and optional output tensors."""
 
817
  out_dtype = A.dtype if out_dtype is None else out_dtype
818
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
819
  varlen_m = cu_seqlens_m is not None
820
- # Determine output shape based on gather_A
821
  if varlen_m:
822
  total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
823
- out_shape = (total_m, B.shape[-1])
824
  elif A.ndim == 2:
825
- out_shape = (A.shape[0], B.shape[-1])
826
  else:
827
- out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
 
 
828
  if dx_out is None:
829
  dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
830
  if postact_out is None:
831
- postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
832
- gemm_dact_out(
833
- A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
834
- )
835
- return dx_out, postact_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
 
837
 
838
  @torch.library.custom_op(
@@ -847,7 +1098,7 @@ def gemm_dact_out(
847
  PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
848
  dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
849
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
850
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
851
  cu_seqlens_m: Optional[Tensor] = None,
852
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
853
  dynamic_scheduler: bool = True,
@@ -859,115 +1110,46 @@ def gemm_dact_out(
859
 
860
 
861
  def gemm_dact_ref(
862
- A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
863
- B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
864
- PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
865
- activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
866
  cu_seqlens_m: Optional[Tensor] = None,
867
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
868
  out_dtype: Optional[torch.dtype] = None,
869
  postact_dtype: Optional[torch.dtype] = None,
870
  ) -> Tuple[Tensor, Tensor]:
871
- """Reference implementation for GEMM with activation gradient."""
 
872
  out_dtype = A.dtype if out_dtype is None else out_dtype
873
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
874
  dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
875
- postact = act_to_pytorch_fn_map[activation](PreAct)
876
- # Compute gradient using autograd
877
- if activation is None:
878
- dx = dout
879
- else:
880
- PreAct_requires_grad = PreAct.requires_grad
881
- PreAct.requires_grad_(True)
882
- postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
883
- dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
884
- PreAct.requires_grad_(PreAct_requires_grad)
885
- return dx.to(out_dtype), postact.to(postact_dtype)
886
-
887
-
888
- def gemm_gated_ref(
889
- A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
890
- B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
891
- C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
892
- bias: Optional[Tensor] = None, # (N,) or (L, N)
893
- activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
894
- cu_seqlens_m: Optional[Tensor] = None,
895
- A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
896
- out_dtype: Optional[torch.dtype] = None,
897
- postact_dtype: Optional[torch.dtype] = None,
898
- store_preact: bool = True,
899
- ) -> Tuple[Optional[Tensor], Tensor]:
900
- """Reference implementation for GEMM with gated activation forward.
901
-
902
- Args:
903
- A: (M, K) - input tensor
904
- B: (K, N) - weight tensor with gate and up projections
905
- C: (M, N) - optional bias tensor
906
- activation: Type of gated activation
907
- out_dtype: Output dtype for preact
908
- postact_dtype: Output dtype for postact
909
- store_preact: Whether to return the pre-activation
910
-
911
- Returns:
912
- (preact, postact) where:
913
- - preact: (M, N) pre-activation (if store_preact=True, else None)
914
- - postact: (M, N // 2) post-activation output
915
- """
916
- out_dtype = A.dtype if out_dtype is None else out_dtype
917
- postact_dtype = A.dtype if postact_dtype is None else postact_dtype
918
- if C is None:
919
- preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
920
  else:
921
- preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
922
- # Split preact into gate and up projections
923
- gate = preact[..., ::2] # (M, N//2)
924
- up = preact[..., 1::2] # (M, N//2)
925
- postact = gated_to_pytorch_fn_map[activation](gate, up)
926
- return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
927
-
 
 
 
928
 
929
- def gemm_dgated_ref(
930
- A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
931
- B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
932
- PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
933
- activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
934
- cu_seqlens_m: Optional[Tensor] = None,
935
- A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
936
- out_dtype: Optional[torch.dtype] = None,
937
- postact_dtype: Optional[torch.dtype] = None,
938
- ) -> Tuple[Tensor, Tensor]:
939
- """Reference implementation for GEMM with gated activation gradient.
940
 
941
- Args:
942
- A: (M, K) - dout input tensor
943
- B: (K, N) - weight tensor
944
- PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
945
- activation: Type of gated activation
946
- out_dtype: Output dtype for dx
947
- postact_dtype: Output dtype for postact
948
-
949
- Returns:
950
- (dx, postact) where:
951
- - dx: (M, 2*N) gradient w.r.t. PreAct
952
- - postact: (M, N) post-activation output
953
- """
954
- out_dtype = A.dtype if out_dtype is None else out_dtype
955
- postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
956
- dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
957
- # Split PreAct into gate and up projections
958
- gate = PreAct[..., ::2] # (M, N)
959
- up = PreAct[..., 1::2] # (M, N)
960
- # Use autograd to compute gradients w.r.t. gate and up
961
- gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
962
- gate.requires_grad_(True)
963
- up.requires_grad_(True)
964
- postact = gated_to_pytorch_fn_map[activation](gate, up)
965
- dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
966
- gate.requires_grad_(gate_requires_grad)
967
- up.requires_grad_(up_requires_grad)
968
- # Interleave gradients back
969
- dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
970
- return dx.to(out_dtype), postact.to(postact_dtype)
971
 
972
 
973
  @torch.library.custom_op(
@@ -1000,18 +1182,27 @@ def gemm_symmetric_out(
1000
  tile_count_semaphore = (
1001
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
1002
  )
1003
- gemm_symmetric_sm90_sm100(
 
 
 
 
 
 
 
 
1004
  A,
1005
  B,
1006
  out if out is not None else None,
1007
  C if C is not None else None,
1008
  tile_count_semaphore,
1009
- tile_M=128,
1010
- tile_N=256,
1011
- cluster_M=2,
1012
  cluster_N=1,
1013
- pingpong=False,
1014
  persistent=True,
 
1015
  max_swizzle_size=8,
1016
  alpha=alpha,
1017
  beta=beta,
@@ -1047,6 +1238,933 @@ def gemm_symmetric(
1047
  return out
1048
 
1049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1050
  # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
1051
  # try:
1052
  # from torch._inductor.fx_passes.reinplace import InplaceableOp
 
3
  from functools import partial
4
 
5
  import torch
6
+ from ._ops_compat import add_quack_op_namespace_prefix
7
  import torch.nn.functional as F
8
  from torch import Tensor
 
9
 
10
  from .gemm_config import GemmConfig, get_all_configs
11
 
12
  from .autotuner import autotune, AutotuneConfig
13
  from .cute_dsl_utils import get_device_capacity
14
+ from .gemm import gemm as gemm_dispatch
15
+ from .gemm_act import gemm_act as gemm_act_dispatch
16
+ from .gemm_dact import gemm_dact as gemm_dact_dispatch
17
+ from .gemm_symmetric import gemm_symmetric as gemm_symmetric_dispatch
18
+ from .gemm_sq_reduce import gemm_sq_reduce as gemm_sq_reduce_dispatch
19
+ from .gemm_norm_act import gemm_norm_act_fn as gemm_norm_act_dispatch
20
+ from .rms_final_reduce import rms_final_reduce
21
+ from .rounding import RoundingMode
22
 
23
 
24
  # Dictionary mapping activation names to PyTorch functions
 
41
  }
42
 
43
 
44
+ ActActivation = Literal[None, "relu", "relu_sq", "gelu_tanh_approx"]
45
+ GatedActivation = Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"]
46
+ Activation = Literal[
47
+ None,
48
+ "relu",
49
+ "relu_sq",
50
+ "gelu_tanh_approx",
51
+ "swiglu",
52
+ "swiglu_oai",
53
+ "reglu",
54
+ "geglu",
55
+ "glu",
56
+ ]
57
 
58
 
59
+ def _concat_interleave(t):
60
+ """Interleave halves along non-contiguous dim: [first; second] [f0, s0, f1, ...]"""
61
+ dim = -2 if t.stride(-1) == 1 else -1
62
+ return t.unflatten(dim, (2, t.shape[dim] // 2)).transpose(dim - 1, dim).flatten(dim - 1, dim)
 
 
 
 
63
 
64
 
65
+ def _concat_interleave_bias(t):
66
+ """Interleave [gate; up] along last dim for bias vectors."""
67
+ half = t.shape[-1] // 2
68
+ return t.unflatten(-1, (2, half)).transpose(-2, -1).flatten(-2, -1)
69
 
70
 
71
  def default_config(device):
72
+ cap = get_device_capacity(device)[0]
73
+ if cap in [10, 11]:
74
+ return GemmConfig(
75
+ tile_m=256,
76
+ tile_n=256,
77
+ cluster_m=2,
78
+ cluster_n=1,
79
+ pingpong=False,
80
+ is_dynamic_persistent=True,
81
+ device_capacity=10,
82
+ )
83
+ elif cap == 12:
84
+ return GemmConfig(
85
+ tile_m=128,
86
+ tile_n=128,
87
+ cluster_m=1,
88
+ cluster_n=1,
89
+ pingpong=True,
90
+ is_dynamic_persistent=True,
91
+ device_capacity=12,
92
+ )
93
  else:
94
+ return GemmConfig(
95
+ tile_m=128,
96
+ tile_n=192,
97
+ cluster_m=2,
98
+ cluster_n=1,
99
+ pingpong=True,
100
+ is_dynamic_persistent=False,
101
+ )
102
+
103
+
104
+ def nvmmh_config(A, B, device_capacity):
105
+ """Use nvMatmulHeuristics to pick a config for pure GEMM (no varlen/gather/epilogue).
106
+
107
+ Returns None if unavailable, caller should fall back to default_config.
108
+ """
109
+ try:
110
+ from .nvmmh_heuristic import nvmmh_default_config
111
+
112
+ return nvmmh_default_config(A, B, device_capacity)
113
+ except Exception:
114
+ return None
115
 
116
 
117
  def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
118
  kwargs = named_args | kwargs
119
+ device_capacity = get_device_capacity(kwargs["A"].device)[0]
120
+ configs = [conf for conf in configs if conf.kwargs["config"].device_capacity == device_capacity]
121
  gather_A = kwargs.get("A_idx", None) is not None
122
  varlen_m = kwargs.get("cu_seqlens_m", None) is not None
123
  if varlen_m or gather_A: # Doesn't support swap_ab
124
  configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
125
  if gather_A:
126
+ configs = [conf for conf in configs if conf.kwargs["config"].cluster_n == 1]
127
+ if device_capacity == 9:
128
+ configs = [conf for conf in configs if conf.kwargs["config"].tile_n != 208]
129
+ configs = [conf for conf in configs if not conf.kwargs["config"].is_dynamic_persistent]
130
+ # use_tma_gather only valid when gather_A is active on SM100/SM110
131
+ if not gather_A or device_capacity not in [10, 11]:
132
+ configs = [conf for conf in configs if not conf.kwargs["config"].use_tma_gather]
133
  return configs
134
 
135
 
136
  @autotune(
137
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
138
  key=["dynamic_scheduler"],
139
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
140
  )
 
154
  add_to_output: bool = False,
155
  dynamic_scheduler: bool = False,
156
  config: Optional[GemmConfig] = None,
157
+ rounding_mode: int = RoundingMode.RN,
158
+ sr_seed: int | Tensor = 0,
159
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
160
  ) -> None:
161
  if config is None:
162
+ # Use nvMMH heuristic for pure GEMM (no varlen, no gather, no epilogue)
163
+ is_pure_gemm = (
164
+ cu_seqlens_m is None
165
+ and cu_seqlens_k is None
166
+ and A_idx is None
167
+ and C is None
168
+ and bias is None
169
+ and not add_to_output
170
+ )
171
+ if is_pure_gemm:
172
+ device_capacity = get_device_capacity(A.device)[0]
173
+ config = nvmmh_config(A, B, device_capacity)
174
+ if config is None:
175
+ config = default_config(A.device)
176
  varlen_m = cu_seqlens_m is not None
177
  varlen_k = cu_seqlens_k is not None
178
  varlen = varlen_m or varlen_k
 
201
  else:
202
  out_shape = (batch_size, A.shape[-2], B.shape[-2])
203
  assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
204
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
205
  tile_count_semaphore = (
206
+ torch.zeros(1, dtype=torch.int32, device=A.device)
207
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
208
+ else None
209
+ )
210
+ # Handle bias concat layout: transform "bias" key to kernel-level key or permute data.
211
+ if concat_layout and "bias" in concat_layout:
212
+ if bias is not None and bias.dtype.itemsize >= 4:
213
+ # fp32: kernel permutes via layout; replace "bias" with the kernel-level key
214
+ concat_layout = tuple("mRowVecBroadcast" if k == "bias" else k for k in concat_layout)
215
+ else:
216
+ # No bias or sub-fp32: strip "bias" from concat_layout; permute data if needed
217
+ concat_layout = tuple(k for k in concat_layout if k != "bias")
218
+ if bias is not None:
219
+ bias = _concat_interleave_bias(bias)
220
+ # When swap_ab, A↔B (out/C stay, but .mT flips their strides so the kernel
221
+ # auto-detects the correct non-contiguous dim).
222
+ _swap_map = {"A": "B", "B": "A", "out": "out", "C": "C", "mRowVecBroadcast": "mColVecBroadcast"}
223
+ swapped_concat = (
224
+ tuple(_swap_map.get(k, k) for k in concat_layout)
225
+ if config.swap_ab and concat_layout
226
+ else concat_layout
227
  )
228
+ gemm_dispatch(
229
  A if not config.swap_ab else B,
230
  B if not config.swap_ab else A,
231
  out if not config.swap_ab else out.mT,
 
237
  config.cluster_n,
238
  config.pingpong,
239
  persistent=True,
240
+ is_dynamic_persistent=dynamic_scheduler,
241
  max_swizzle_size=config.max_swizzle_size,
242
  rowvec_bias=bias if not config.swap_ab else None,
243
  colvec_bias=bias if config.swap_ab else None,
 
248
  A_idx=A_idx,
249
  batch_idx_permute=batch_idx_permute,
250
  add_to_output=add_to_output,
251
+ rounding_mode=rounding_mode,
252
+ sr_seed=sr_seed,
253
+ use_tma_gather=config.use_tma_gather,
254
+ concat_layout=swapped_concat,
255
  )
256
 
257
 
258
  @autotune(
259
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
260
  key=["activation", "dynamic_scheduler"],
261
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
262
  )
 
269
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
270
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
271
  bias: Optional[Tensor] = None, # (N,) or (L, N)
272
+ activation: ActActivation = None,
273
  cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
274
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
275
  dynamic_scheduler: bool = False,
 
297
  PostAct = postact_out
298
  if bias is not None and bias.ndim == 1:
299
  bias = bias.unsqueeze(0) # (L, N)
300
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
301
  tile_count_semaphore = (
302
+ torch.zeros(1, dtype=torch.int32, device=A.device)
303
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
304
+ else None
305
  )
306
+ gemm_act_dispatch(
307
  A if not config.swap_ab else B,
308
  B if not config.swap_ab else A,
309
  (D if not config.swap_ab else D.mT) if D is not None else None,
 
317
  config.cluster_n,
318
  config.pingpong,
319
  persistent=True,
320
+ is_dynamic_persistent=dynamic_scheduler,
321
  max_swizzle_size=config.max_swizzle_size,
322
  rowvec_bias=bias if not config.swap_ab else None,
323
  colvec_bias=bias if config.swap_ab else None,
324
  cu_seqlens_m=cu_seqlens_m,
325
  A_idx=A_idx,
326
+ use_tma_gather=config.use_tma_gather,
327
  )
328
 
329
 
330
  @autotune(
331
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
332
  key=["activation", "dynamic_scheduler"],
333
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
334
  )
 
339
  PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
340
  dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
341
  postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
342
+ activation: ActActivation = None,
343
  cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
344
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
345
  dynamic_scheduler: bool = True,
 
365
  PostAct = postact_out.unsqueeze(0)
366
  else:
367
  PostAct = postact_out
368
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
369
  tile_count_semaphore = (
370
+ torch.zeros(1, dtype=torch.int32, device=A.device)
371
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
372
+ else None
373
  )
374
+ gemm_dact_dispatch(
375
  A if not config.swap_ab else B,
376
  B if not config.swap_ab else A,
377
  D if not config.swap_ab else D.mT,
 
385
  config.cluster_n,
386
  config.pingpong,
387
  persistent=True,
388
+ is_dynamic_persistent=dynamic_scheduler,
389
  max_swizzle_size=config.max_swizzle_size,
390
  cu_seqlens_m=cu_seqlens_m,
391
  A_idx=A_idx,
392
+ use_tma_gather=config.use_tma_gather,
393
  )
394
 
395
 
 
407
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
408
  dynamic_scheduler: bool = False,
409
  tuned: bool = True,
410
+ rounding_mode: int = RoundingMode.RN,
411
+ sr_seed: int | Tensor = 0,
412
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
413
  ) -> Tensor:
414
  """GEMM with optional output tensor and tuning control."""
415
  if out is None:
 
430
  out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
431
  alpha_tensor = alpha if not isinstance(alpha, float) else None
432
  alpha = alpha if isinstance(alpha, float) else 1.0
433
+ sr_seed_tensor = sr_seed if isinstance(sr_seed, Tensor) else None
434
+ sr_seed_int = sr_seed if isinstance(sr_seed, int) else 0
435
+ concat_str = ",".join(concat_layout) if concat_layout else None
436
  gemm_out(
437
  A,
438
  B,
 
446
  batch_idx_permute=batch_idx_permute,
447
  dynamic_scheduler=dynamic_scheduler,
448
  tuned=tuned,
449
+ rounding_mode=rounding_mode,
450
+ sr_seed=sr_seed_int,
451
+ sr_seed_tensor=sr_seed_tensor,
452
+ concat_layout=concat_str,
453
  )
454
  return out
455
 
 
476
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
477
  dynamic_scheduler: bool = False,
478
  tuned: bool = True,
479
+ rounding_mode: int = RoundingMode.RN,
480
+ sr_seed: int = 0,
481
+ sr_seed_tensor: Optional[Tensor] = None,
482
+ concat_layout: Optional[str] = None,
483
  ) -> None:
484
  """GEMM with pre-allocated output tensor."""
485
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
486
  alpha = alpha_tensor if alpha_tensor is not None else alpha
487
+ sr_seed_arg = sr_seed_tensor if sr_seed_tensor is not None else sr_seed
488
  fn(
489
  A,
490
  B,
 
497
  A_idx=A_idx,
498
  batch_idx_permute=batch_idx_permute,
499
  dynamic_scheduler=dynamic_scheduler,
500
+ rounding_mode=rounding_mode,
501
+ sr_seed=sr_seed_arg,
502
+ concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
503
  )
504
 
505
 
 
514
  cu_seqlens_k: Optional[Tensor] = None,
515
  A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
516
  out_dtype: Optional[torch.dtype] = None,
517
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
518
  ) -> Tensor:
519
  """Reference implementation for GEMM with pre-allocated output."""
520
  # The out_dtype argument requires torch >= 2.8
521
  out_dtype = A.dtype if out_dtype is None else out_dtype
522
+ if concat_layout:
523
+ if "A" in concat_layout:
524
+ A = _concat_interleave(A)
525
+ if "B" in concat_layout:
526
+ B = _concat_interleave(B)
527
+ if "bias" in concat_layout and bias is not None:
528
+ bias = _concat_interleave_bias(bias)
529
  if cu_seqlens_m is None and cu_seqlens_k is None:
530
  fn = torch.bmm if A.ndim == 3 else torch.mm
531
  out = fn(A, B, out_dtype=out_dtype, out=out)
 
566
  out *= alpha
567
  if bias is not None:
568
  out += bias
569
+ if concat_layout and "out" in concat_layout:
570
+ # out is n-major (ref allocates contiguous). Split rows (non-contiguous dim).
571
+ out = torch.cat([out[..., ::2, :], out[..., 1::2, :]], dim=-2)
572
  return out
573
 
574
 
 
587
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
588
  dynamic_scheduler: bool = False,
589
  tuned: bool = True,
590
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
591
  ) -> Tensor:
592
  """GEMM with addition and optional output tensor."""
593
  if out is None:
 
612
  alpha = alpha if isinstance(alpha, float) else 1.0
613
  beta_tensor = beta if not isinstance(beta, float) else None
614
  beta = beta if isinstance(beta, float) else 1.0
615
+ alpha_arg = alpha_tensor if alpha_tensor is not None else alpha
616
+ beta_arg = beta_tensor if beta_tensor is not None else beta
617
+ concat_str = ",".join(concat_layout) if concat_layout else None
618
+ if add_to_output:
619
+ gemm_add_inplace(
620
+ A,
621
+ B,
622
+ out,
623
+ alpha=alpha_arg,
624
+ beta=beta_arg,
625
+ cu_seqlens_m=cu_seqlens_m,
626
+ cu_seqlens_k=cu_seqlens_k,
627
+ A_idx=A_idx,
628
+ batch_idx_permute=batch_idx_permute,
629
+ dynamic_scheduler=dynamic_scheduler,
630
+ tuned=tuned,
631
+ concat_layout=concat_str,
632
+ )
633
+ else:
634
+ gemm_add_out(
635
+ A,
636
+ B,
637
+ C,
638
+ out,
639
+ alpha,
640
+ beta,
641
+ alpha_tensor,
642
+ beta_tensor,
643
+ cu_seqlens_m=cu_seqlens_m,
644
+ cu_seqlens_k=cu_seqlens_k,
645
+ A_idx=A_idx,
646
+ batch_idx_permute=batch_idx_permute,
647
+ add_to_output=add_to_output,
648
+ dynamic_scheduler=dynamic_scheduler,
649
+ tuned=tuned,
650
+ concat_layout=concat_str,
651
+ )
652
  return out
653
 
654
 
 
677
  add_to_output: bool = False,
678
  dynamic_scheduler: bool = False,
679
  tuned: bool = True,
680
+ concat_layout: Optional[str] = None,
681
  ) -> None:
682
  """GEMM with addition and pre-allocated output tensor."""
683
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
 
696
  batch_idx_permute=batch_idx_permute,
697
  add_to_output=add_to_output,
698
  dynamic_scheduler=dynamic_scheduler,
699
+ concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
700
  )
701
 
702
 
 
713
  cu_seqlens_k: Optional[Tensor] = None,
714
  A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
715
  out_dtype: Optional[torch.dtype] = None,
716
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
717
  ) -> Tensor:
718
  """Reference implementation for GEMM with addition and pre-allocated output."""
719
+ if concat_layout:
720
+ if "A" in concat_layout:
721
+ A = _concat_interleave(A)
722
+ if "B" in concat_layout:
723
+ B = _concat_interleave(B)
724
+ if "bias" in concat_layout and bias is not None:
725
+ bias = _concat_interleave_bias(bias)
726
+ if "C" in concat_layout:
727
+ C = _concat_interleave(C)
728
  if cu_seqlens_m is None and cu_seqlens_k is None:
729
  if isinstance(alpha, float) and isinstance(beta, float):
730
  out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
 
735
  result = (alpha * (A @ B) + beta * C).to(out_dtype)
736
  if out is not None:
737
  out.copy_(result)
738
+ else:
739
+ out = result
740
  if bias is not None:
741
  bias = bias if A.ndim == 2 else bias.unsqueeze(1)
742
  out += bias
 
776
  out[i].copy_(result)
777
  if bias is not None:
778
  out += bias
779
+ if concat_layout and "out" in concat_layout:
780
+ out = torch.cat([out[..., ::2, :], out[..., 1::2, :]], dim=-2)
781
  return out
782
 
783
 
 
794
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
795
  dynamic_scheduler: bool = False,
796
  tuned: bool = True,
797
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
798
  ) -> None:
799
  """In-place GEMM with addition: out = alpha * A @ B + beta * out.
800
  Args:
 
826
  batch_idx_permute=batch_idx_permute,
827
  dynamic_scheduler=dynamic_scheduler,
828
  tuned=tuned,
829
+ concat_layout=",".join(concat_layout)
830
+ if isinstance(concat_layout, tuple)
831
+ else concat_layout,
832
  )
833
 
834
 
 
855
  batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
856
  dynamic_scheduler: bool = False,
857
  tuned: bool = True,
858
+ concat_layout: Optional[str] = None,
859
  ) -> None:
860
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
861
  alpha = alpha_tensor if alpha_tensor is not None else alpha
 
875
  batch_idx_permute=batch_idx_permute,
876
  add_to_output=add_to_output,
877
  dynamic_scheduler=dynamic_scheduler,
878
+ concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
879
  )
880
 
881
 
 
884
  B: Tensor, # (K, N) or (L, K, N)
885
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
886
  bias: Optional[Tensor] = None, # (N,) or (L, N)
887
+ activation: Activation = None,
888
  preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
889
  postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
890
  out_dtype: Optional[torch.dtype] = None,
 
894
  store_preact: bool = True,
895
  dynamic_scheduler: bool = False,
896
  tuned: bool = True,
897
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
898
  ) -> Tuple[Optional[Tensor], Tensor]:
899
+ """GEMM with activation (or gated activation) and optional output tensors."""
900
+ is_gated = activation in gated_to_pytorch_fn_map
901
  out_dtype = A.dtype if out_dtype is None else out_dtype
902
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
903
  varlen_m = cu_seqlens_m is not None
 
909
  out_shape = (A.shape[0], B.shape[-1])
910
  else:
911
  out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
912
+ postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape
913
  if preact_out is None and store_preact:
914
  preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
915
  if postact_out is None:
916
+ postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
917
+ concat_str = ",".join(concat_layout) if concat_layout else None
918
+ if is_gated:
919
+ gemm_gated_out(
920
+ A,
921
+ B,
922
+ preact_out,
923
+ postact_out,
924
+ C,
925
+ bias,
926
+ activation,
927
+ cu_seqlens_m,
928
+ A_idx,
929
+ dynamic_scheduler,
930
+ tuned,
931
+ concat_layout=concat_str,
932
+ )
933
+ else:
934
+ gemm_act_out(
935
+ A,
936
+ B,
937
+ preact_out,
938
+ postact_out,
939
+ C,
940
+ bias,
941
+ activation,
942
+ cu_seqlens_m,
943
+ A_idx,
944
+ dynamic_scheduler,
945
+ tuned,
946
+ )
947
  return preact_out, postact_out
948
 
949
 
950
+ gemm_gated = gemm_act
951
+
952
+
953
  @torch.library.custom_op(
954
  add_quack_op_namespace_prefix("gemm_act_out"),
955
  mutates_args=("preact_out", "postact_out"),
 
963
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
964
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
965
  bias: Optional[Tensor] = None, # (N,) or (L, N)
966
+ activation: ActActivation = None,
967
  cu_seqlens_m: Optional[Tensor] = None,
968
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
969
  dynamic_scheduler: bool = False,
 
979
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
980
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
981
  bias: Optional[Tensor] = None, # (N,) or (L, N)
982
+ activation: Activation = None,
983
  cu_seqlens_m: Optional[Tensor] = None,
984
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
985
  out_dtype: Optional[torch.dtype] = None,
986
  postact_dtype: Optional[torch.dtype] = None,
987
  store_preact: bool = True,
988
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
989
  ) -> Tuple[Optional[Tensor], Tensor]:
990
+ is_gated = activation in gated_to_pytorch_fn_map
991
  out_dtype = A.dtype if out_dtype is None else out_dtype
992
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
993
  if C is None:
994
+ preact = gemm_ref(
995
+ A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, concat_layout=concat_layout
996
+ )
997
  else:
998
+ preact = gemm_add_ref(
999
+ A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, concat_layout=concat_layout
1000
+ )
1001
+ if is_gated:
1002
+ # With concat=("B",), gemm_ref already interleaves the output columns,
1003
+ # so we always use the interleaved gate/up split.
1004
+ gate = preact[..., ::2]
1005
+ up = preact[..., 1::2]
1006
+ postact = gated_to_pytorch_fn_map[activation](gate, up).to(postact_dtype)
1007
+ else:
1008
+ postact = act_to_pytorch_fn_map[activation](preact).to(postact_dtype)
1009
+ return preact.to(out_dtype) if store_preact else None, postact
1010
+
1011
+
1012
+ gemm_gated_ref = gemm_act_ref
1013
 
1014
 
1015
  def gemm_dact(
1016
  A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1017
  B: Tensor, # (K, N) or (L, K, N)
1018
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m; or (M, 2*N) for dgated
1019
+ activation: Activation = None,
1020
+ dx_out: Optional[
1021
+ Tensor
1022
+ ] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m; double for gated
1023
  postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1024
  out_dtype: Optional[torch.dtype] = None,
1025
  postact_dtype: Optional[torch.dtype] = None,
1026
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m (dgated only)
1027
+ colvec_reduce: bool = False, # dgated only
1028
  cu_seqlens_m: Optional[Tensor] = None,
1029
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1030
  dynamic_scheduler: bool = True,
1031
  tuned: bool = True,
1032
+ ):
1033
+ """GEMM with activation (or gated activation) gradient and optional output tensors."""
1034
+ is_dgated = activation in gated_to_pytorch_fn_map
1035
  out_dtype = A.dtype if out_dtype is None else out_dtype
1036
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
1037
  varlen_m = cu_seqlens_m is not None
 
1038
  if varlen_m:
1039
  total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1040
+ out_shape = (total_m, B.shape[-1] * 2) if is_dgated else (total_m, B.shape[-1])
1041
  elif A.ndim == 2:
1042
+ out_shape = (A.shape[0], B.shape[-1] * 2) if is_dgated else (A.shape[0], B.shape[-1])
1043
  else:
1044
+ n = B.shape[-1] * 2 if is_dgated else B.shape[-1]
1045
+ out_shape = (A.shape[0], A.shape[-2], n)
1046
+ postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_dgated else out_shape
1047
  if dx_out is None:
1048
  dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1049
  if postact_out is None:
1050
+ postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
1051
+ if is_dgated:
1052
+ colvec_reduce_final = gemm_dgated_out(
1053
+ A,
1054
+ B,
1055
+ PreAct,
1056
+ dx_out,
1057
+ postact_out,
1058
+ colvec_scale,
1059
+ activation,
1060
+ colvec_reduce,
1061
+ cu_seqlens_m,
1062
+ A_idx,
1063
+ dynamic_scheduler,
1064
+ tuned,
1065
+ )
1066
+ if not colvec_reduce:
1067
+ return dx_out, postact_out
1068
+ else:
1069
+ return dx_out, postact_out, colvec_reduce_final
1070
+ else:
1071
+ gemm_dact_out(
1072
+ A,
1073
+ B,
1074
+ PreAct,
1075
+ dx_out,
1076
+ postact_out,
1077
+ activation,
1078
+ cu_seqlens_m,
1079
+ A_idx,
1080
+ dynamic_scheduler,
1081
+ tuned,
1082
+ )
1083
+ return dx_out, postact_out
1084
+
1085
+
1086
+ gemm_dgated = gemm_dact
1087
 
1088
 
1089
  @torch.library.custom_op(
 
1098
  PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1099
  dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1100
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1101
+ activation: ActActivation = None,
1102
  cu_seqlens_m: Optional[Tensor] = None,
1103
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1104
  dynamic_scheduler: bool = True,
 
1110
 
1111
 
1112
  def gemm_dact_ref(
1113
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A
1114
+ B: Tensor, # (K, N) or (L, K, N)
1115
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N); or (M, 2*N) for dgated
1116
+ activation: Activation = None,
1117
  cu_seqlens_m: Optional[Tensor] = None,
1118
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1119
  out_dtype: Optional[torch.dtype] = None,
1120
  postact_dtype: Optional[torch.dtype] = None,
1121
  ) -> Tuple[Tensor, Tensor]:
1122
+ """Reference implementation for GEMM with activation (or gated activation) gradient."""
1123
+ is_dgated = activation in gated_to_pytorch_fn_map
1124
  out_dtype = A.dtype if out_dtype is None else out_dtype
1125
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
1126
  dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
1127
+ if is_dgated:
1128
+ gate = PreAct[..., ::2]
1129
+ up = PreAct[..., 1::2]
1130
+ gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
1131
+ gate.requires_grad_(True)
1132
+ up.requires_grad_(True)
1133
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
1134
+ dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
1135
+ gate.requires_grad_(gate_requires_grad)
1136
+ up.requires_grad_(up_requires_grad)
1137
+ dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
1138
+ return dx.to(out_dtype), postact.to(postact_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
  else:
1140
+ postact = act_to_pytorch_fn_map[activation](PreAct)
1141
+ if activation is None:
1142
+ dx = dout
1143
+ else:
1144
+ PreAct_requires_grad = PreAct.requires_grad
1145
+ PreAct.requires_grad_(True)
1146
+ postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
1147
+ dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
1148
+ PreAct.requires_grad_(PreAct_requires_grad)
1149
+ return dx.to(out_dtype), postact.to(postact_dtype)
1150
 
 
 
 
 
 
 
 
 
 
 
 
1151
 
1152
+ gemm_dgated_ref = gemm_dact_ref
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1153
 
1154
 
1155
  @torch.library.custom_op(
 
1182
  tile_count_semaphore = (
1183
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
1184
  )
1185
+ sm = get_device_capacity(A.device)[0]
1186
+ # We want square tile per cluster
1187
+ tile_m, tile_n, cluster_m, pingpong = {
1188
+ 9: (128, 256, 2, False),
1189
+ 10: (256, 256, 2, False),
1190
+ 11: (256, 256, 2, False),
1191
+ 12: (128, 128, 1, True),
1192
+ }[sm]
1193
+ gemm_symmetric_dispatch(
1194
  A,
1195
  B,
1196
  out if out is not None else None,
1197
  C if C is not None else None,
1198
  tile_count_semaphore,
1199
+ tile_M=tile_m,
1200
+ tile_N=tile_n,
1201
+ cluster_M=cluster_m,
1202
  cluster_N=1,
1203
+ pingpong=pingpong,
1204
  persistent=True,
1205
+ is_dynamic_persistent=sm >= 10,
1206
  max_swizzle_size=8,
1207
  alpha=alpha,
1208
  beta=beta,
 
1238
  return out
1239
 
1240
 
1241
+ @autotune(
1242
+ configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")],
1243
+ key=["activation", "dynamic_scheduler"],
1244
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
1245
+ )
1246
+ def gemm_gated_tuned(
1247
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1248
+ A: Tensor,
1249
+ B: Tensor, # (K, N) or (L, K, N)
1250
+ # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
1251
+ preact_out: Optional[Tensor],
1252
+ postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
1253
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1254
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
1255
+ activation: GatedActivation = "swiglu",
1256
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
1257
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1258
+ dynamic_scheduler: bool = False,
1259
+ config: Optional[GemmConfig] = None,
1260
+ concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up]
1261
+ ) -> None:
1262
+ if config is None:
1263
+ config = default_config(A.device)
1264
+ varlen_m = cu_seqlens_m is not None
1265
+ if varlen_m:
1266
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
1267
+ if A.ndim == 2 and not varlen_m:
1268
+ A = A.unsqueeze(0) # (1, M, K)
1269
+ B = B.mT # (N, K) or (L, N, K)
1270
+ if B.ndim == 2:
1271
+ B = B.unsqueeze(0) # (1, N, K)
1272
+ if C is not None and C.ndim == 2 and not varlen_m:
1273
+ C = C.unsqueeze(0) # (1, M, N)
1274
+ if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
1275
+ D = preact_out.unsqueeze(0)
1276
+ else:
1277
+ D = preact_out
1278
+ if postact_out.ndim == 2 and not varlen_m:
1279
+ PostAct = postact_out.unsqueeze(0)
1280
+ else:
1281
+ PostAct = postact_out
1282
+ if bias is not None and bias.ndim == 1:
1283
+ bias = bias.unsqueeze(0) # (L, N)
1284
+ if concat_layout and "bias" in concat_layout:
1285
+ if bias is not None and bias.dtype.itemsize >= 4:
1286
+ bias_key = "mColVecBroadcast" if config.swap_ab else "mRowVecBroadcast"
1287
+ concat_layout = tuple(bias_key if k == "bias" else k for k in concat_layout)
1288
+ else:
1289
+ concat_layout = tuple(k for k in concat_layout if k != "bias")
1290
+ if bias is not None:
1291
+ bias = _concat_interleave_bias(bias)
1292
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
1293
+ tile_count_semaphore = (
1294
+ torch.zeros(1, dtype=torch.int32, device=A.device)
1295
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
1296
+ else None
1297
+ )
1298
+ gemm_act_dispatch(
1299
+ A if not config.swap_ab else B,
1300
+ B if not config.swap_ab else A,
1301
+ (D if not config.swap_ab else D.mT) if D is not None else None,
1302
+ (C if not config.swap_ab else C.mT) if C is not None else None,
1303
+ PostAct if not config.swap_ab else PostAct.mT,
1304
+ tile_count_semaphore,
1305
+ activation,
1306
+ config.tile_m,
1307
+ config.tile_n,
1308
+ config.cluster_m,
1309
+ config.cluster_n,
1310
+ config.pingpong,
1311
+ persistent=True,
1312
+ is_dynamic_persistent=dynamic_scheduler,
1313
+ max_swizzle_size=config.max_swizzle_size,
1314
+ rowvec_bias=bias if not config.swap_ab else None,
1315
+ colvec_bias=bias if config.swap_ab else None,
1316
+ cu_seqlens_m=cu_seqlens_m,
1317
+ A_idx=A_idx,
1318
+ use_tma_gather=config.use_tma_gather,
1319
+ concat_layout=concat_layout,
1320
+ )
1321
+
1322
+
1323
+ def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs):
1324
+ kwargs = named_args | kwargs
1325
+ # if there's colvec_scale or colvec_reduce, don't swap_AB
1326
+ if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False):
1327
+ configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
1328
+ return prune_invalid_gemm_configs(configs, named_args, **kwargs)
1329
+
1330
+
1331
+ @autotune(
1332
+ configs=[AutotuneConfig(config=c) for c in get_all_configs("dgated")],
1333
+ key=["activation", "colvec_reduce", "dynamic_scheduler"],
1334
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs},
1335
+ )
1336
+ def gemm_dgated_tuned(
1337
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1338
+ A: Tensor,
1339
+ B: Tensor, # (K, N) or (L, K, N)
1340
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1341
+ dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1342
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1343
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
1344
+ activation: GatedActivation = "swiglu",
1345
+ # whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m
1346
+ colvec_reduce: bool = False,
1347
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
1348
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1349
+ dynamic_scheduler: bool = True,
1350
+ config: Optional[GemmConfig] = None,
1351
+ ) -> Optional[Tensor]:
1352
+ if config is None:
1353
+ config = default_config(A.device)
1354
+ varlen_m = cu_seqlens_m is not None
1355
+ if varlen_m:
1356
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
1357
+ og_ndim_2 = A.ndim == 2 and not varlen_m
1358
+ if A.ndim == 2 and not varlen_m:
1359
+ A = A.unsqueeze(0) # (1, M, K)
1360
+ B = B.mT # (N, K) or (L, N, K)
1361
+ if B.ndim == 2:
1362
+ B = B.unsqueeze(0) # (1, N, K)
1363
+ if PreAct.ndim == 2 and not varlen_m:
1364
+ PreAct = PreAct.unsqueeze(0) # (1, M, 2*N)
1365
+ if dx_out.ndim == 2 and not varlen_m:
1366
+ D = dx_out.unsqueeze(0)
1367
+ else:
1368
+ D = dx_out
1369
+ if postact_out.ndim == 2 and not varlen_m:
1370
+ PostAct = postact_out.unsqueeze(0)
1371
+ else:
1372
+ PostAct = postact_out
1373
+ if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m:
1374
+ colvec_scale = colvec_scale.unsqueeze(0) # (L, N)
1375
+ if colvec_scale is not None:
1376
+ assert not config.swap_ab, "colvec_scale not supported with swap_ab"
1377
+ if colvec_reduce:
1378
+ tile_n = config.tile_n
1379
+ shape_n = (B.shape[-2] + tile_n - 1) // tile_n
1380
+ if varlen_m:
1381
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1382
+ colvec_shape = (total_m, shape_n)
1383
+ else:
1384
+ colvec_shape = (A.shape[0], A.shape[-2], shape_n)
1385
+ colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device)
1386
+ else:
1387
+ colvec_reduce_partial = None
1388
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
1389
+ tile_count_semaphore = (
1390
+ torch.zeros(1, dtype=torch.int32, device=A.device)
1391
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
1392
+ else None
1393
+ )
1394
+ gemm_dact_dispatch(
1395
+ A if not config.swap_ab else B,
1396
+ B if not config.swap_ab else A,
1397
+ D if not config.swap_ab else D.mT,
1398
+ PreAct if not config.swap_ab else PreAct.mT,
1399
+ PostAct if not config.swap_ab else PostAct.mT,
1400
+ tile_count_semaphore,
1401
+ activation,
1402
+ config.tile_m,
1403
+ config.tile_n,
1404
+ config.cluster_m,
1405
+ config.cluster_n,
1406
+ config.pingpong,
1407
+ persistent=True,
1408
+ is_dynamic_persistent=dynamic_scheduler,
1409
+ max_swizzle_size=config.max_swizzle_size,
1410
+ colvec_scale=colvec_scale,
1411
+ colvec_reduce=colvec_reduce_partial,
1412
+ cu_seqlens_m=cu_seqlens_m,
1413
+ A_idx=A_idx,
1414
+ use_tma_gather=config.use_tma_gather,
1415
+ )
1416
+ if colvec_reduce:
1417
+ colvec_reduce_final = colvec_reduce_partial.sum(dim=-1)
1418
+ if og_ndim_2:
1419
+ colvec_reduce_final = colvec_reduce_final.squeeze(0)
1420
+ else:
1421
+ colvec_reduce_final = None
1422
+ return colvec_reduce_final
1423
+
1424
+
1425
+ @torch.library.custom_op(
1426
+ add_quack_op_namespace_prefix("gemm_gated_out"),
1427
+ mutates_args=("preact_out", "postact_out"),
1428
+ device_types="cuda",
1429
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True, str? concat_layout=None) -> ()",
1430
+ )
1431
+ def gemm_gated_out(
1432
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1433
+ B: Tensor, # (K, N) or (L, K, N)
1434
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1435
+ postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
1436
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1437
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
1438
+ activation: GatedActivation = "swiglu",
1439
+ cu_seqlens_m: Optional[Tensor] = None,
1440
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1441
+ dynamic_scheduler: bool = False,
1442
+ tuned: bool = True,
1443
+ concat_layout: Optional[str] = None,
1444
+ ) -> None:
1445
+ """GEMM with gated activation and pre-allocated output tensors."""
1446
+ fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None)
1447
+ fn(
1448
+ A,
1449
+ B,
1450
+ preact_out,
1451
+ postact_out,
1452
+ C,
1453
+ bias,
1454
+ activation,
1455
+ cu_seqlens_m,
1456
+ A_idx,
1457
+ dynamic_scheduler,
1458
+ concat_layout=tuple(concat_layout.split(",")) if concat_layout else None,
1459
+ )
1460
+
1461
+
1462
+ @torch.library.custom_op(
1463
+ add_quack_op_namespace_prefix("gemm_dgated_out"),
1464
+ mutates_args=("dx_out", "postact_out"),
1465
+ device_types="cuda",
1466
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a!) dx_out, Tensor(b!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor",
1467
+ )
1468
+ def gemm_dgated_out(
1469
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1470
+ B: Tensor, # (K, N) or (L, K, N)
1471
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1472
+ dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1473
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1474
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
1475
+ activation: GatedActivation = "swiglu",
1476
+ colvec_reduce: bool = False,
1477
+ cu_seqlens_m: Optional[Tensor] = None,
1478
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1479
+ dynamic_scheduler: bool = True,
1480
+ tuned: bool = True,
1481
+ ) -> Tensor:
1482
+ """GEMM with gated activation gradient and pre-allocated output tensors."""
1483
+ fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None)
1484
+ result = fn(
1485
+ A,
1486
+ B,
1487
+ PreAct,
1488
+ dx_out,
1489
+ postact_out,
1490
+ colvec_scale,
1491
+ activation,
1492
+ colvec_reduce,
1493
+ cu_seqlens_m,
1494
+ A_idx,
1495
+ dynamic_scheduler,
1496
+ )
1497
+ if result is None: # Have to return a tensor, not None, to make torch compile happy
1498
+ return torch.empty(0, device=A.device, dtype=torch.float32)
1499
+ return result
1500
+
1501
+
1502
+ @torch.library.register_fake(add_quack_op_namespace_prefix("gemm_dgated_out"))
1503
+ def gemm_dgated_out_fake(
1504
+ A: Tensor,
1505
+ B: Tensor,
1506
+ PreAct: Tensor,
1507
+ dx_out: Tensor,
1508
+ postact_out: Tensor,
1509
+ colvec_scale: Optional[Tensor] = None,
1510
+ activation: str = "swiglu",
1511
+ colvec_reduce: bool = False,
1512
+ cu_seqlens_m: Optional[Tensor] = None,
1513
+ A_idx: Optional[Tensor] = None,
1514
+ dynamic_scheduler: bool = True,
1515
+ tuned: bool = True,
1516
+ ) -> Tensor:
1517
+ _precompile_default_config(
1518
+ gemm_dgated_tuned,
1519
+ A,
1520
+ B,
1521
+ PreAct,
1522
+ dx_out,
1523
+ postact_out,
1524
+ colvec_scale=colvec_scale,
1525
+ activation=activation,
1526
+ colvec_reduce=colvec_reduce,
1527
+ cu_seqlens_m=cu_seqlens_m,
1528
+ A_idx=A_idx,
1529
+ dynamic_scheduler=dynamic_scheduler,
1530
+ )
1531
+ if not colvec_reduce:
1532
+ return torch.empty(0, dtype=torch.float32, device=A.device)
1533
+ else:
1534
+ if cu_seqlens_m is not None:
1535
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1536
+ out_shape = (total_m,)
1537
+ elif A.ndim == 2:
1538
+ out_shape = (A.shape[0],)
1539
+ else:
1540
+ out_shape = (A.shape[0], A.shape[-2])
1541
+ return torch.empty(out_shape, dtype=torch.float32, device=A.device)
1542
+
1543
+
1544
+ def _precompile_default_config(autotuned_fn, *args, **kwargs):
1545
+ """Compile the default config in COMPILE_ONLY mode.
1546
+
1547
+ Checks COMPILE_ONLY flag and SymInt guard, then calls the unwrapped function with
1548
+ config=None (which selects the default config), triggering compilation (exports .o)
1549
+ without benchmarking or kernel launch.
1550
+ Tests use tuned=False which also selects the default config, so this is sufficient.
1551
+ """
1552
+ from .cache_utils import COMPILE_ONLY
1553
+
1554
+ A = args[0] if args else kwargs.get("A")
1555
+ if not COMPILE_ONLY or A is None or isinstance(A.shape[0], torch.SymInt):
1556
+ return
1557
+ try:
1558
+ autotuned_fn.fn(*args, config=None, **kwargs)
1559
+ except Exception:
1560
+ pass
1561
+
1562
+
1563
+ @gemm_add_inplace_op.register_fake
1564
+ def gemm_add_inplace_fake(
1565
+ A: Tensor,
1566
+ B: Tensor,
1567
+ out: Tensor,
1568
+ alpha: float = 1.0,
1569
+ beta: float = 1.0,
1570
+ alpha_tensor: Optional[Tensor] = None,
1571
+ beta_tensor: Optional[Tensor] = None,
1572
+ cu_seqlens_m: Optional[Tensor] = None,
1573
+ cu_seqlens_k: Optional[Tensor] = None,
1574
+ A_idx: Optional[Tensor] = None,
1575
+ batch_idx_permute: Optional[Tensor] = None,
1576
+ dynamic_scheduler: bool = False,
1577
+ tuned: bool = True,
1578
+ ) -> None:
1579
+ alpha_val = alpha_tensor if alpha_tensor is not None else alpha
1580
+ beta_val = beta_tensor if beta_tensor is not None else beta
1581
+ add_to_output = isinstance(beta_val, float) and beta_val == 1.0 and cu_seqlens_m is None
1582
+ _precompile_default_config(
1583
+ gemm_tuned,
1584
+ A,
1585
+ B,
1586
+ out,
1587
+ out if not add_to_output else None,
1588
+ alpha=alpha_val,
1589
+ beta=beta_val,
1590
+ cu_seqlens_m=cu_seqlens_m,
1591
+ cu_seqlens_k=cu_seqlens_k,
1592
+ A_idx=A_idx,
1593
+ batch_idx_permute=batch_idx_permute,
1594
+ add_to_output=add_to_output,
1595
+ dynamic_scheduler=dynamic_scheduler,
1596
+ )
1597
+
1598
+
1599
+ def _register_precompile_fake(custom_op, autotuned_fn, rewrite=None):
1600
+ """Register a fake that precompiles the default config in COMPILE_ONLY mode.
1601
+
1602
+ For custom_ops that forward args to their autotuned fn. Binds all args by name,
1603
+ strips 'tuned', applies optional rewrite(kw), then calls _precompile_default_config.
1604
+ PyTorch normalizes all custom_op args to positional, so we use inspect.signature
1605
+ to recover keyword names.
1606
+ """
1607
+ import inspect
1608
+
1609
+ sig = inspect.signature(custom_op._init_fn)
1610
+
1611
+ @custom_op.register_fake
1612
+ def _fake(*args, **kwargs):
1613
+ bound = sig.bind(*args, **kwargs)
1614
+ bound.apply_defaults()
1615
+ kw = dict(bound.arguments)
1616
+ kw.pop("tuned", None)
1617
+ if rewrite is not None:
1618
+ rewrite(kw)
1619
+ _precompile_default_config(autotuned_fn, **kw)
1620
+
1621
+
1622
+ def _rewrite_merge_alpha(kwargs):
1623
+ """Merge alpha_tensor into alpha for gemm_tuned; add C=None."""
1624
+ at = kwargs.pop("alpha_tensor", None)
1625
+ if at is not None:
1626
+ kwargs["alpha"] = at
1627
+ kwargs.setdefault("C", None)
1628
+
1629
+
1630
+ def _rewrite_merge_alpha_beta(kwargs):
1631
+ """Merge alpha_tensor/beta_tensor into alpha/beta for gemm_tuned."""
1632
+ at = kwargs.pop("alpha_tensor", None)
1633
+ if at is not None:
1634
+ kwargs["alpha"] = at
1635
+ bt = kwargs.pop("beta_tensor", None)
1636
+ if bt is not None:
1637
+ kwargs["beta"] = bt
1638
+
1639
+
1640
+ _register_precompile_fake(gemm_out, gemm_tuned, rewrite=_rewrite_merge_alpha)
1641
+ _register_precompile_fake(gemm_add_out, gemm_tuned, rewrite=_rewrite_merge_alpha_beta)
1642
+ _register_precompile_fake(gemm_act_out, gemm_act_tuned)
1643
+ _register_precompile_fake(gemm_dact_out, gemm_dact_tuned)
1644
+ _register_precompile_fake(gemm_gated_out, gemm_gated_tuned)
1645
+
1646
+
1647
+ @gemm_symmetric_out.register_fake
1648
+ def gemm_symmetric_out_fake(
1649
+ A: Tensor,
1650
+ B: Tensor,
1651
+ out: Tensor,
1652
+ C: Optional[Tensor] = None,
1653
+ dynamic_scheduler: bool = False,
1654
+ alpha: float = 1.0,
1655
+ beta: float = 1.0,
1656
+ ) -> None:
1657
+ from .cache_utils import COMPILE_ONLY
1658
+
1659
+ if not COMPILE_ONLY or isinstance(A.shape[0], torch.SymInt):
1660
+ return
1661
+ # gemm_symmetric is not autotuned, compile the single fixed config directly
1662
+ sm = get_device_capacity(A.device)[0]
1663
+ tile_m = 256 if sm == 10 else 128
1664
+ tile_n = 128 if sm == 12 else 256
1665
+ cluster_m = 1 if sm == 12 else 2
1666
+ try:
1667
+ gemm_symmetric_dispatch(
1668
+ A.unsqueeze(0) if A.ndim == 2 else A,
1669
+ (B.mT.unsqueeze(0) if B.ndim == 2 else B.mT),
1670
+ out.unsqueeze(0) if out.ndim == 2 else out,
1671
+ (C.unsqueeze(0) if C.ndim == 2 else C) if C is not None else None,
1672
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None,
1673
+ tile_M=tile_m,
1674
+ tile_N=tile_n,
1675
+ cluster_M=cluster_m,
1676
+ cluster_N=1,
1677
+ pingpong=False,
1678
+ persistent=True,
1679
+ max_swizzle_size=8,
1680
+ alpha=alpha,
1681
+ beta=beta,
1682
+ )
1683
+ except Exception:
1684
+ pass
1685
+
1686
+
1687
+ ## ── gemm_rms ────────────────────────────────────────────────────────────────
1688
+
1689
+
1690
+ def _prune_gemm_rms_configs(configs, named_args: dict, **kwargs):
1691
+ """ColVecReduce requires no swap_ab."""
1692
+ configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
1693
+ return prune_invalid_gemm_configs(configs, named_args | kwargs)
1694
+
1695
+
1696
+ @autotune(
1697
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
1698
+ key=["dynamic_scheduler"],
1699
+ prune_configs_by={"early_config_prune": _prune_gemm_rms_configs},
1700
+ )
1701
+ def _gemm_rms_tuned(
1702
+ A: Tensor, # (M, K) or (L, M, K)
1703
+ B: Tensor, # (K, N) or (L, K, N)
1704
+ out: Tensor, # (M, N) or (L, M, N)
1705
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N)
1706
+ norm_weight: Optional[Tensor] = None, # (N,) or (L, N)
1707
+ eps: float = 1e-6,
1708
+ dynamic_scheduler: bool = False,
1709
+ config: Optional[GemmConfig] = None,
1710
+ ) -> Tensor:
1711
+ if config is None:
1712
+ config = default_config(A.device)
1713
+ og_ndim_2 = A.ndim == 2
1714
+ N = B.shape[-1]
1715
+ if A.ndim == 2:
1716
+ A = A.unsqueeze(0)
1717
+ B = B.mT
1718
+ if B.ndim == 2:
1719
+ B = B.unsqueeze(0)
1720
+ if out.ndim == 2:
1721
+ out = out.unsqueeze(0)
1722
+ if C is not None and C.ndim == 2:
1723
+ C = C.unsqueeze(0)
1724
+ if norm_weight is not None and norm_weight.ndim == 1:
1725
+ norm_weight = norm_weight.unsqueeze(0) # (L, N)
1726
+ # Allocate partial reduction buffer
1727
+ tile_n = config.tile_n
1728
+ n_tiles = (N + tile_n - 1) // tile_n
1729
+ colvec_reduce = torch.empty(
1730
+ (A.shape[0], A.shape[1], n_tiles), dtype=torch.float32, device=A.device
1731
+ )
1732
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
1733
+ tile_count_semaphore = (
1734
+ torch.zeros(1, dtype=torch.int32, device=A.device)
1735
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
1736
+ else None
1737
+ )
1738
+ gemm_sq_reduce_dispatch(
1739
+ A,
1740
+ B,
1741
+ out,
1742
+ C,
1743
+ colvec_reduce,
1744
+ tile_count_semaphore,
1745
+ config.tile_m,
1746
+ config.tile_n,
1747
+ config.cluster_m,
1748
+ config.cluster_n,
1749
+ config.pingpong,
1750
+ persistent=True,
1751
+ is_dynamic_persistent=dynamic_scheduler,
1752
+ max_swizzle_size=config.max_swizzle_size,
1753
+ rowvec=norm_weight,
1754
+ )
1755
+ # Final reduction: rstd = rsqrt(sum(partials) / N + eps)
1756
+ scale = 1.0 / N
1757
+ flat_reduce = colvec_reduce.reshape(-1, n_tiles)
1758
+ rstd_flat = rms_final_reduce(flat_reduce, scale=scale, eps=eps)
1759
+ rstd = rstd_flat.reshape(A.shape[:-1])
1760
+ if og_ndim_2:
1761
+ rstd = rstd.squeeze(0)
1762
+ return rstd
1763
+
1764
+
1765
+ @torch.library.custom_op(
1766
+ add_quack_op_namespace_prefix("gemm_rms_out"),
1767
+ mutates_args=("out",),
1768
+ device_types="cuda",
1769
+ schema="(Tensor A, Tensor B, Tensor(a!) out, Tensor? C=None, Tensor? norm_weight=None, float eps=1e-6, bool dynamic_scheduler=False, bool tuned=True) -> Tensor",
1770
+ )
1771
+ def _gemm_rms_out(
1772
+ A: Tensor,
1773
+ B: Tensor,
1774
+ out: Tensor,
1775
+ C: Optional[Tensor] = None,
1776
+ norm_weight: Optional[Tensor] = None,
1777
+ eps: float = 1e-6,
1778
+ dynamic_scheduler: bool = False,
1779
+ tuned: bool = True,
1780
+ ) -> Tensor:
1781
+ """GEMM + RMS + optional rowvec scaling.
1782
+
1783
+ D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D_out = D_raw * norm_weight.
1784
+ """
1785
+ fn = _gemm_rms_tuned if tuned else partial(_gemm_rms_tuned.fn, config=None)
1786
+ return fn(
1787
+ A,
1788
+ B,
1789
+ out,
1790
+ C=C,
1791
+ norm_weight=norm_weight,
1792
+ eps=eps,
1793
+ dynamic_scheduler=dynamic_scheduler,
1794
+ )
1795
+
1796
+
1797
+ @torch.library.register_fake(add_quack_op_namespace_prefix("gemm_rms_out"))
1798
+ def _gemm_rms_out_fake(
1799
+ A: Tensor,
1800
+ B: Tensor,
1801
+ out: Tensor,
1802
+ C: Optional[Tensor] = None,
1803
+ norm_weight: Optional[Tensor] = None,
1804
+ eps: float = 1e-6,
1805
+ dynamic_scheduler: bool = False,
1806
+ tuned: bool = True,
1807
+ ) -> Tensor:
1808
+ _precompile_default_config(
1809
+ _gemm_rms_tuned,
1810
+ A,
1811
+ B,
1812
+ out,
1813
+ C=C,
1814
+ norm_weight=norm_weight,
1815
+ eps=eps,
1816
+ dynamic_scheduler=dynamic_scheduler,
1817
+ )
1818
+ rstd_shape = A.shape[:-1]
1819
+ return torch.empty(rstd_shape, dtype=torch.float32, device=A.device)
1820
+
1821
+
1822
+ def gemm_rms_ref(
1823
+ A: Tensor,
1824
+ B: Tensor,
1825
+ C: Optional[Tensor] = None,
1826
+ norm_weight: Optional[Tensor] = None,
1827
+ eps: float = 1e-6,
1828
+ ) -> Tuple[Tensor, Tensor]:
1829
+ """Reference: D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D = D_raw * norm_weight."""
1830
+ fn = torch.bmm if A.ndim == 3 else torch.mm
1831
+ D = fn(A, B)
1832
+ if C is not None:
1833
+ D = D + C
1834
+ rstd = torch.rsqrt(D.float().square().mean(dim=-1) + eps)
1835
+ if norm_weight is not None:
1836
+ D = D * norm_weight
1837
+ return D, rstd
1838
+
1839
+
1840
+ def gemm_rms(
1841
+ A: Tensor, # (M, K) or (L, M, K)
1842
+ B: Tensor, # (K, N) or (L, K, N)
1843
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N)
1844
+ norm_weight: Optional[Tensor] = None, # (N,) or (L, N)
1845
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N)
1846
+ out_dtype: Optional[torch.dtype] = None,
1847
+ eps: float = 1e-6,
1848
+ dynamic_scheduler: bool = False,
1849
+ tuned: bool = True,
1850
+ ) -> Tuple[Tensor, Tensor]:
1851
+ """GEMM + RMS statistics + optional rowvec scaling.
1852
+
1853
+ D_raw = A @ B (+ C), rstd = rsqrt(mean(D_raw^2) + eps), D_out = D_raw * norm_weight.
1854
+ Returns (D_out, rstd).
1855
+ """
1856
+ out_dtype = A.dtype if out_dtype is None else out_dtype
1857
+ N = B.shape[-1]
1858
+ if out is None:
1859
+ out_shape = (*A.shape[:-1], N)
1860
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1861
+ rstd = _gemm_rms_out(
1862
+ A,
1863
+ B,
1864
+ out,
1865
+ C=C,
1866
+ norm_weight=norm_weight,
1867
+ eps=eps,
1868
+ dynamic_scheduler=dynamic_scheduler,
1869
+ tuned=tuned,
1870
+ )
1871
+ return out, rstd
1872
+
1873
+
1874
+ ## ── gemm_norm_act ─────────────────────────────────────────────────────────────
1875
+
1876
+
1877
+ @autotune(
1878
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
1879
+ key=["activation", "dynamic_scheduler"],
1880
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
1881
+ )
1882
+ def gemm_norm_act_tuned(
1883
+ A: Tensor, # (M, K) or (L, M, K)
1884
+ B: Tensor, # (K, N) or (L, K, N)
1885
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N) — None if not storing preact
1886
+ postact_out: Tensor, # (M, N) or (L, M, N)
1887
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N)
1888
+ rstd: Optional[Tensor] = None, # (M,) or (L, M)
1889
+ activation: ActActivation = None,
1890
+ dynamic_scheduler: bool = False,
1891
+ config: Optional[GemmConfig] = None,
1892
+ ) -> None:
1893
+ if config is None:
1894
+ config = default_config(A.device)
1895
+ if A.ndim == 2:
1896
+ A = A.unsqueeze(0)
1897
+ B = B.mT
1898
+ if B.ndim == 2:
1899
+ B = B.unsqueeze(0)
1900
+ if C is not None and C.ndim == 2:
1901
+ C = C.unsqueeze(0)
1902
+ if preact_out is not None and preact_out.ndim == 2:
1903
+ D = preact_out.unsqueeze(0)
1904
+ else:
1905
+ D = preact_out
1906
+ if postact_out.ndim == 2:
1907
+ PostAct = postact_out.unsqueeze(0)
1908
+ else:
1909
+ PostAct = postact_out
1910
+ if rstd is not None and rstd.ndim == 1:
1911
+ rstd = rstd.unsqueeze(0) # (L, M)
1912
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
1913
+ tile_count_semaphore = (
1914
+ torch.zeros(1, dtype=torch.int32, device=A.device)
1915
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
1916
+ else None
1917
+ )
1918
+ gemm_norm_act_dispatch(
1919
+ A if not config.swap_ab else B,
1920
+ B if not config.swap_ab else A,
1921
+ (D if not config.swap_ab else D.mT) if D is not None else None,
1922
+ (C if not config.swap_ab else C.mT) if C is not None else None,
1923
+ PostAct if not config.swap_ab else PostAct.mT,
1924
+ tile_count_semaphore,
1925
+ activation,
1926
+ config.tile_m,
1927
+ config.tile_n,
1928
+ config.cluster_m,
1929
+ config.cluster_n,
1930
+ config.pingpong,
1931
+ persistent=True,
1932
+ is_dynamic_persistent=dynamic_scheduler,
1933
+ max_swizzle_size=config.max_swizzle_size,
1934
+ colvec=rstd if not config.swap_ab else None,
1935
+ rowvec=rstd if config.swap_ab else None,
1936
+ )
1937
+
1938
+
1939
+ @autotune(
1940
+ configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")],
1941
+ key=["activation", "dynamic_scheduler"],
1942
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
1943
+ )
1944
+ def gemm_norm_gated_tuned(
1945
+ A: Tensor, # (M, K) or (L, M, K)
1946
+ B: Tensor, # (K, N) or (L, K, N)
1947
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N)
1948
+ postact_out: Tensor, # (M, N//2) or (L, M, N//2)
1949
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N)
1950
+ rstd: Optional[Tensor] = None, # (M,) or (L, M)
1951
+ activation: GatedActivation = "swiglu",
1952
+ dynamic_scheduler: bool = False,
1953
+ config: Optional[GemmConfig] = None,
1954
+ ) -> None:
1955
+ if config is None:
1956
+ config = default_config(A.device)
1957
+ if A.ndim == 2:
1958
+ A = A.unsqueeze(0)
1959
+ B = B.mT
1960
+ if B.ndim == 2:
1961
+ B = B.unsqueeze(0)
1962
+ if C is not None and C.ndim == 2:
1963
+ C = C.unsqueeze(0)
1964
+ if preact_out is not None and preact_out.ndim == 2:
1965
+ D = preact_out.unsqueeze(0)
1966
+ else:
1967
+ D = preact_out
1968
+ if postact_out.ndim == 2:
1969
+ PostAct = postact_out.unsqueeze(0)
1970
+ else:
1971
+ PostAct = postact_out
1972
+ if rstd is not None and rstd.ndim == 1:
1973
+ rstd = rstd.unsqueeze(0) # (L, M)
1974
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
1975
+ tile_count_semaphore = (
1976
+ torch.zeros(1, dtype=torch.int32, device=A.device)
1977
+ if dynamic_scheduler and get_device_capacity(A.device)[0] == 9
1978
+ else None
1979
+ )
1980
+ gemm_norm_act_dispatch(
1981
+ A if not config.swap_ab else B,
1982
+ B if not config.swap_ab else A,
1983
+ (D if not config.swap_ab else D.mT) if D is not None else None,
1984
+ (C if not config.swap_ab else C.mT) if C is not None else None,
1985
+ PostAct if not config.swap_ab else PostAct.mT,
1986
+ tile_count_semaphore,
1987
+ activation,
1988
+ config.tile_m,
1989
+ config.tile_n,
1990
+ config.cluster_m,
1991
+ config.cluster_n,
1992
+ config.pingpong,
1993
+ persistent=True,
1994
+ is_dynamic_persistent=dynamic_scheduler,
1995
+ max_swizzle_size=config.max_swizzle_size,
1996
+ colvec=rstd if not config.swap_ab else None,
1997
+ rowvec=rstd if config.swap_ab else None,
1998
+ )
1999
+
2000
+
2001
+ @torch.library.custom_op(
2002
+ add_quack_op_namespace_prefix("gemm_norm_act_out"),
2003
+ mutates_args=("preact_out", "postact_out"),
2004
+ device_types="cuda",
2005
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? rstd=None, str? activation=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
2006
+ )
2007
+ def gemm_norm_act_out(
2008
+ A: Tensor,
2009
+ B: Tensor,
2010
+ preact_out: Optional[Tensor],
2011
+ postact_out: Tensor,
2012
+ C: Optional[Tensor] = None,
2013
+ rstd: Optional[Tensor] = None,
2014
+ activation: ActActivation = None,
2015
+ dynamic_scheduler: bool = False,
2016
+ tuned: bool = True,
2017
+ ) -> None:
2018
+ fn = gemm_norm_act_tuned if tuned else partial(gemm_norm_act_tuned.fn, config=None)
2019
+ fn(A, B, preact_out, postact_out, C, rstd, activation, dynamic_scheduler)
2020
+
2021
+
2022
+ @torch.library.register_fake(add_quack_op_namespace_prefix("gemm_norm_act_out"))
2023
+ def _gemm_norm_act_out_fake(
2024
+ A,
2025
+ B,
2026
+ preact_out,
2027
+ postact_out,
2028
+ C=None,
2029
+ rstd=None,
2030
+ activation=None,
2031
+ dynamic_scheduler=False,
2032
+ tuned=True,
2033
+ ) -> None:
2034
+ pass
2035
+
2036
+
2037
+ @torch.library.custom_op(
2038
+ add_quack_op_namespace_prefix("gemm_norm_gated_out"),
2039
+ mutates_args=("preact_out", "postact_out"),
2040
+ device_types="cuda",
2041
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? rstd=None, str activation='swiglu', bool dynamic_scheduler=False, bool tuned=True) -> ()",
2042
+ )
2043
+ def gemm_norm_gated_out(
2044
+ A: Tensor,
2045
+ B: Tensor,
2046
+ preact_out: Optional[Tensor],
2047
+ postact_out: Tensor,
2048
+ C: Optional[Tensor] = None,
2049
+ rstd: Optional[Tensor] = None,
2050
+ activation: GatedActivation = "swiglu",
2051
+ dynamic_scheduler: bool = False,
2052
+ tuned: bool = True,
2053
+ ) -> None:
2054
+ fn = gemm_norm_gated_tuned if tuned else partial(gemm_norm_gated_tuned.fn, config=None)
2055
+ fn(A, B, preact_out, postact_out, C, rstd, activation, dynamic_scheduler)
2056
+
2057
+
2058
+ @torch.library.register_fake(add_quack_op_namespace_prefix("gemm_norm_gated_out"))
2059
+ def _gemm_norm_gated_out_fake(
2060
+ A,
2061
+ B,
2062
+ preact_out,
2063
+ postact_out,
2064
+ C=None,
2065
+ rstd=None,
2066
+ activation="swiglu",
2067
+ dynamic_scheduler=False,
2068
+ tuned=True,
2069
+ ) -> None:
2070
+ pass
2071
+
2072
+
2073
+ def gemm_norm_act(
2074
+ A: Tensor, # (M, K) or (L, M, K)
2075
+ B: Tensor, # (K, N) or (L, K, N)
2076
+ rstd: Optional[Tensor] = None, # (M,) or (L, M)
2077
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) — residual
2078
+ activation: Activation = None,
2079
+ preact_out: Optional[Tensor] = None,
2080
+ postact_out: Optional[Tensor] = None,
2081
+ out_dtype: Optional[torch.dtype] = None,
2082
+ postact_dtype: Optional[torch.dtype] = None,
2083
+ store_preact: bool = False,
2084
+ dynamic_scheduler: bool = False,
2085
+ tuned: bool = True,
2086
+ ) -> Tuple[Optional[Tensor], Tensor]:
2087
+ """GEMM + normalize + activation: PostAct = act((A @ B + C) * rstd).
2088
+
2089
+ rstd is a column vector (M,).
2090
+ Returns (preact, postact) where preact is the normalized value before activation.
2091
+ """
2092
+ is_gated = activation in gated_to_pytorch_fn_map
2093
+ out_dtype = A.dtype if out_dtype is None else out_dtype
2094
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
2095
+ if A.ndim == 2:
2096
+ out_shape = (A.shape[0], B.shape[-1])
2097
+ else:
2098
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
2099
+ postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape
2100
+ if preact_out is None and store_preact:
2101
+ preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
2102
+ if postact_out is None:
2103
+ postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
2104
+ if is_gated:
2105
+ gemm_norm_gated_out(
2106
+ A,
2107
+ B,
2108
+ preact_out,
2109
+ postact_out,
2110
+ C,
2111
+ rstd,
2112
+ activation,
2113
+ dynamic_scheduler,
2114
+ tuned,
2115
+ )
2116
+ else:
2117
+ gemm_norm_act_out(
2118
+ A,
2119
+ B,
2120
+ preact_out,
2121
+ postact_out,
2122
+ C,
2123
+ rstd,
2124
+ activation,
2125
+ dynamic_scheduler,
2126
+ tuned,
2127
+ )
2128
+ return preact_out, postact_out
2129
+
2130
+
2131
+ gemm_norm_gated = gemm_norm_act
2132
+
2133
+
2134
+ def gemm_norm_act_ref(
2135
+ A: Tensor,
2136
+ B: Tensor,
2137
+ rstd: Optional[Tensor] = None, # (M,) or (L, M)
2138
+ C: Optional[Tensor] = None,
2139
+ activation: Activation = None,
2140
+ store_preact: bool = False,
2141
+ out_dtype: Optional[torch.dtype] = None,
2142
+ postact_dtype: Optional[torch.dtype] = None,
2143
+ ) -> Tuple[Optional[Tensor], Tensor]:
2144
+ """Reference: preact = (A @ B + C) * rstd, postact = act(preact)."""
2145
+ is_gated = activation in gated_to_pytorch_fn_map
2146
+ out_dtype = A.dtype if out_dtype is None else out_dtype
2147
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
2148
+ fn = torch.bmm if A.ndim == 3 else torch.mm
2149
+ D = fn(A, B)
2150
+ if C is not None:
2151
+ D = D + C
2152
+ if rstd is not None:
2153
+ D = D * rstd.unsqueeze(-1)
2154
+ preact = D.to(out_dtype) if store_preact else None
2155
+ _act_map = {**act_to_pytorch_fn_map, "silu": F.silu}
2156
+ if is_gated:
2157
+ gate = D[..., ::2]
2158
+ up = D[..., 1::2]
2159
+ postact = gated_to_pytorch_fn_map[activation](gate, up).to(postact_dtype)
2160
+ else:
2161
+ postact = _act_map[activation](D).to(postact_dtype)
2162
+ return preact, postact
2163
+
2164
+
2165
+ gemm_norm_gated_ref = gemm_norm_act_ref
2166
+
2167
+
2168
  # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
2169
  # try:
2170
  # from torch._inductor.fx_passes.reinplace import InplaceableOp
build/torch-cuda/quack/gemm_norm_act.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
+ # GEMM + normalize (multiply by colvec and rowvec) + activation:
3
+ # PostAct = act((A @ B + C) * colvec * rowvec)
4
+ # colvec is typically rstd (M,), rowvec is typically norm_weight (N,).
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ from torch import Tensor
9
+
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+ from cutlass import Int32, const_expr
13
+ from cutlass.cute.runtime import make_ptr
14
+
15
+ from .compile_utils import make_fake_tensor as fake_tensor
16
+ from .cute_dsl_utils import (
17
+ torch2cute_dtype_map,
18
+ get_device_capacity,
19
+ get_max_active_clusters,
20
+ )
21
+ from .gemm_sm90 import GemmSm90
22
+ from .gemm_sm100 import GemmSm100
23
+ from .gemm_sm120 import GemmSm120
24
+ from .gemm_act import GemmActMixin, GemmGatedMixin
25
+ from .epi_ops import vec_multiply
26
+ from .activation import act_fn_map, gate_fn_map
27
+ from .cache_utils import jit_cache
28
+ from .rounding import RoundingMode
29
+ from .gemm_tvm_ffi_utils import (
30
+ get_major,
31
+ perm3d_single,
32
+ make_scheduler_args,
33
+ make_varlen_args,
34
+ make_fake_scheduler_args,
35
+ make_fake_varlen_args,
36
+ div_for_dtype,
37
+ make_fake_gemm_tensors,
38
+ compile_gemm_kernel,
39
+ )
40
+ from . import utils as utils
41
+
42
+
43
+ class GemmNormActMixin(GemmActMixin):
44
+ """GEMM + normalize + activation: PostAct = act((A @ B + C) * colvec * rowvec).
45
+
46
+ colvec is typically rstd (M,), rowvec is typically norm_weight (N,).
47
+ D stores the normalized (pre-activation) value, PostAct stores act(D).
48
+ """
49
+
50
+ @cute.jit
51
+ def epi_visit_subtile(
52
+ self,
53
+ params: GemmActMixin.EpilogueParams,
54
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
55
+ tRS_rD: cute.Tensor,
56
+ tRS_rC: Optional[cute.Tensor] = None,
57
+ ) -> Optional[cute.Tensor]:
58
+ tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
59
+ tDrColVec = epi_loop_tensors["mColVecBroadcast"]
60
+ # Load accumulator and apply alpha/beta/C
61
+ rD = tRS_rD.load()
62
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
63
+ alpha = utils.load_scalar_or_pointer(params.alpha)
64
+ rD *= alpha
65
+ if const_expr(tRS_rC is not None):
66
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
67
+ rD += tRS_rC.load().to(tRS_rD.element_type)
68
+ else:
69
+ beta = utils.load_scalar_or_pointer(params.beta)
70
+ rD += beta * tRS_rC.load().to(tRS_rD.element_type)
71
+ tRS_rD.store(rD)
72
+ # Multiply by colvec (rstd) and rowvec (norm_weight)
73
+ vec_multiply(self, tRS_rD, tDrColVec, tDrRowVec)
74
+ # Apply activation
75
+ if const_expr(params.act_fn is not None):
76
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rD.layout.shape, self.acc_dtype)
77
+ if const_expr(self.arch < 100):
78
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
79
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
80
+ else:
81
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
82
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
83
+ (tRS_rD[2 * i], tRS_rD[2 * i + 1])
84
+ )
85
+ else:
86
+ tRS_rPostAct = tRS_rD
87
+ return tRS_rPostAct
88
+
89
+
90
+ class GemmNormActSm90(GemmNormActMixin, GemmSm90):
91
+ pass
92
+
93
+
94
+ class GemmNormActSm100(GemmNormActMixin, GemmSm100):
95
+ pass
96
+
97
+
98
+ class GemmNormActSm120(GemmNormActMixin, GemmSm120):
99
+ pass
100
+
101
+
102
+ class GemmNormGatedMixin(GemmGatedMixin):
103
+ """GEMM + normalize + gated activation: PostAct = gated_act((A @ B + C) * colvec * rowvec)."""
104
+
105
+ @cute.jit
106
+ def epi_visit_subtile(
107
+ self,
108
+ params: GemmActMixin.EpilogueParams,
109
+ epi_loop_tensors: Tuple[cute.Tensor, ...],
110
+ tRS_rD: cute.Tensor,
111
+ tRS_rC: Optional[cute.Tensor] = None,
112
+ ) -> Optional[cute.Tensor]:
113
+ tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
114
+ tDrColVec = epi_loop_tensors["mColVecBroadcast"]
115
+ # Load accumulator and apply alpha/beta/C
116
+ rD = tRS_rD.load()
117
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
118
+ alpha = utils.load_scalar_or_pointer(params.alpha)
119
+ rD *= alpha
120
+ if const_expr(tRS_rC is not None):
121
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
122
+ rD += tRS_rC.load().to(tRS_rD.element_type)
123
+ else:
124
+ beta = utils.load_scalar_or_pointer(params.beta)
125
+ rD += beta * tRS_rC.load().to(tRS_rD.element_type)
126
+ tRS_rD.store(rD)
127
+ # Multiply by colvec (rstd) and rowvec (norm_weight)
128
+ vec_multiply(self, tRS_rD, tDrColVec, tDrRowVec)
129
+ # Gated activation on normalized D
130
+ tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout)
131
+ tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype)
132
+ if const_expr(self.arch < 100):
133
+ for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
134
+ tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1])
135
+ else:
136
+ for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
137
+ tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
138
+ (tRS_rD[4 * i], tRS_rD[4 * i + 2]),
139
+ (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3]),
140
+ )
141
+ return tRS_rPostAct
142
+
143
+
144
+ class GemmNormGatedSm90(GemmNormGatedMixin, GemmSm90):
145
+ pass
146
+
147
+
148
+ class GemmNormGatedSm100(GemmNormGatedMixin, GemmSm100):
149
+ pass
150
+
151
+
152
+ class GemmNormGatedSm120(GemmNormGatedMixin, GemmSm120):
153
+ pass
154
+
155
+
156
+ @jit_cache
157
+ def _compile_gemm_norm_act(
158
+ a_dtype,
159
+ b_dtype,
160
+ d_dtype,
161
+ c_dtype,
162
+ postact_dtype,
163
+ a_major,
164
+ b_major,
165
+ d_major,
166
+ c_major,
167
+ postact_major,
168
+ tile_shape_mn,
169
+ cluster_shape_mnk,
170
+ pingpong,
171
+ persistent,
172
+ is_dynamic_persistent,
173
+ activation,
174
+ rowvec_dtype,
175
+ colvec_dtype,
176
+ colvec_ndim,
177
+ varlen_m,
178
+ gather_A,
179
+ device_capacity,
180
+ gemm_cls_name,
181
+ rounding_mode=RoundingMode.RN,
182
+ sr_seed_mode=0,
183
+ ):
184
+ sm_to_cls = {
185
+ "norm_act": {
186
+ 9: GemmNormActSm90,
187
+ 10: GemmNormActSm100,
188
+ 11: GemmNormActSm100,
189
+ 12: GemmNormActSm120,
190
+ },
191
+ "norm_gated": {
192
+ 9: GemmNormGatedSm90,
193
+ 10: GemmNormGatedSm100,
194
+ 11: GemmNormGatedSm100,
195
+ 12: GemmNormGatedSm120,
196
+ },
197
+ }
198
+ GemmCls = sm_to_cls[gemm_cls_name][device_capacity[0]]
199
+ pa_leading = 1 if postact_major == "n" else 0
200
+ mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
201
+ a_dtype,
202
+ b_dtype,
203
+ d_dtype,
204
+ c_dtype,
205
+ a_major,
206
+ b_major,
207
+ d_major,
208
+ c_major,
209
+ varlen_m=varlen_m,
210
+ gather_A=gather_A,
211
+ )
212
+ div_pa = div_for_dtype(postact_dtype)
213
+ pa_n = cute.sym_int() if gemm_cls_name == "norm_gated" else n
214
+ pa_leading_dim = 1 if gemm_cls_name == "norm_gated" else pa_leading
215
+ pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l)
216
+ mPostAct = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa)
217
+
218
+ mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
219
+ if colvec_ndim == 2:
220
+ mColVec = fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4)
221
+ elif colvec_ndim == 1:
222
+ mColVec = fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4)
223
+ else:
224
+ mColVec = None
225
+
226
+ act_fn = act_fn_map[activation] if gemm_cls_name == "norm_act" else gate_fn_map[activation]
227
+
228
+ def fake_scalar(mode, dtype=Int32):
229
+ if mode == 0:
230
+ return None
231
+ elif mode == 1:
232
+ return dtype(0)
233
+ else:
234
+ return make_ptr(dtype, 0, cute.AddressSpace.gmem, assumed_align=4)
235
+
236
+ epi_args = GemmCls.EpilogueArguments(
237
+ mPostAct,
238
+ act_fn,
239
+ mRowVecBroadcast=mRowVec,
240
+ mColVecBroadcast=mColVec,
241
+ rounding_mode=rounding_mode,
242
+ sr_seed=fake_scalar(sr_seed_mode),
243
+ )
244
+ scheduler_args = make_fake_scheduler_args(
245
+ (is_dynamic_persistent and device_capacity[0] == 9), False, l
246
+ )
247
+ varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None)
248
+ return compile_gemm_kernel(
249
+ GemmCls,
250
+ a_dtype,
251
+ tile_shape_mn,
252
+ cluster_shape_mnk,
253
+ pingpong,
254
+ persistent,
255
+ gather_A,
256
+ is_dynamic_persistent,
257
+ device_capacity,
258
+ mA,
259
+ mB,
260
+ mD,
261
+ mC,
262
+ epi_args,
263
+ scheduler_args,
264
+ varlen_args,
265
+ )
266
+
267
+
268
+ def gemm_norm_act_fn(
269
+ A: Tensor, # (l, m, k) or (total_m, k) if varlen_m
270
+ B: Tensor, # (l, n, k)
271
+ D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
272
+ C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
273
+ PostAct: Tensor, # (l, m, n) or (total_m, n//2) if gated
274
+ tile_count_semaphore: Optional[Tensor],
275
+ activation: Optional[str],
276
+ tile_M: int,
277
+ tile_N: int,
278
+ cluster_M: int,
279
+ cluster_N: int,
280
+ pingpong: bool = False,
281
+ persistent: bool = True,
282
+ is_dynamic_persistent: bool = False,
283
+ max_swizzle_size: int = 8,
284
+ rowvec: Optional[Tensor] = None, # (l, n) — norm_weight
285
+ colvec: Optional[Tensor] = None, # (l, m) or (total_m,) — rstd
286
+ cu_seqlens_m: Optional[Tensor] = None,
287
+ A_idx: Optional[Tensor] = None,
288
+ rounding_mode: int = RoundingMode.RN,
289
+ sr_seed: int | Tensor = 0,
290
+ ) -> None:
291
+ if activation in gate_fn_map:
292
+ gemm_cls_name = "norm_gated"
293
+ else:
294
+ assert activation in act_fn_map, f"Unsupported activation {activation}"
295
+ gemm_cls_name = "norm_act"
296
+
297
+ varlen_m = cu_seqlens_m is not None
298
+ gather_A = A_idx is not None
299
+ if varlen_m:
300
+ assert persistent, "varlen_m requires persistent=True"
301
+ assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
302
+ if D is not None:
303
+ assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
304
+ assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
305
+ if gather_A:
306
+ assert cu_seqlens_m is not None, "gather_A requires varlen"
307
+ assert cluster_N == 1, "gather_A requires cluster_N=1"
308
+
309
+ A_p = perm3d_single(A, varlen_m)
310
+ B_p = perm3d_single(B)
311
+ D_p = perm3d_single(D, varlen_m)
312
+ C_p = perm3d_single(C, varlen_m)
313
+ PostAct_p = perm3d_single(PostAct, varlen_m)
314
+
315
+ a_major = get_major(A_p, "m", "k")
316
+ b_major = get_major(B_p, "n", "k")
317
+ d_major = get_major(D_p, "m", "n") if D_p is not None else None
318
+ c_major = get_major(C_p, "m", "n") if C_p is not None else None
319
+ postact_major = get_major(PostAct_p, "m", "n")
320
+
321
+ a_dtype = torch2cute_dtype_map[A.dtype]
322
+ b_dtype = torch2cute_dtype_map[B.dtype]
323
+ d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None
324
+ c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None
325
+ postact_dtype = torch2cute_dtype_map[PostAct.dtype]
326
+ colvec_ndim = colvec.ndim if colvec is not None else 0
327
+
328
+ device_capacity = get_device_capacity(A.device)
329
+ assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
330
+ if rounding_mode == RoundingMode.RS:
331
+ assert device_capacity[0] == 10, "Stochastic rounding requires SM100"
332
+
333
+ if is_dynamic_persistent and device_capacity[0] == 9:
334
+ assert tile_count_semaphore is not None, (
335
+ "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
336
+ )
337
+
338
+ sr_seed_mode = (
339
+ 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
340
+ )
341
+ compiled_fn = _compile_gemm_norm_act(
342
+ a_dtype,
343
+ b_dtype,
344
+ d_dtype,
345
+ c_dtype,
346
+ postact_dtype,
347
+ a_major,
348
+ b_major,
349
+ d_major,
350
+ c_major,
351
+ postact_major,
352
+ (tile_M, tile_N),
353
+ (cluster_M, cluster_N, 1),
354
+ pingpong,
355
+ persistent,
356
+ is_dynamic_persistent,
357
+ activation,
358
+ torch2cute_dtype_map[rowvec.dtype] if rowvec is not None else None,
359
+ torch2cute_dtype_map[colvec.dtype] if colvec is not None else None,
360
+ colvec_ndim,
361
+ varlen_m,
362
+ gather_A,
363
+ device_capacity,
364
+ gemm_cls_name,
365
+ rounding_mode=rounding_mode,
366
+ sr_seed_mode=sr_seed_mode,
367
+ )
368
+
369
+ from .cache_utils import COMPILE_ONLY
370
+
371
+ if COMPILE_ONLY:
372
+ return
373
+
374
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
375
+
376
+ def scalar_arg(scalar, mode, dtype=Int32):
377
+ if mode == 0:
378
+ return None
379
+ elif mode == 1:
380
+ return dtype(scalar)
381
+ else:
382
+ return scalar.data_ptr()
383
+
384
+ epi_args = GemmActMixin.EpilogueArguments(
385
+ PostAct_p,
386
+ None, # act_fn is Constexpr, pass None at call time
387
+ mRowVecBroadcast=rowvec,
388
+ mColVecBroadcast=colvec,
389
+ rounding_mode=None,
390
+ sr_seed=scalar_arg(sr_seed, sr_seed_mode),
391
+ )
392
+ scheduler_args = make_scheduler_args(
393
+ max_active_clusters, max_swizzle_size, tile_count_semaphore
394
+ )
395
+ varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
396
+
397
+ if device_capacity[0] in [10, 11]:
398
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
399
+ else:
400
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
build/torch-cuda/quack/gemm_sm100.py CHANGED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/quack/gemm_sm120.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
+ # Based on the cute-dsl example:
3
+ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py
4
+ # SM120-style GEMM using warp-level MMA (MmaF16BF16Op) + ldmatrix.
5
+ # Unlike SM90 WGMMA (which reads A/B from SMEM directly), warp-level MMA
6
+ # requires explicit SMEM→RMEM copies via ldmatrix before each MMA instruction.
7
+
8
+ # This is a work in progress and not very optimized.
9
+
10
+ import math
11
+ from typing import Tuple, Type, Callable, Optional
12
+ from functools import partial
13
+
14
+ import cutlass
15
+ import cutlass.cute as cute
16
+ import cutlass.pipeline as pipeline
17
+ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
18
+ from cutlass.cute.nvgpu import cpasync, warp
19
+ from cutlass import Int32, Boolean, const_expr
20
+
21
+ from .varlen_utils import VarlenManager
22
+ from .pipeline import make_pipeline_state
23
+ from . import copy_utils
24
+ from .gemm_sm90 import GemmSm90, NamedBarrierGemm
25
+ from . import sm80_utils
26
+
27
+
28
+ class GemmSm120(GemmSm90):
29
+ """SM120-style GEMM using warp-level MMA instead of WGMMA.
30
+
31
+ Key differences from SM90:
32
+ - Uses MmaF16BF16Op (warp-level, 32 threads) instead of WGMMA (warp-group, 128 threads)
33
+ - Requires explicit SMEM→RMEM copy via ldmatrix before MMA
34
+ - Thread config: num_mma_warps regular warps + 1 DMA warp
35
+ - Pingpong: 2 warp groups of (2,2,1), each processing alternating tiles
36
+ - No fp8 support (warp-level MMA only supports fp16/bf16)
37
+ """
38
+
39
+ arch = 120
40
+
41
+ def __init__(
42
+ self,
43
+ acc_dtype: Type[cutlass.Numeric],
44
+ a_dtype: Type[cutlass.Numeric],
45
+ tile_shape_mn: Tuple[int, int],
46
+ cluster_shape_mnk: Tuple[int, int, int],
47
+ pingpong: bool = False,
48
+ is_persistent: bool = True,
49
+ gather_A: bool = False,
50
+ use_pdl: bool = True,
51
+ ):
52
+ # Don't call super().__init__ — we set up our own config
53
+ self.acc_dtype = acc_dtype
54
+ self.pingpong = pingpong
55
+ self.is_persistent = is_persistent
56
+ self.use_clc_persistence = False
57
+ self.use_pdl = use_pdl
58
+ self.fp8_slow_accum = False
59
+ self.gather_A = gather_A
60
+ if self.pingpong:
61
+ assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
62
+ if gather_A:
63
+ assert cluster_shape_mnk[1] == 1
64
+
65
+ self.cluster_shape_mnk = cluster_shape_mnk
66
+ tile_M, tile_N = tile_shape_mn
67
+ self.cta_tile_shape_mnk = (tile_M, tile_N, 1)
68
+
69
+ # Pingpong: 2 warp groups each with (2,2,1) atom layout
70
+ # Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout
71
+ self.mma_inst_mnk = (16, 8, 16)
72
+ if not self.pingpong:
73
+ self.atom_layout_mnk = (4, 2, 1)
74
+ else:
75
+ self.atom_layout_mnk = (2, 2, 1)
76
+ # num_mma_warps = total warps doing MMA (both warp groups in pingpong)
77
+ self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
78
+ # For compatibility with SM90 code that uses warp groups
79
+ self.num_threads_per_warp_group = 128
80
+ assert self.num_mma_warps % 4 == 0
81
+ self.mma_warp_groups = self.num_mma_warps // 4
82
+ if self.pingpong:
83
+ assert self.mma_warp_groups == 2
84
+ # threads_per_cta must be a multiple of 128 (warp group size) so that
85
+ # the DMA warp's setmaxnreg.dec.sync has a complete warp group to sync with.
86
+ self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
87
+
88
+ self.num_mcast_ctas_a = cluster_shape_mnk[1]
89
+ if gather_A:
90
+ assert self.num_mcast_ctas_a == 1
91
+ self.num_mcast_ctas_b = cluster_shape_mnk[0]
92
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
93
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
94
+
95
+ self.occupancy = 1
96
+ self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}")
97
+
98
+ # In pingpong, only 1 warp group (4 warps) participates in epilogue at a time
99
+ self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
100
+ self.epilogue_barrier = pipeline.NamedBarrier(
101
+ barrier_id=int(NamedBarrierGemm.Epilogue),
102
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
103
+ )
104
+ self.num_ab_load_warps = 1 if not self.gather_A else 4
105
+ self.ab_load_warp_id = self.num_mma_warps
106
+
107
+ if not self.gather_A:
108
+ self.num_regs_load = 40
109
+ self.num_regs_mma = 232
110
+ else:
111
+ self.num_regs_load = 56
112
+ self.num_regs_mma = 224
113
+
114
+ self.ab_stage = None
115
+ self.epi_stage = None
116
+ self.a_smem_layout_staged = None
117
+ self.b_smem_layout_staged = None
118
+ self.epi_smem_layout_staged = None
119
+ self.epi_tile = None
120
+ self.shared_storage = None
121
+ self.buffer_align_bytes = 1024
122
+
123
+ def _setup_tiled_mma(self):
124
+ """Set up warp-level MMA (MmaF16BF16Op) and tile K dimension."""
125
+ op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk)
126
+ tC = cute.make_layout(self.atom_layout_mnk)
127
+ permutation_mnk = (
128
+ self.atom_layout_mnk[0] * self.mma_inst_mnk[0],
129
+ self.atom_layout_mnk[1] * self.mma_inst_mnk[1] * 2,
130
+ self.atom_layout_mnk[2] * self.mma_inst_mnk[2],
131
+ )
132
+ self.tiled_mma = cute.make_tiled_mma(op, tC, permutation_mnk=permutation_mnk)
133
+ tile_k = self.mma_inst_mnk[2] * 4
134
+ self.cta_tile_shape_mnk = (
135
+ self.cta_tile_shape_mnk[0],
136
+ self.cta_tile_shape_mnk[1],
137
+ tile_k,
138
+ )
139
+
140
+ # __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline,
141
+ # make_sched_pipeline, epilogue are all inherited from GemmSm90.
142
+
143
+ @cute.kernel
144
+ def kernel(
145
+ self,
146
+ tiled_mma: cute.TiledMma,
147
+ tma_atom_a: Optional[cute.CopyAtom],
148
+ mA_mkl: cute.Tensor,
149
+ tma_atom_b: cute.CopyAtom,
150
+ mB_nkl: cute.Tensor,
151
+ tma_atom_d: Optional[cute.CopyAtom],
152
+ mD_mnl: Optional[cute.Tensor],
153
+ tma_atom_c: Optional[cute.CopyAtom],
154
+ mC_mnl: Optional[cute.Tensor],
155
+ epilogue_params,
156
+ varlen_params: VarlenManager.Params,
157
+ cluster_layout_mnk: cute.Layout,
158
+ a_smem_layout: cute.ComposedLayout,
159
+ b_smem_layout: cute.ComposedLayout,
160
+ epi_smem_layout: cute.ComposedLayout,
161
+ epi_c_smem_layout: cute.ComposedLayout,
162
+ tile_sched_params,
163
+ TileSchedulerCls: cutlass.Constexpr[Callable],
164
+ trace_ptr: Optional[cutlass.Int64] = None,
165
+ ):
166
+ from .trace import TraceContext
167
+
168
+ tctx = TraceContext.create(trace_ptr)
169
+
170
+ varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
171
+ varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
172
+ if const_expr(self.gather_A):
173
+ assert varlen_m or varlen_k
174
+ has_D = const_expr(mD_mnl is not None)
175
+ has_C = const_expr(mC_mnl is not None)
176
+
177
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
178
+
179
+ # Prefetch TMA descriptors
180
+ if warp_idx == self.ab_load_warp_id:
181
+ for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
182
+ if const_expr(tma_atom is not None):
183
+ cpasync.prefetch_descriptor(tma_atom)
184
+
185
+ # Allocate shared memory
186
+ smem = cutlass.utils.SmemAllocator()
187
+ storage = smem.allocate(self.shared_storage)
188
+
189
+ ab_pipeline = self.make_ab_pipeline(
190
+ tiled_mma=tiled_mma,
191
+ cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)),
192
+ ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(),
193
+ )
194
+ epi_pipeline = None
195
+ if const_expr(has_C):
196
+ epi_pipeline = self.make_epi_pipeline(
197
+ c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)),
198
+ epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
199
+ )
200
+ sched_pipeline = None
201
+ sched_data = None
202
+ if const_expr(self.is_persistent):
203
+ sched_pipeline = self.make_sched_pipeline(
204
+ cluster_layout_mnk,
205
+ sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
206
+ varlen_k=varlen_k,
207
+ )
208
+ sched_data = storage.sched_data.get_tensor((4, self.sched_stage))
209
+
210
+ # Cluster sync
211
+ pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
212
+
213
+ # SMEM tensors
214
+ sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
215
+ sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
216
+ sD = None
217
+ if const_expr(has_D):
218
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
219
+ sC = None
220
+ if const_expr(has_C):
221
+ sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
222
+ epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
223
+
224
+ varlen_manager = VarlenManager.create(
225
+ varlen_params,
226
+ len_m_static=Int32(
227
+ cute.size(mA_mkl, mode=[0])
228
+ if varlen_k or varlen_params.mAIdx is None
229
+ else varlen_params.mAIdx.shape[0]
230
+ ),
231
+ len_k_static=Int32(cute.size(mA_mkl, mode=[1])),
232
+ )
233
+
234
+ TileSchedulerCls = partial(
235
+ TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline
236
+ )
237
+
238
+ # Cluster wait
239
+ pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1])
240
+
241
+ if warp_idx >= self.ab_load_warp_id:
242
+ cute.arch.setmaxregister_decrease(self.num_regs_load)
243
+ if (
244
+ warp_idx >= self.ab_load_warp_id
245
+ and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
246
+ ):
247
+ # Get mcast mask
248
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
249
+ block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
250
+ a_mcast_mask = cute.make_layout_image_mask(
251
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
252
+ )
253
+ b_mcast_mask = cute.make_layout_image_mask(
254
+ cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0
255
+ )
256
+ a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0
257
+ b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
258
+
259
+ # Persistent tile scheduling loop
260
+ is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
261
+ if const_expr(cute.size(cluster_layout_mnk) > 1):
262
+ is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
263
+ tile_scheduler = TileSchedulerCls()
264
+ work_tile = tile_scheduler.initial_work_tile_info()
265
+ ab_producer_state = make_pipeline_state(
266
+ pipeline.PipelineUserType.Producer, self.ab_stage
267
+ )
268
+ while work_tile.is_valid_tile:
269
+ tctx.b("tma_load")
270
+ tile_coord_mnkl = work_tile.tile_idx
271
+ batch_idx = tile_coord_mnkl[3]
272
+ # Local_tile partition global tensors
273
+ copy_A, prefetch_A = None, None
274
+ if const_expr(not self.gather_A):
275
+ mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
276
+ # (bM, bK, RestK)
277
+ gA_mk = cute.local_tile(
278
+ mA_mk,
279
+ cute.select(self.cta_tile_shape_mnk, [0, 2]),
280
+ (tile_coord_mnkl[0], None),
281
+ )
282
+ # TMA load A partition_S/D
283
+ copy_A, _, _ = copy_utils.tma_get_copy_fn(
284
+ tma_atom_a,
285
+ cta_coord=block_in_cluster_coord_mnk[1],
286
+ cta_layout=cute.make_layout(
287
+ cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
288
+ ),
289
+ src_tensor=gA_mk,
290
+ dst_tensor=sA,
291
+ mcast_mask=a_mcast_mask,
292
+ )
293
+ else:
294
+ copy_A, prefetch_A = self._make_gather_A_copy(
295
+ mA_mkl, sA, varlen_manager, tile_coord_mnkl, batch_idx
296
+ )
297
+ # (bN, bK, RestK)
298
+ gB_nk = cute.local_tile(
299
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
300
+ cute.select(self.cta_tile_shape_mnk, [1, 2]),
301
+ (tile_coord_mnkl[1], None),
302
+ )
303
+ # TMA load B partition_S/D
304
+ copy_B, _, _ = copy_utils.tma_get_copy_fn(
305
+ tma_atom_b,
306
+ cta_coord=block_in_cluster_coord_mnk[0],
307
+ cta_layout=cute.make_layout(
308
+ cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
309
+ ),
310
+ src_tensor=gB_nk,
311
+ dst_tensor=sB,
312
+ mcast_mask=b_mcast_mask,
313
+ )
314
+ len_k = varlen_manager.len_k(batch_idx)
315
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
316
+ if const_expr(not self.gather_A):
317
+ ab_producer_state = self.load_AB(
318
+ ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
319
+ )
320
+ else:
321
+ ab_producer_state = self.load_AB_gather_A(
322
+ ab_pipeline,
323
+ ab_producer_state,
324
+ copy_A,
325
+ prefetch_A,
326
+ copy_B,
327
+ k_tile_cnt,
328
+ varlen_m=varlen_m,
329
+ )
330
+ tctx.e("tma_load")
331
+ tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
332
+ work_tile = tile_scheduler.get_current_work()
333
+ # End of persistent scheduler loop
334
+ if const_expr(self.pingpong and not varlen_k):
335
+ # Need to write the tile_idx to smem for the next WG in the pingpong mode
336
+ if is_scheduler_warp:
337
+ tile_scheduler.write_work_tile_to_smem(work_tile)
338
+ work_tile = tile_scheduler.get_current_work()
339
+ ab_pipeline.producer_tail(ab_producer_state)
340
+ if is_scheduler_warp:
341
+ tile_scheduler.producer_tail()
342
+
343
+ # =====================================================================
344
+ # MMA warps
345
+ # =====================================================================
346
+ if warp_idx < self.num_mma_warps:
347
+ cute.arch.setmaxregister_increase(self.num_regs_mma)
348
+ is_tma_warp = Boolean(
349
+ (not self.pingpong and warp_idx == 0)
350
+ or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
351
+ )
352
+ tidx, _, _ = cute.arch.thread_idx()
353
+ # For pingpong, adjust tidx to within-warp-group index
354
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
355
+ if const_expr(self.pingpong):
356
+ tidx = tidx % self.num_threads_per_warp_group
357
+
358
+ # ldmatrix copy atoms for SMEM → RMEM
359
+ atom_copy_ldmatrix_A = cute.make_copy_atom(
360
+ warp.LdMatrix8x8x16bOp(self.a_layout.is_m_major_a(), 4),
361
+ self.a_dtype,
362
+ )
363
+ atom_copy_ldmatrix_B = cute.make_copy_atom(
364
+ warp.LdMatrix8x8x16bOp(self.b_layout.is_n_major_b(), 4),
365
+ self.b_dtype,
366
+ )
367
+ smem_tiled_copy_A = cute.make_tiled_copy_A(atom_copy_ldmatrix_A, tiled_mma)
368
+ smem_tiled_copy_B = cute.make_tiled_copy_B(atom_copy_ldmatrix_B, tiled_mma)
369
+ thr_copy_ldmatrix_A = smem_tiled_copy_A.get_slice(tidx)
370
+ thr_copy_ldmatrix_B = smem_tiled_copy_B.get_slice(tidx)
371
+ tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
372
+ tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
373
+
374
+ # Make fragments
375
+ thr_mma = tiled_mma.get_slice(tidx)
376
+ acc, tCsA, tCsB, tCrA, tCrB = sm80_utils.partition_fragment_ABC(
377
+ thr_mma, self.cta_tile_shape_mnk, sA, sB
378
+ )
379
+
380
+ if const_expr(self.pingpong):
381
+ if warp_group_idx == 0:
382
+ # WG0 needs a start signal at the very beginning
383
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
384
+ self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
385
+
386
+ k_tile_cnt_static = cute.ceil_div(
387
+ cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2]
388
+ )
389
+ c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
390
+
391
+ ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
392
+ epi_store_pipeline = self.make_epi_store_pipeline()
393
+ epi_read_state = make_pipeline_state(
394
+ pipeline.PipelineUserType.Consumer, self.epi_c_stage
395
+ )
396
+ epi_producer_state = make_pipeline_state(
397
+ pipeline.PipelineUserType.Producer, self.epi_c_stage
398
+ )
399
+ tile_scheduler = TileSchedulerCls()
400
+ work_tile = tile_scheduler.initial_work_tile_info()
401
+
402
+ if const_expr(self.pingpong):
403
+ if warp_idx >= 4:
404
+ # Advance 2nd Math WG pipeline states to the end of 1st Math WG
405
+ epi_read_state.advance_iters(c_tile_cnt)
406
+ epi_producer_state.advance_iters(c_tile_cnt)
407
+ if const_expr(not varlen_k):
408
+ ab_read_state.advance_iters(k_tile_cnt_static)
409
+ else:
410
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
411
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
412
+ ab_read_state.advance_iters(k_tile_cnt)
413
+ tile_scheduler.advance_to_next_work()
414
+ work_tile = tile_scheduler.get_current_work()
415
+ while work_tile.is_valid_tile:
416
+ tile_coord_mnkl = work_tile.tile_idx
417
+ batch_idx = tile_coord_mnkl[3]
418
+ len_k = varlen_manager.len_k(batch_idx)
419
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
420
+ acc.fill(0.0)
421
+ if const_expr(self.pingpong):
422
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
423
+ tctx.b("mma")
424
+ ab_read_state = self.mma(
425
+ ab_pipeline,
426
+ ab_read_state,
427
+ tiled_mma,
428
+ acc,
429
+ k_tile_cnt,
430
+ smem_tiled_copy_A,
431
+ smem_tiled_copy_B,
432
+ tCsA_copy_view,
433
+ tCsB_copy_view,
434
+ tCrA,
435
+ tCrB,
436
+ )
437
+ if const_expr(self.pingpong):
438
+ # Cue for next WG's MMA to start
439
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma")
440
+ tctx.e("mma")
441
+
442
+ # ============================================================
443
+ # EPILOGUE — reuse SM90's epilogue flow
444
+ # ============================================================
445
+ if const_expr(self.pingpong):
446
+ self.pingpong_barrier_sync(warp_group_idx, "epi")
447
+ tctx.b("epilogue")
448
+
449
+ copy_D = None
450
+ if const_expr(has_D):
451
+ copy_D, _, _ = self.epilog_gmem_copy_and_partition(
452
+ tma_atom_d,
453
+ varlen_manager.offset_batch_epi(mD_mnl, tile_coord_mnkl[3]),
454
+ self.cta_tile_shape_mnk[:2],
455
+ self.epi_tile,
456
+ sD,
457
+ tile_coord_mnkl,
458
+ )
459
+ copy_C = None
460
+ if const_expr(has_C):
461
+ copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
462
+ tma_atom_c,
463
+ varlen_manager.offset_batch_epi(mC_mnl, tile_coord_mnkl[3]),
464
+ self.cta_tile_shape_mnk[:2],
465
+ self.epi_tile,
466
+ sC,
467
+ tile_coord_mnkl,
468
+ )
469
+ copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
470
+
471
+ d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
472
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
473
+ tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
474
+ )
475
+ tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s, tidx)
476
+ load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
477
+ if const_expr(has_C):
478
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
479
+ tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
480
+ )
481
+ else:
482
+ tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
483
+
484
+ self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
485
+
486
+ epi_read_state, epi_producer_state = self.epilogue(
487
+ epilogue_params,
488
+ epi_smem_tensors,
489
+ epi_pipeline,
490
+ epi_store_pipeline,
491
+ epi_read_state,
492
+ epi_producer_state,
493
+ self.epi_tile,
494
+ load_acc_subtile,
495
+ tRS_rD,
496
+ tRS_rC,
497
+ None, # tiled_copy_t2r, for Sm100 only
498
+ tiled_copy_r2s,
499
+ tRS_sD,
500
+ tiled_copy_s2r,
501
+ tSR_rC,
502
+ tSR_sC,
503
+ copy_D,
504
+ copy_C,
505
+ tile_coord_mnkl,
506
+ varlen_manager,
507
+ self.epilogue_barrier,
508
+ tile_scheduler,
509
+ tidx,
510
+ is_tma_warp,
511
+ )
512
+
513
+ if const_expr(self.pingpong):
514
+ # With pingpong, 2 WGs write two different output tiles to the same smem,
515
+ # so we have to make sure the smem content is done reading before signaling
516
+ # the next WG's epilogue.
517
+ if is_tma_warp:
518
+ epi_store_pipeline.producer_tail()
519
+ self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
520
+ tctx.e("epilogue")
521
+
522
+ if const_expr(not self.pingpong):
523
+ tile_scheduler.advance_to_next_work()
524
+ work_tile = tile_scheduler.get_current_work()
525
+ else: # Skip a tile for pingpong
526
+ # Update starting load/store pipeline states for the next tile
527
+ epi_read_state.advance_iters(c_tile_cnt)
528
+ epi_producer_state.advance_iters(c_tile_cnt)
529
+ # Update starting mainloop pipeline state for the next tile
530
+ if const_expr(not varlen_k):
531
+ ab_read_state.advance_iters(k_tile_cnt_static)
532
+ tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
533
+ work_tile = tile_scheduler.get_current_work()
534
+ else:
535
+ tile_scheduler.advance_to_next_work()
536
+ work_tile = tile_scheduler.get_current_work()
537
+ if work_tile.is_valid_tile:
538
+ len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
539
+ k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
540
+ ab_read_state.advance_iters(k_tile_cnt)
541
+ tile_scheduler.advance_to_next_work()
542
+ work_tile = tile_scheduler.get_current_work()
543
+
544
+ # Wait for D store complete
545
+ if const_expr(not self.pingpong):
546
+ if is_tma_warp:
547
+ epi_store_pipeline.producer_tail()
548
+
549
+ tctx.flush()
550
+
551
+ @cute.jit
552
+ def mma(
553
+ self,
554
+ ab_pipeline: cutlass.pipeline.PipelineAsync,
555
+ ab_read_state: cutlass.pipeline.PipelineState,
556
+ tiled_mma: cute.TiledMma,
557
+ acc: cute.Tensor,
558
+ k_tile_cnt: Int32,
559
+ smem_tiled_copy_A: cute.TiledCopy,
560
+ smem_tiled_copy_B: cute.TiledCopy,
561
+ tCsA_copy_view: cute.Tensor,
562
+ tCsB_copy_view: cute.Tensor,
563
+ tCrA: cute.Tensor,
564
+ tCrB: cute.Tensor,
565
+ ) -> cutlass.pipeline.PipelineState:
566
+ """Warp-level MMA mainloop: ldmatrix SMEM→RMEM + warp MMA."""
567
+ tCrA_copy_view = smem_tiled_copy_A.retile(tCrA)
568
+ tCrB_copy_view = smem_tiled_copy_B.retile(tCrB)
569
+ load_sA = partial(cute.copy, smem_tiled_copy_A)
570
+ load_sB = partial(cute.copy, smem_tiled_copy_B)
571
+
572
+ num_k_blocks = cute.size(tCrA, mode=[2])
573
+ peek_ab_full_status = Boolean(True)
574
+ if 0 < k_tile_cnt:
575
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
576
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
577
+
578
+ # Load first k-block
579
+ tCsA_p = tCsA_copy_view[None, None, None, ab_read_state.index]
580
+ tCsB_p = tCsB_copy_view[None, None, None, ab_read_state.index]
581
+ load_sA(tCsA_p[None, None, 0], tCrA_copy_view[None, None, 0])
582
+ load_sB(tCsB_p[None, None, 0], tCrB_copy_view[None, None, 0])
583
+
584
+ for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
585
+ for k in cutlass.range_constexpr(num_k_blocks):
586
+ k_next = 0 if k + 1 == num_k_blocks else k + 1
587
+ if const_expr(k == num_k_blocks - 1):
588
+ # Don't need to sync_warp: the previous instruction was mma.sync from cute.gemm
589
+ ab_pipeline.consumer_release(ab_read_state)
590
+ ab_read_state.advance()
591
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
592
+ tCsA_p = tCsA_copy_view[None, None, None, ab_read_state.index]
593
+ tCsB_p = tCsB_copy_view[None, None, None, ab_read_state.index]
594
+ ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
595
+ load_sA(tCsA_p[None, None, k_next], tCrA_copy_view[None, None, k_next])
596
+ load_sB(tCsB_p[None, None, k_next], tCrB_copy_view[None, None, k_next])
597
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
598
+
599
+ # Last k-tile (hoisted)
600
+ if 0 < k_tile_cnt:
601
+ for k in cutlass.range_constexpr(num_k_blocks):
602
+ k_next = 0 if k + 1 == num_k_blocks else k + 1
603
+ if const_expr(k == num_k_blocks - 1):
604
+ ab_pipeline.consumer_release(ab_read_state)
605
+ ab_read_state.advance()
606
+ if const_expr(k_next > 0):
607
+ load_sA(tCsA_p[None, None, k_next], tCrA_copy_view[None, None, k_next])
608
+ load_sB(tCsB_p[None, None, k_next], tCrB_copy_view[None, None, k_next])
609
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
610
+
611
+ return ab_read_state
612
+
613
+ def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s, tidx=None):
614
+ """Retile accumulator for epilogue. Warp-level MMA uses tiled_copy_r2s.retile."""
615
+ if tidx is None:
616
+ tidx = cute.arch.thread_idx()[0]
617
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
618
+ self._epi_size_tRS_rD = cute.size(tRS_rD)
619
+ return thr_copy_r2s.retile(acc)
620
+
621
+ @cute.jit
622
+ def epi_load_acc_subtile(self, tRS_rAcc, tRS_rD, epi_idx):
623
+ """Load acc subtile using retile-based flat indexing (warp-level MMA layout)."""
624
+ size_rD = self._epi_size_tRS_rD
625
+ for i in cutlass.range_constexpr(size_rD):
626
+ tRS_rD[i] = tRS_rAcc[epi_idx * size_rD + i]
build/torch-cuda/quack/gemm_sm90.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # Based on the cute-dsl example:
2
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
3
 
@@ -12,20 +13,24 @@ import cuda.bindings.driver as cuda
12
  import cutlass
13
  import cutlass.cute as cute
14
  import cutlass.pipeline as pipeline
 
15
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
16
  import cutlass.utils.hopper_helpers as sm90_utils
17
  from cutlass import Int32, Float32, Float16, Boolean, const_expr
18
- from cutlass.cutlass_dsl import if_generate
19
  from cutlass.utils import LayoutEnum
20
 
21
 
22
- from .cute_dsl_utils import ParamsBase, ArgumentsBase
 
 
 
23
  from .tile_scheduler import (
24
  TileSchedulerOptions,
25
  TileSchedulerArguments,
26
  TileScheduler,
27
  VarlenMTileSchedulerArguments,
28
  VarlenMTileScheduler,
 
29
  )
30
  from .varlen_utils import VarlenArguments, VarlenManager
31
 
@@ -33,6 +38,7 @@ from .varlen_utils import VarlenArguments, VarlenManager
33
  from .pipeline import make_pipeline_state, PipelineTmaCpAsync
34
  from . import copy_utils as copy_utils
35
  from . import sm90_utils as quack_sm90_utils
 
36
 
37
  """
38
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
@@ -122,9 +128,11 @@ class GemmSm90:
122
  """
123
 
124
  arch = 90
125
- num_epi_tensormaps: int = 0
126
 
127
- EpilogueArguments = ArgumentsBase
 
 
 
128
  EpilogueParams = ParamsBase
129
 
130
  def __init__(
@@ -137,6 +145,9 @@ class GemmSm90:
137
  is_persistent: bool = True,
138
  fp8_fast_accum: bool = False,
139
  gather_A: bool = False,
 
 
 
140
  ):
141
  """
142
  Initializes the configuration for a Hopper dense GEMM kernel.
@@ -155,10 +166,15 @@ class GemmSm90:
155
  self.acc_dtype = acc_dtype
156
  self.pingpong = pingpong
157
  self.is_persistent = is_persistent
 
 
 
 
158
  if self.pingpong:
159
  assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
160
  self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
161
  self.gather_A = gather_A
 
162
  if gather_A:
163
  assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
164
 
@@ -224,10 +240,12 @@ class GemmSm90:
224
  self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
225
  self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
226
  self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
 
 
 
 
227
  self.num_ab_load_warps = 1 if not self.gather_A else 4
228
  self.ab_load_warp_id = self.mma_warp_groups * 4
229
- # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1
230
- # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
231
 
232
  regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
233
  math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
@@ -259,20 +277,8 @@ class GemmSm90:
259
  self.shared_storage = None
260
  self.buffer_align_bytes = 1024
261
 
262
- def _setup_attributes(self, epilogue_args: EpilogueArguments):
263
- """Set up configurations that are dependent on GEMM inputs
264
-
265
- This method configures various attributes based on the input tensor properties
266
- (data types, leading dimensions) and kernel settings:
267
- - Configuring tiled MMA
268
- - Computing MMA/cluster/tile shapes
269
- - Computing cluster layout
270
- - Computing multicast CTAs for A/B
271
- - Computing epilogue subtile
272
- - Setting up A/B/C stage counts in shared memory
273
- - Computing A/B/C shared memory layout
274
- """
275
-
276
  self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
277
  self.a_dtype,
278
  self.b_dtype,
@@ -305,6 +311,21 @@ class GemmSm90:
305
  mma_inst_shape_k * mma_inst_tile_k,
306
  )
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
309
 
310
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
@@ -324,8 +345,6 @@ class GemmSm90:
324
  epilogue_args,
325
  cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
326
  self.occupancy,
327
- # epi_smem will reuse smem ab if not persistent.
328
- overlap_sD_sA=not self.is_persistent,
329
  )
330
  self.sched_stage = 2 if self.pingpong else 1
331
 
@@ -357,10 +376,11 @@ class GemmSm90:
357
  mB: cute.Tensor,
358
  mD: Optional[cute.Tensor],
359
  mC: Optional[cute.Tensor],
360
- epilogue_args: ArgumentsBase,
361
  scheduler_args: TileSchedulerOptions,
362
  varlen_args: Optional[VarlenArguments],
363
  stream: cuda.CUstream,
 
364
  ):
365
  """Execute the GEMM operation in steps:
366
  - Setup static attributes
@@ -379,6 +399,14 @@ class GemmSm90:
379
  :type stream: cuda.CUstream
380
  """
381
 
 
 
 
 
 
 
 
 
382
  # setup static attributes before smem/grid/tma computation
383
  self.a_dtype = mA.element_type
384
  self.b_dtype = mB.element_type
@@ -399,18 +427,8 @@ class GemmSm90:
399
  if const_expr(varlen_args is None):
400
  varlen_args = VarlenArguments()
401
  assert (varlen_args.mAIdx is not None) == self.gather_A
402
-
403
- # Assume all strides are divisible by 128 bits except the last stride
404
- new_stride = lambda t: tuple(
405
- cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
406
- for s in t.stride
407
- )
408
- mA, mD = [
409
- cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
410
- if t is not None
411
- else None
412
- for t in (mA, mD)
413
- ]
414
 
415
  self._setup_attributes(epilogue_args)
416
 
@@ -419,13 +437,15 @@ class GemmSm90:
419
  tma_atom_a, tma_tensor_a = None, None
420
  if const_expr(not self.gather_A):
421
  tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
422
- mA,
 
 
423
  a_smem_layout,
424
  (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
425
  self.cluster_shape_mnk[1],
426
  )
427
  tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
428
- mB,
429
  b_smem_layout,
430
  (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
431
  self.cluster_shape_mnk[0],
@@ -438,7 +458,13 @@ class GemmSm90:
438
  tma_atom_d, tma_tensor_d = None, None
439
  if const_expr(mD is not None):
440
  tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
441
- mD,
 
 
 
 
 
 
442
  self.epi_smem_layout_staged,
443
  self.epi_tile,
444
  op_type="store"
@@ -454,16 +480,16 @@ class GemmSm90:
454
  epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
455
  varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
456
 
457
- TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None)
458
- tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args)
 
 
459
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
460
  grid = TileSchedulerCls.get_grid_shape(
461
  tile_sched_params, scheduler_args.max_active_clusters
462
  )
463
 
464
- epi_smem_size = (
465
- cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0
466
- )
467
  epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
468
 
469
  @cute.struct
@@ -471,7 +497,7 @@ class GemmSm90:
471
  ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
472
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
473
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
474
- tile_count: cute.struct.MemRange[Int32, self.sched_stage]
475
  sD: cute.struct.Align[
476
  cute.struct.MemRange[
477
  self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
@@ -516,12 +542,14 @@ class GemmSm90:
516
  self.epi_c_smem_layout_staged,
517
  tile_sched_params,
518
  TileSchedulerCls,
 
519
  ).launch(
520
  grid=grid,
521
  block=[self.threads_per_cta, 1, 1],
522
  cluster=self.cluster_shape_mnk,
523
  stream=stream,
524
  min_blocks_per_mp=1,
 
525
  )
526
  return
527
 
@@ -538,15 +566,16 @@ class GemmSm90:
538
  mD_mnl: Optional[cute.Tensor],
539
  tma_atom_c: Optional[cute.CopyAtom],
540
  mC_mnl: Optional[cute.Tensor],
541
- epilogue_params: ParamsBase,
542
  varlen_params: VarlenManager.Params,
543
  cluster_layout_mnk: cute.Layout,
544
  a_smem_layout: cute.ComposedLayout,
545
  b_smem_layout: cute.ComposedLayout,
546
  epi_smem_layout: cute.ComposedLayout,
547
  epi_c_smem_layout: cute.ComposedLayout,
548
- tile_sched_params: ParamsBase,
549
  TileSchedulerCls: cutlass.Constexpr[Callable],
 
550
  ):
551
  """
552
  GPU device kernel performing the batched GEMM computation.
@@ -575,6 +604,10 @@ class GemmSm90:
575
  :type epi_smem_layout: cute.ComposedLayout
576
  """
577
 
 
 
 
 
578
  varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
579
  varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
580
  assert not (varlen_m and varlen_k)
@@ -585,17 +618,13 @@ class GemmSm90:
585
 
586
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
587
 
588
- # /////////////////////////////////////////////////////////////////////////////
589
- # Prefetch Tma desc
590
- # /////////////////////////////////////////////////////////////////////////////
591
  if warp_idx == self.ab_load_warp_id:
592
  for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
593
  if const_expr(tma_atom is not None):
594
  cpasync.prefetch_descriptor(tma_atom)
595
 
596
- # /////////////////////////////////////////////////////////////////////////////
597
- # Alloc and init AB full/empty + ACC full mbar (pipeline)
598
- # /////////////////////////////////////////////////////////////////////////////
599
  smem = cutlass.utils.SmemAllocator()
600
  storage = smem.allocate(self.shared_storage)
601
 
@@ -611,28 +640,24 @@ class GemmSm90:
611
  epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
612
  )
613
  sched_pipeline = None
614
- tile_count = None
615
- if const_expr(tile_sched_params.tile_count_semaphore is not None):
616
- # Dynamic persistent scheduler
617
  sched_pipeline = self.make_sched_pipeline(
618
  cluster_layout_mnk,
619
  sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
620
  varlen_k=varlen_k,
621
  )
622
- tile_count = storage.tile_count.get_tensor((self.sched_stage,))
 
 
 
623
 
624
- # ///////////////////////////////////////////////////////////////////////////////
625
- # Generate smem tensor A/B
626
- # ///////////////////////////////////////////////////////////////////////////////
627
  sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
628
  sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
629
  sD = None
630
  if const_expr(has_D):
631
- if const_expr(not self.is_persistent):
632
- sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype)
633
- sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer)
634
- else:
635
- sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
636
  sC = None
637
  if const_expr(has_C):
638
  sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
@@ -640,37 +665,32 @@ class GemmSm90:
640
 
641
  varlen_manager = VarlenManager.create(
642
  varlen_params,
643
- has_D,
644
- self.num_epi_tensormaps,
645
  # Only used if not varlen_m
646
  len_m_static=Int32(
647
- mA_mkl.shape[0]
648
  if varlen_k or varlen_params.mAIdx is None
649
  else varlen_params.mAIdx.shape[0]
650
  ),
651
- len_k_static=Int32(mA_mkl.shape[1]),
652
- pingpong=self.pingpong,
653
- warp_idx=warp_idx,
654
  )
655
 
656
  TileSchedulerCls = partial(
657
- TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline
658
  )
659
 
 
 
 
660
  if warp_idx >= self.ab_load_warp_id:
661
- cute.arch.warpgroup_reg_dealloc(self.num_regs_load)
662
  if (
663
  warp_idx >= self.ab_load_warp_id
664
  and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
665
  ):
666
- is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
667
- # initialize tensormap for A & B
668
- varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp)
669
- tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr()
670
- tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr()
671
- # ///////////////////////////////////////////////////////////////////////////////
672
  # Get mcast mask
673
- # ///////////////////////////////////////////////////////////////////////////////
674
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
675
  block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
676
  a_mcast_mask = cute.make_layout_image_mask(
@@ -686,26 +706,17 @@ class GemmSm90:
686
  is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
687
  if const_expr(cute.size(cluster_layout_mnk) > 1):
688
  is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
689
- tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp)
690
  work_tile = tile_scheduler.initial_work_tile_info()
691
  ab_producer_state = make_pipeline_state(
692
  pipeline.PipelineUserType.Producer, self.ab_stage
693
  )
694
- if const_expr(varlen_k):
695
- # wait tensormap initialization complete before update
696
- varlen_manager.fence_tensormap_init()
697
  while work_tile.is_valid_tile:
 
698
  tile_coord_mnkl = work_tile.tile_idx
699
  batch_idx = tile_coord_mnkl[3]
700
- varlen_manager.update_tensormap_AB(
701
- batch_idx,
702
- self.a_layout,
703
- self.b_layout,
704
- is_tma_warp,
705
- )
706
- # ///////////////////////////////////////////////////////////////////////////
707
- # Local_tile partition global tensors
708
- # ///////////////////////////////////////////////////////////////////////////
709
  if const_expr(not self.gather_A):
710
  mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
711
  # (bM, bK, RestK)
@@ -714,37 +725,7 @@ class GemmSm90:
714
  cute.select(self.cta_tile_shape_mnk, [0, 2]),
715
  (tile_coord_mnkl[0], None),
716
  )
717
- else:
718
- mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
719
- if const_expr(varlen_m):
720
- gAIdx = cute.local_tile(
721
- mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
722
- )
723
- # (M, K)
724
- mA_mk = mA_mkl
725
- else:
726
- assert varlen_k
727
- # (tile_K, RestK)
728
- gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
729
- # (tile_M, K)
730
- mA_mk = cute.local_tile(
731
- mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
732
- )
733
- # (bN, bK, RestK)
734
- gB_nk = cute.local_tile(
735
- varlen_manager.offset_batch_B(mB_nkl, batch_idx),
736
- cute.select(self.cta_tile_shape_mnk, [1, 2]),
737
- (tile_coord_mnkl[1], None),
738
- )
739
- # //////////////////////////////////////////////////////////////////////////
740
- # Partition shared tensor for TMA load A/B
741
- # //////////////////////////////////////////////////////////////////////////
742
- varlen_manager.fence_tensormap_update_AB(is_tma_warp)
743
- len_m = varlen_manager.len_m(batch_idx)
744
- len_k = varlen_manager.len_k(batch_idx)
745
- # TMA load A partition_S/D
746
- copy_A = None
747
- if const_expr(not self.gather_A):
748
  copy_A, _, _ = copy_utils.tma_get_copy_fn(
749
  tma_atom_a,
750
  cta_coord=block_in_cluster_coord_mnk[1],
@@ -754,35 +735,17 @@ class GemmSm90:
754
  src_tensor=gA_mk,
755
  dst_tensor=sA,
756
  mcast_mask=a_mcast_mask,
757
- tma_desc_ptr=tma_desc_a_ptr,
758
  )
759
  else:
760
- tiled_copy_A = self._make_gmem_tiled_copy_A(
761
- mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
762
- )
763
- tidx = (
764
- cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
765
  )
766
- thr_copy_A = tiled_copy_A.get_slice(tidx)
767
- copy_A, prefetch_A = None, None
768
- if const_expr(varlen_m):
769
- copy_A = copy_utils.gather_m_get_copy_fn(
770
- thr_copy_A,
771
- mA_mk,
772
- sA,
773
- gAIdx,
774
- limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
775
- limit_k=len_k,
776
- )
777
- else:
778
- copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
779
- thr_copy_A,
780
- mA_mk,
781
- sA,
782
- gAIdx,
783
- limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
784
- limit_k=len_k,
785
- )
786
  # TMA load B partition_S/D
787
  copy_B, _, _ = copy_utils.tma_get_copy_fn(
788
  tma_atom_b,
@@ -793,8 +756,8 @@ class GemmSm90:
793
  src_tensor=gB_nk,
794
  dst_tensor=sB,
795
  mcast_mask=b_mcast_mask,
796
- tma_desc_ptr=tma_desc_b_ptr,
797
  )
 
798
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
799
  if const_expr(not self.gather_A):
800
  ab_producer_state = self.load_AB(
@@ -810,56 +773,47 @@ class GemmSm90:
810
  k_tile_cnt,
811
  varlen_m=varlen_m,
812
  )
813
- tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp)
814
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
815
  work_tile = tile_scheduler.get_current_work()
816
  # End of persistent scheduler loop
817
  if const_expr(self.pingpong and not varlen_k):
818
  # Need to write the tile_idx to smem for the next WG in the pingpong mode
819
- tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
820
- ab_pipeline.producer_tail(ab_producer_state)
 
 
 
821
  if is_scheduler_warp:
822
  tile_scheduler.producer_tail()
823
 
824
  if warp_idx < self.ab_load_warp_id:
825
- cute.arch.warpgroup_reg_alloc(self.num_regs_mma)
826
  is_tma_warp = Boolean(
827
  (not self.pingpong and warp_idx == 0)
828
  or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
829
  )
830
- varlen_manager.init_tensormap_epi(
831
- tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp
832
- )
833
- tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr()
834
- tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs()
835
- # //////////////////////////////////////////////////////////////////////////////
836
- # Partition global tensor for TiledMMA_A/B/C
837
- # //////////////////////////////////////////////////////////////////////////////
838
  tidx, _, _ = cute.arch.thread_idx()
839
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
840
  if const_expr(self.pingpong):
841
  tidx = tidx % self.num_threads_per_warp_group
842
  warp_group_thread_layout = cute.make_layout(
843
- self.mma_warp_groups if not self.pingpong else 1,
844
  stride=self.num_threads_per_warp_group,
845
  )
846
  thr_mma = tiled_mma.get_slice(
847
  warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
848
  )
849
 
850
- # //////////////////////////////////////////////////////////////////////////////
851
- # Make fragments
852
- # //////////////////////////////////////////////////////////////////////////////
853
- tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA))
854
- tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB))
855
-
856
- acc_shape = tiled_mma.partition_shape_C(
857
- cute.select(self.cta_tile_shape_mnk, mode=[0, 1])
858
  )
859
- acc = cute.make_fragment(acc_shape, self.acc_dtype)
860
  acc_slow = None
861
  if const_expr(self.fp8_slow_accum):
862
- acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
 
863
 
864
  if const_expr(self.pingpong):
865
  if warp_group_idx == 0:
@@ -867,7 +821,9 @@ class GemmSm90:
867
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
868
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
869
 
870
- k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
 
 
871
  c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
872
 
873
  ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
@@ -879,10 +835,8 @@ class GemmSm90:
879
  pipeline.PipelineUserType.Producer, self.epi_c_stage
880
  )
881
  tile_scheduler = TileSchedulerCls()
882
- work_tile = None
883
  if const_expr(self.pingpong):
884
- if const_expr(varlen_k):
885
- work_tile = tile_scheduler.initial_work_tile_info()
886
  if warp_idx >= 4:
887
  # Advance 2nd Math WG pipeline states to the end of 1st Math WG
888
  epi_read_state.advance_iters(c_tile_cnt)
@@ -893,58 +847,29 @@ class GemmSm90:
893
  len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
894
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
895
  ab_read_state.advance_iters(k_tile_cnt)
 
896
  tile_scheduler.advance_to_next_work()
897
- if const_expr(varlen_k):
898
- work_tile = tile_scheduler.get_current_work()
899
- if const_expr(not varlen_k):
900
- work_tile = tile_scheduler.initial_work_tile_info()
901
- else:
902
- work_tile = tile_scheduler.initial_work_tile_info()
903
- if const_expr(varlen_m):
904
- # wait tensormap initialization complete before update
905
- varlen_manager.fence_tensormap_init()
906
  while work_tile.is_valid_tile:
907
  tile_coord_mnkl = work_tile.tile_idx
908
  batch_idx = tile_coord_mnkl[3]
909
- epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders(
910
- epilogue_params, varlen_params.cu_seqlens_m, batch_idx
911
- )
912
- varlen_manager.update_tensormap_epi(
913
- batch_idx,
914
- self.d_layout,
915
- epi_shapes,
916
- epi_orders,
917
- is_tma_warp,
918
- )
919
  len_k = varlen_manager.len_k(batch_idx)
920
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
921
- ab_read_state, tiled_mma = self.mma(
922
- ab_pipeline,
923
- ab_read_state,
924
- tiled_mma,
925
- tCrA,
926
- tCrB,
927
- acc,
928
- acc_slow,
929
- k_tile_cnt,
930
- warp_group_idx,
931
  )
932
  if const_expr(varlen_k):
933
  if k_tile_cnt == 0:
934
  acc.fill(0.0)
 
935
 
936
- # /////////////////////////////////////////////////////////////////////////////
937
- # EPILOGUE
938
- # /////////////////////////////////////////////////////////////////////////////
939
  if const_expr(self.pingpong):
940
  self.pingpong_barrier_sync(warp_group_idx, "epi")
941
-
942
- epilogue_barrier = pipeline.NamedBarrier(
943
- barrier_id=int(NamedBarrierGemm.Epilogue),
944
- num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
945
- )
946
-
947
- varlen_manager.fence_tensormap_update_epi(is_tma_warp)
948
 
949
  copy_D = None
950
  if const_expr(has_D):
@@ -955,7 +880,6 @@ class GemmSm90:
955
  self.epi_tile,
956
  sD,
957
  tile_coord_mnkl,
958
- tma_desc_ptr=tma_desc_d_ptr,
959
  )
960
  copy_C = None
961
  if const_expr(has_C):
@@ -973,8 +897,8 @@ class GemmSm90:
973
  tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
974
  tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
975
  )
976
- # (R2S, R2S_M, R2S_N)
977
- tRS_rAcc = tiled_copy_r2s.retile(acc)
978
  load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
979
  if const_expr(has_C):
980
  tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
@@ -983,17 +907,11 @@ class GemmSm90:
983
  else:
984
  tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
985
 
986
- # Wait for all warp groups in the thread block to finish, because smem for tensor
987
- # A in the mainloop is reused in the epilogue if not persistent.
988
- if const_expr(not self.is_persistent):
989
- epilogue_barrier.arrive_and_wait()
990
-
991
  self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
992
 
993
  epi_read_state, epi_producer_state = self.epilogue(
994
  epilogue_params,
995
  epi_smem_tensors,
996
- tma_desc_epi_ptrs,
997
  epi_pipeline,
998
  epi_store_pipeline,
999
  epi_read_state,
@@ -1012,7 +930,7 @@ class GemmSm90:
1012
  copy_C,
1013
  tile_coord_mnkl,
1014
  varlen_manager,
1015
- epilogue_barrier,
1016
  tile_scheduler,
1017
  tidx,
1018
  is_tma_warp,
@@ -1025,6 +943,7 @@ class GemmSm90:
1025
  if is_tma_warp:
1026
  epi_store_pipeline.producer_tail()
1027
  self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
 
1028
 
1029
  if const_expr(not self.pingpong):
1030
  tile_scheduler.advance_to_next_work()
@@ -1049,11 +968,17 @@ class GemmSm90:
1049
  work_tile = tile_scheduler.get_current_work()
1050
  # End of persistent scheduler loop
1051
 
 
 
 
 
1052
  # Wait for D store complete
1053
  if const_expr(not self.pingpong):
1054
  if is_tma_warp:
1055
  epi_store_pipeline.producer_tail()
1056
 
 
 
1057
  @cute.jit
1058
  def load_AB(
1059
  self,
@@ -1073,9 +998,7 @@ class GemmSm90:
1073
  peek_ab_empty_status = Boolean(True)
1074
  if 0 < k_tile_cnt:
1075
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1076
- # /////////////////////////////////////////////////////////////////////////
1077
  # TMA load
1078
- # /////////////////////////////////////////////////////////////////////////
1079
  for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1080
  # Wait for A/B buffers to be empty before loading into them
1081
  # Also sets the transaction barrier for the A/B buffers
@@ -1112,9 +1035,7 @@ class GemmSm90:
1112
  peek_ab_empty_status = Boolean(True)
1113
  if 0 < k_tile_cnt:
1114
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
1115
- # /////////////////////////////////////////////////////////////////////////
1116
  # TMA load on B and cp.async on A
1117
- # /////////////////////////////////////////////////////////////////////////
1118
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1119
  prefetch_out = ()
1120
  if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
@@ -1122,11 +1043,7 @@ class GemmSm90:
1122
  # Wait for A/B buffers to be empty before loading into them
1123
  # Also sets the transaction barrier for the A/B buffers
1124
  # A tiny bit faster to rotate the warp that does TMA
1125
- # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id
1126
- # since that's the warp that does the tensormap update.
1127
- is_tma_warp = warp_idx == self.ab_load_warp_id + (
1128
- (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1129
- )
1130
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1131
  smem_idx = ab_producer_state.index
1132
  # A bit faster to load B first while we calculate the indices for A
@@ -1146,9 +1063,7 @@ class GemmSm90:
1146
  prefetch_out = ()
1147
  if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1148
  prefetch_out = (prefetch_A(k_tile, pred=True),)
1149
- is_tma_warp = warp_idx == self.ab_load_warp_id + (
1150
- (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0
1151
- )
1152
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1153
  smem_idx = ab_producer_state.index
1154
  if is_tma_warp:
@@ -1159,41 +1074,78 @@ class GemmSm90:
1159
  ab_producer_state.advance()
1160
  return ab_producer_state
1161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1162
  @cute.jit
1163
  def mma(
1164
  self,
1165
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1166
  ab_read_state: cutlass.pipeline.PipelineState,
1167
- tiled_mma: cute.TiledMma,
1168
- tCrA: cute.Tensor,
1169
- tCrB: cute.Tensor,
1170
  acc: cute.Tensor,
1171
  acc_slow: Optional[cute.Tensor],
1172
  k_tile_cnt: Int32,
1173
  warp_group_idx: Int32,
1174
- ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]:
1175
- # /////////////////////////////////////////////////////////////////////////////
1176
- # Prologue MMAs
1177
- # /////////////////////////////////////////////////////////////////////////////
1178
  k_pipe_mmas = 1
1179
  ab_release_state = ab_read_state.clone()
1180
  num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
1181
- if const_expr(self.pingpong):
1182
- self.pingpong_barrier_sync(warp_group_idx, stage="mma")
1183
  peek_ab_full_status = Boolean(True)
1184
  if 0 < k_tile_cnt:
1185
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1186
- tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1187
- num_k_blocks = cute.size(tCrA, mode=[2])
1188
  for k_tile in cutlass.range(num_prologue_mma):
1189
  # Wait for A/B buffer to be ready
1190
  ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1191
- warpgroup.fence()
1192
- for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1193
- k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1194
- cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1195
- tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1196
- warpgroup.commit_group()
1197
  ab_read_state.advance()
1198
  peek_ab_full_status = Boolean(True)
1199
  if k_tile + 1 < k_tile_cnt:
@@ -1204,21 +1156,14 @@ class GemmSm90:
1204
  warpgroup.wait_group(0)
1205
  acc_slow.store(acc.load())
1206
 
1207
- # /////////////////////////////////////////////////////////////////////////////
1208
- # MAINLOOP
1209
- # /////////////////////////////////////////////////////////////////////////////
1210
  for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1211
  # Wait for TMA copies to complete
1212
  ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1213
- # WGMMA
1214
- warpgroup.fence()
1215
  if const_expr(self.fp8_slow_accum):
1216
- tiled_mma.set(warpgroup.Field.ACCUMULATE, False)
1217
- for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
1218
- k_blk_coord = (None, None, k_blk_idx, ab_read_state.index)
1219
- cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
1220
- tiled_mma.set(warpgroup.Field.ACCUMULATE, True)
1221
- warpgroup.commit_group()
1222
  # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
1223
  if const_expr(not self.fp8_slow_accum):
1224
  warpgroup.wait_group(k_pipe_mmas)
@@ -1242,16 +1187,13 @@ class GemmSm90:
1242
  ab_release_state.advance()
1243
  if const_expr(self.fp8_slow_accum):
1244
  acc.store(acc_slow.load())
1245
- # If we don't return the tiled_mma, we get compiler error
1246
- # "operand #0 does not dominate this use"
1247
- return ab_read_state, tiled_mma
1248
 
1249
  @cute.jit
1250
  def epilogue(
1251
  self,
1252
  params: EpilogueParams,
1253
  epi_smem_tensors: Tuple[cute.Tensor, ...],
1254
- tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
1255
  epi_pipeline: cutlass.pipeline.PipelineAsync,
1256
  epi_store_pipeline: cutlass.pipeline.PipelineAsync,
1257
  epi_read_state: cutlass.pipeline.PipelineState,
@@ -1277,6 +1219,18 @@ class GemmSm90:
1277
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1278
  has_C = const_expr(tRS_rC is not None)
1279
  has_D = const_expr(copy_D is not None)
 
 
 
 
 
 
 
 
 
 
 
 
1280
  epi_tile_shape = cute.zipped_divide(
1281
  cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
1282
  ).shape[1]
@@ -1306,26 +1260,6 @@ class GemmSm90:
1306
  epi_pipeline.producer_commit(epi_producer_state)
1307
  epi_producer_state.advance()
1308
 
1309
- def tma_store_fn(src_idx, dst_idx):
1310
- # Fence and barrier to make sure shared memory store is visible to TMA store
1311
- cute.arch.fence_proxy(
1312
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1313
- )
1314
- epilogue_barrier.arrive_and_wait()
1315
- # Copy from shared memory to global memory
1316
- if is_tma_warp:
1317
- if const_expr(has_D):
1318
- copy_D(src_idx=src_idx, dst_idx=dst_idx)
1319
- # Can't use if statement here, epi_store_pipeline object isn't captured somehow
1320
- if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
1321
- if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
1322
- epilogue_barrier.arrive_and_wait()
1323
-
1324
- # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops
1325
- # with the TMA store. However, currently this doesn't seem to improve perf.
1326
- delay_tma_store = False
1327
-
1328
- src_idx_prev, dst_idx_prev = None, None
1329
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
1330
  # The global memory coordinate for the current epi tile
1331
  gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
@@ -1336,9 +1270,7 @@ class GemmSm90:
1336
  epi_pipeline.consumer_wait(epi_read_state)
1337
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
1338
  # Fence to make sure shared memory read is visible to TMA load
1339
- cute.arch.fence_proxy(
1340
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
1341
- )
1342
  cute.arch.sync_warp()
1343
  with cute.arch.elect_one():
1344
  epi_pipeline.consumer_release(epi_read_state)
@@ -1350,20 +1282,63 @@ class GemmSm90:
1350
  copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1351
  epi_pipeline.producer_commit(epi_producer_state)
1352
  epi_producer_state.advance()
1353
- tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
1354
- epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1355
- if const_expr(delay_tma_store):
1356
- if const_expr(epi_idx > 0):
1357
- tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
1358
- src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
 
 
 
 
 
 
 
 
1359
  # Copy from D registers to shared memory
 
1360
  if const_expr(has_D):
1361
- copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
1362
- if const_expr(not delay_tma_store):
1363
- tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
1364
-
1365
- if const_expr(delay_tma_store):
1366
- tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
  self.epi_end(
1369
  params,
@@ -1389,8 +1364,18 @@ class GemmSm90:
1389
  mD: Optional[cute.Tensor],
1390
  scheduler_args,
1391
  varlen_args,
 
1392
  ):
1393
  """Create scheduler arguments. Override in subclasses for custom schedulers."""
 
 
 
 
 
 
 
 
 
1394
  if const_expr(varlen_args.mCuSeqlensM is None):
1395
  num_problems = (
1396
  mD.shape[2]
@@ -1402,8 +1387,8 @@ class GemmSm90:
1402
  )
1403
  )
1404
  problem_shape_ntile_mnl = (
1405
- cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]),
1406
- cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1407
  num_problems,
1408
  )
1409
  tile_sched_args = TileSchedulerArguments(
@@ -1413,13 +1398,13 @@ class GemmSm90:
1413
  cluster_shape_mnk=self.cluster_shape_mnk,
1414
  tile_count_semaphore=scheduler_args.tile_count_semaphore,
1415
  batch_idx_permute=scheduler_args.batch_idx_permute,
1416
- is_persistent=self.is_persistent,
1417
  )
1418
  else:
1419
- assert mD is not None or not self.gather_A
1420
  problem_shape_ntile_mnl = (
1421
  None,
1422
- cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]),
1423
  varlen_args.mCuSeqlensM.shape[0] - 1,
1424
  )
1425
  tile_sched_args = VarlenMTileSchedulerArguments(
@@ -1431,14 +1416,17 @@ class GemmSm90:
1431
  tile_shape_mn=self.cta_tile_shape_mnk[:2],
1432
  cluster_shape_mnk=self.cluster_shape_mnk,
1433
  tile_count_semaphore=scheduler_args.tile_count_semaphore,
1434
- is_persistent=self.is_persistent,
1435
  )
1436
  return tile_sched_args
1437
 
 
 
 
 
1438
  @cute.jit
1439
  def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
1440
- for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
1441
- tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
1442
 
1443
  @cute.jit
1444
  def epi_begin(
@@ -1504,18 +1492,6 @@ class GemmSm90:
1504
  """Subclasses can override this"""
1505
  return []
1506
 
1507
- def epi_get_tensormap_update_shapes_orders(
1508
- self,
1509
- params: EpilogueParams,
1510
- cu_seqlens_m: cute.Tensor,
1511
- batch_idx: Int32,
1512
- *,
1513
- loc=None,
1514
- ip=None,
1515
- ) -> tuple[list[Int32], list[int]]:
1516
- """Subclasses can override this"""
1517
- return [], []
1518
-
1519
  @staticmethod
1520
  def epi_smem_bytes_per_stage(
1521
  args: Optional[EpilogueArguments],
@@ -1579,7 +1555,7 @@ class GemmSm90:
1579
  tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1580
  sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1581
  tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1582
- tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype)
1583
  return tiled_copy_r2s, tRS_rD, tRS_sD
1584
 
1585
  def epilog_smem_load_and_partition(
@@ -1596,7 +1572,7 @@ class GemmSm90:
1596
  tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1597
  thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1598
  tSR_sC = thr_copy_s2r.partition_S(sC)
1599
- tRS_rC = cute.make_fragment(tRS_rD_layout, dtype)
1600
  tSR_rC = thr_copy_s2r.retile(tRS_rC)
1601
  return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1602
 
@@ -1608,7 +1584,6 @@ class GemmSm90:
1608
  epi_tile: cute.Tile,
1609
  sD: cute.Tensor,
1610
  tile_coord_mnkl: cute.Coord,
1611
- tma_desc_ptr: Optional[cute.Pointer] = None,
1612
  ) -> Tuple[cute.Tensor, cute.Tensor]:
1613
  # (bM, bN)
1614
  gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
@@ -1625,7 +1600,6 @@ class GemmSm90:
1625
  cta_layout=cute.make_layout(1),
1626
  src_tensor=src_tensor,
1627
  dst_tensor=dst_tensor,
1628
- tma_desc_ptr=tma_desc_ptr,
1629
  )
1630
 
1631
  def make_ab_pipeline(
@@ -1651,6 +1625,7 @@ class GemmSm90:
1651
  consumer_group=ab_pipeline_consumer_group,
1652
  tx_count=self.num_tma_load_bytes,
1653
  cta_layout_vmnk=cluster_layout_vmnk,
 
1654
  )
1655
 
1656
  def make_epi_pipeline(
@@ -1670,6 +1645,7 @@ class GemmSm90:
1670
  producer_group=epi_pipeline_producer_group,
1671
  consumer_group=epi_pipeline_consumer_group,
1672
  tx_count=tma_copy_c_bytes,
 
1673
  )
1674
 
1675
  def make_epi_store_pipeline(self):
@@ -1686,13 +1662,13 @@ class GemmSm90:
1686
  # Threads/warps participating in this pipeline
1687
  sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1688
  cluster_size = cute.size(cluster_layout_mnk)
1689
- # Each warp that are not the scheduler warp will contribute 1 to the arrive count
1690
  # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
1691
  # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
1692
  consumer_arrive_cnt = (
1693
  (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
1694
  + self.num_ab_load_warps
1695
- ) * cluster_size - 1
1696
  sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1697
  pipeline.Agent.Thread, consumer_arrive_cnt
1698
  )
@@ -1703,6 +1679,7 @@ class GemmSm90:
1703
  consumer_group=sched_pipeline_consumer_group,
1704
  # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1705
  consumer_mask=None if const_expr(cluster_size == 1) else 0,
 
1706
  )
1707
 
1708
  @classmethod
@@ -1717,7 +1694,6 @@ class GemmSm90:
1717
  epilogue_args: EpilogueArguments,
1718
  smem_capacity: int,
1719
  occupancy: int,
1720
- overlap_sD_sA: bool = False,
1721
  ) -> Tuple[int, int]:
1722
  """Computes the number of stages for A/B/C operands based on heuristics.
1723
 
@@ -1738,16 +1714,11 @@ class GemmSm90:
1738
  """
1739
 
1740
  epi_stage = 4 if epi_tile[1] <= 16 else 2
1741
- if overlap_sD_sA:
1742
- epi_bytes = 0
1743
- else:
1744
- d_bytes_per_stage = (
1745
- cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1746
- )
1747
- epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1748
- epilogue_args, cta_tile_shape_mnk, epi_tile
1749
- )
1750
- epi_bytes = epi_bytes_per_stage * epi_stage
1751
  epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1752
  if c_dtype is not None:
1753
  epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
@@ -1765,7 +1736,7 @@ class GemmSm90:
1765
  # Refine epilogue stages:
1766
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1767
  # Add remaining unused smem to epilogue
1768
- if not overlap_sD_sA and epi_bytes_per_stage > 0:
1769
  epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
1770
  return ab_stage, epi_stage, epi_c_stage
1771
 
@@ -2030,20 +2001,10 @@ class GemmSm90:
2030
  :rtype: bool
2031
  """
2032
  is_valid = True
2033
- if a_dtype not in {
2034
- Float16,
2035
- cutlass.BFloat16,
2036
- cutlass.Float8E4M3FN,
2037
- cutlass.Float8E5M2,
2038
- }:
2039
  is_valid = False
2040
  # tested b_dtype
2041
- if b_dtype not in {
2042
- Float16,
2043
- cutlass.BFloat16,
2044
- cutlass.Float8E4M3FN,
2045
- cutlass.Float8E5M2,
2046
- }:
2047
  is_valid = False
2048
  if acc_dtype not in {Float32, Float16}:
2049
  is_valid = False
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
  # Based on the cute-dsl example:
3
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py
4
 
 
13
  import cutlass
14
  import cutlass.cute as cute
15
  import cutlass.pipeline as pipeline
16
+ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
17
  from cutlass.cute.nvgpu import cpasync, warp, warpgroup
18
  import cutlass.utils.hopper_helpers as sm90_utils
19
  from cutlass import Int32, Float32, Float16, Boolean, const_expr
 
20
  from cutlass.utils import LayoutEnum
21
 
22
 
23
+ from dataclasses import dataclass
24
+
25
+ from .cute_dsl_utils import ParamsBase
26
+ from . import layout_utils
27
  from .tile_scheduler import (
28
  TileSchedulerOptions,
29
  TileSchedulerArguments,
30
  TileScheduler,
31
  VarlenMTileSchedulerArguments,
32
  VarlenMTileScheduler,
33
+ PersistenceMode,
34
  )
35
  from .varlen_utils import VarlenArguments, VarlenManager
36
 
 
38
  from .pipeline import make_pipeline_state, PipelineTmaCpAsync
39
  from . import copy_utils as copy_utils
40
  from . import sm90_utils as quack_sm90_utils
41
+ from .rounding import RoundingMode
42
 
43
  """
44
  A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
 
128
  """
129
 
130
  arch = 90
 
131
 
132
+ @dataclass
133
+ class EpilogueArguments:
134
+ pass
135
+
136
  EpilogueParams = ParamsBase
137
 
138
  def __init__(
 
145
  is_persistent: bool = True,
146
  fp8_fast_accum: bool = False,
147
  gather_A: bool = False,
148
+ use_clc_persistence: bool = False,
149
+ concat_layout: tuple | None = None,
150
+ use_pdl: bool = True,
151
  ):
152
  """
153
  Initializes the configuration for a Hopper dense GEMM kernel.
 
166
  self.acc_dtype = acc_dtype
167
  self.pingpong = pingpong
168
  self.is_persistent = is_persistent
169
+ self.use_clc_persistence = use_clc_persistence
170
+ if self.use_clc_persistence:
171
+ assert self.arch == 100
172
+ self.use_pdl = use_pdl
173
  if self.pingpong:
174
  assert self.is_persistent, "Pingpong gemm requires persistent scheduler"
175
  self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8
176
  self.gather_A = gather_A
177
+ self.concat_layout = concat_layout or ()
178
  if gather_A:
179
  assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A "
180
 
 
240
  self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
241
  self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
242
  self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
243
+ self.epilogue_barrier = pipeline.NamedBarrier(
244
+ barrier_id=int(NamedBarrierGemm.Epilogue),
245
+ num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
246
+ )
247
  self.num_ab_load_warps = 1 if not self.gather_A else 4
248
  self.ab_load_warp_id = self.mma_warp_groups * 4
 
 
249
 
250
  regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // (
251
  math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group
 
277
  self.shared_storage = None
278
  self.buffer_align_bytes = 1024
279
 
280
+ def _setup_tiled_mma(self):
281
+ """Set up tiled MMA and tile K dimension. Override for different MMA types."""
 
 
 
 
 
 
 
 
 
 
 
 
282
  self.tiled_mma = sm90_utils.make_trivial_tiled_mma(
283
  self.a_dtype,
284
  self.b_dtype,
 
311
  mma_inst_shape_k * mma_inst_tile_k,
312
  )
313
 
314
+ def _setup_attributes(self, epilogue_args: EpilogueArguments):
315
+ """Set up configurations that are dependent on GEMM inputs
316
+
317
+ This method configures various attributes based on the input tensor properties
318
+ (data types, leading dimensions) and kernel settings:
319
+ - Configuring tiled MMA
320
+ - Computing MMA/cluster/tile shapes
321
+ - Computing cluster layout
322
+ - Computing multicast CTAs for A/B
323
+ - Computing epilogue subtile
324
+ - Setting up A/B/C stage counts in shared memory
325
+ - Computing A/B/C shared memory layout
326
+ """
327
+ self._setup_tiled_mma()
328
+
329
  self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
330
 
331
  self.epi_tile = self._sm90_compute_tile_shape_or_override(
 
345
  epilogue_args,
346
  cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
347
  self.occupancy,
 
 
348
  )
349
  self.sched_stage = 2 if self.pingpong else 1
350
 
 
376
  mB: cute.Tensor,
377
  mD: Optional[cute.Tensor],
378
  mC: Optional[cute.Tensor],
379
+ epilogue_args: tuple,
380
  scheduler_args: TileSchedulerOptions,
381
  varlen_args: Optional[VarlenArguments],
382
  stream: cuda.CUstream,
383
+ trace_ptr: Optional[cutlass.Int64] = None,
384
  ):
385
  """Execute the GEMM operation in steps:
386
  - Setup static attributes
 
399
  :type stream: cuda.CUstream
400
  """
401
 
402
+ # Concat layout: interleave the non-contiguous dim (detected via leading_dim).
403
+ mA, mB, mD, mC = [
404
+ layout_utils.concat_to_interleave(mT, 1 - mT.leading_dim)
405
+ if const_expr(name in self.concat_layout and mT is not None)
406
+ else mT
407
+ for name, mT in [("A", mA), ("B", mB), ("out", mD), ("C", mC)]
408
+ ]
409
+
410
  # setup static attributes before smem/grid/tma computation
411
  self.a_dtype = mA.element_type
412
  self.b_dtype = mB.element_type
 
427
  if const_expr(varlen_args is None):
428
  varlen_args = VarlenArguments()
429
  assert (varlen_args.mAIdx is not None) == self.gather_A
430
+ varlen_m = varlen_args.mCuSeqlensM is not None
431
+ varlen_k = varlen_args.mCuSeqlensK is not None
 
 
 
 
 
 
 
 
 
 
432
 
433
  self._setup_attributes(epilogue_args)
434
 
 
437
  tma_atom_a, tma_tensor_a = None, None
438
  if const_expr(not self.gather_A):
439
  tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors(
440
+ copy_utils.create_ragged_tensor_for_tma(mA, ragged_dim=1)
441
+ if varlen_k and not self.gather_A
442
+ else mA,
443
  a_smem_layout,
444
  (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
445
  self.cluster_shape_mnk[1],
446
  )
447
  tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
448
+ copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB,
449
  b_smem_layout,
450
  (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
451
  self.cluster_shape_mnk[0],
 
458
  tma_atom_d, tma_tensor_d = None, None
459
  if const_expr(mD is not None):
460
  tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
461
+ copy_utils.create_ragged_tensor_for_tma(
462
+ mD,
463
+ ragged_dim=0,
464
+ ptr_shift=True,
465
+ )
466
+ if varlen_m
467
+ else mD,
468
  self.epi_smem_layout_staged,
469
  self.epi_tile,
470
  op_type="store"
 
480
  epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
481
  varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
482
 
483
+ TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m)
484
+ tile_sched_args = self.get_scheduler_arguments(
485
+ mA, mB, mD, scheduler_args, varlen_args, epilogue_args
486
+ )
487
  tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args)
488
  grid = TileSchedulerCls.get_grid_shape(
489
  tile_sched_params, scheduler_args.max_active_clusters
490
  )
491
 
492
+ epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
 
 
493
  epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
494
 
495
  @cute.struct
 
497
  ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
498
  epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
499
  sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
500
+ sched_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
501
  sD: cute.struct.Align[
502
  cute.struct.MemRange[
503
  self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size
 
542
  self.epi_c_smem_layout_staged,
543
  tile_sched_params,
544
  TileSchedulerCls,
545
+ trace_ptr,
546
  ).launch(
547
  grid=grid,
548
  block=[self.threads_per_cta, 1, 1],
549
  cluster=self.cluster_shape_mnk,
550
  stream=stream,
551
  min_blocks_per_mp=1,
552
+ use_pdl=self.use_pdl,
553
  )
554
  return
555
 
 
566
  mD_mnl: Optional[cute.Tensor],
567
  tma_atom_c: Optional[cute.CopyAtom],
568
  mC_mnl: Optional[cute.Tensor],
569
+ epilogue_params,
570
  varlen_params: VarlenManager.Params,
571
  cluster_layout_mnk: cute.Layout,
572
  a_smem_layout: cute.ComposedLayout,
573
  b_smem_layout: cute.ComposedLayout,
574
  epi_smem_layout: cute.ComposedLayout,
575
  epi_c_smem_layout: cute.ComposedLayout,
576
+ tile_sched_params,
577
  TileSchedulerCls: cutlass.Constexpr[Callable],
578
+ trace_ptr: Optional[cutlass.Int64] = None,
579
  ):
580
  """
581
  GPU device kernel performing the batched GEMM computation.
 
604
  :type epi_smem_layout: cute.ComposedLayout
605
  """
606
 
607
+ from .trace import TraceContext
608
+
609
+ tctx = TraceContext.create(trace_ptr)
610
+
611
  varlen_m = const_expr(varlen_params.cu_seqlens_m is not None)
612
  varlen_k = const_expr(varlen_params.cu_seqlens_k is not None)
613
  assert not (varlen_m and varlen_k)
 
618
 
619
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
620
 
621
+ # Prefetch Tma desc
 
 
622
  if warp_idx == self.ab_load_warp_id:
623
  for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c):
624
  if const_expr(tma_atom is not None):
625
  cpasync.prefetch_descriptor(tma_atom)
626
 
627
+ # Alloc and init AB full/empty + ACC full mbar (pipeline)
 
 
628
  smem = cutlass.utils.SmemAllocator()
629
  storage = smem.allocate(self.shared_storage)
630
 
 
640
  epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(),
641
  )
642
  sched_pipeline = None
643
+ sched_data = None
644
+ if const_expr(self.is_persistent):
 
645
  sched_pipeline = self.make_sched_pipeline(
646
  cluster_layout_mnk,
647
  sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(),
648
  varlen_k=varlen_k,
649
  )
650
+ sched_data = storage.sched_data.get_tensor((4, self.sched_stage))
651
+
652
+ # Cluster arrive after barrier init
653
+ pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
654
 
655
+ # Generate smem tensor A/B
 
 
656
  sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
657
  sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
658
  sD = None
659
  if const_expr(has_D):
660
+ sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
 
 
 
 
661
  sC = None
662
  if const_expr(has_C):
663
  sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
 
665
 
666
  varlen_manager = VarlenManager.create(
667
  varlen_params,
 
 
668
  # Only used if not varlen_m
669
  len_m_static=Int32(
670
+ cute.size(mA_mkl, mode=[0])
671
  if varlen_k or varlen_params.mAIdx is None
672
  else varlen_params.mAIdx.shape[0]
673
  ),
674
+ len_k_static=Int32(cute.size(mA_mkl, mode=[1])),
 
 
675
  )
676
 
677
  TileSchedulerCls = partial(
678
+ TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline
679
  )
680
 
681
+ # Cluster wait for barrier init
682
+ pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1])
683
+
684
  if warp_idx >= self.ab_load_warp_id:
685
+ cute.arch.setmaxregister_decrease(self.num_regs_load)
686
  if (
687
  warp_idx >= self.ab_load_warp_id
688
  and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps
689
  ):
690
+ # PDL: wait for prior kernel before any TMA loads (matches cutlass C++ sm90 mainloop producer)
691
+ if const_expr(self.use_pdl):
692
+ cute.arch.griddepcontrol_wait()
 
 
 
693
  # Get mcast mask
 
694
  cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
695
  block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
696
  a_mcast_mask = cute.make_layout_image_mask(
 
706
  is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
707
  if const_expr(cute.size(cluster_layout_mnk) > 1):
708
  is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
709
+ tile_scheduler = TileSchedulerCls()
710
  work_tile = tile_scheduler.initial_work_tile_info()
711
  ab_producer_state = make_pipeline_state(
712
  pipeline.PipelineUserType.Producer, self.ab_stage
713
  )
 
 
 
714
  while work_tile.is_valid_tile:
715
+ tctx.b("tma_load")
716
  tile_coord_mnkl = work_tile.tile_idx
717
  batch_idx = tile_coord_mnkl[3]
718
+ # Local_tile partition global tensors
719
+ copy_A, prefetch_A = None, None
 
 
 
 
 
 
 
720
  if const_expr(not self.gather_A):
721
  mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
722
  # (bM, bK, RestK)
 
725
  cute.select(self.cta_tile_shape_mnk, [0, 2]),
726
  (tile_coord_mnkl[0], None),
727
  )
728
+ # TMA load A partition_S/D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  copy_A, _, _ = copy_utils.tma_get_copy_fn(
730
  tma_atom_a,
731
  cta_coord=block_in_cluster_coord_mnk[1],
 
735
  src_tensor=gA_mk,
736
  dst_tensor=sA,
737
  mcast_mask=a_mcast_mask,
 
738
  )
739
  else:
740
+ copy_A, prefetch_A = self._make_gather_A_copy(
741
+ mA_mkl, sA, varlen_manager, tile_coord_mnkl, batch_idx
 
 
 
742
  )
743
+ # (bN, bK, RestK)
744
+ gB_nk = cute.local_tile(
745
+ varlen_manager.offset_batch_B(mB_nkl, batch_idx),
746
+ cute.select(self.cta_tile_shape_mnk, [1, 2]),
747
+ (tile_coord_mnkl[1], None),
748
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
  # TMA load B partition_S/D
750
  copy_B, _, _ = copy_utils.tma_get_copy_fn(
751
  tma_atom_b,
 
756
  src_tensor=gB_nk,
757
  dst_tensor=sB,
758
  mcast_mask=b_mcast_mask,
 
759
  )
760
+ len_k = varlen_manager.len_k(batch_idx)
761
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
762
  if const_expr(not self.gather_A):
763
  ab_producer_state = self.load_AB(
 
773
  k_tile_cnt,
774
  varlen_m=varlen_m,
775
  )
776
+ tctx.e("tma_load")
777
  tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
778
  work_tile = tile_scheduler.get_current_work()
779
  # End of persistent scheduler loop
780
  if const_expr(self.pingpong and not varlen_k):
781
  # Need to write the tile_idx to smem for the next WG in the pingpong mode
782
+ if is_scheduler_warp:
783
+ tile_scheduler.write_work_tile_to_smem(work_tile)
784
+ work_tile = tile_scheduler.get_current_work()
785
+ if warp_idx == self.ab_load_warp_id:
786
+ ab_pipeline.producer_tail(ab_producer_state)
787
  if is_scheduler_warp:
788
  tile_scheduler.producer_tail()
789
 
790
  if warp_idx < self.ab_load_warp_id:
791
+ cute.arch.setmaxregister_increase(self.num_regs_mma)
792
  is_tma_warp = Boolean(
793
  (not self.pingpong and warp_idx == 0)
794
  or (self.pingpong and (warp_idx == 0 or warp_idx == 4))
795
  )
796
+ # Partition global tensor for TiledMMA_A/B/C
 
 
 
 
 
 
 
797
  tidx, _, _ = cute.arch.thread_idx()
798
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
799
  if const_expr(self.pingpong):
800
  tidx = tidx % self.num_threads_per_warp_group
801
  warp_group_thread_layout = cute.make_layout(
802
+ self.mma_warp_groups if const_expr(not self.pingpong) else 1,
803
  stride=self.num_threads_per_warp_group,
804
  )
805
  thr_mma = tiled_mma.get_slice(
806
  warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)
807
  )
808
 
809
+ # Make fragments
810
+ acc, tCrA, tCrB = quack_sm90_utils.partition_fragment_ABC(
811
+ thr_mma, self.cta_tile_shape_mnk, sA, sB
 
 
 
 
 
812
  )
 
813
  acc_slow = None
814
  if const_expr(self.fp8_slow_accum):
815
+ acc_slow = cute.make_rmem_tensor(acc.shape, self.acc_dtype)
816
+ mma_fn = partial(quack_sm90_utils.gemm_w_idx, tiled_mma, acc, tCrA, tCrB)
817
 
818
  if const_expr(self.pingpong):
819
  if warp_group_idx == 0:
 
821
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma")
822
  self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi")
823
 
824
+ k_tile_cnt_static = cute.ceil_div(
825
+ cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2]
826
+ )
827
  c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
828
 
829
  ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
 
835
  pipeline.PipelineUserType.Producer, self.epi_c_stage
836
  )
837
  tile_scheduler = TileSchedulerCls()
838
+ work_tile = tile_scheduler.initial_work_tile_info()
839
  if const_expr(self.pingpong):
 
 
840
  if warp_idx >= 4:
841
  # Advance 2nd Math WG pipeline states to the end of 1st Math WG
842
  epi_read_state.advance_iters(c_tile_cnt)
 
847
  len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
848
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
849
  ab_read_state.advance_iters(k_tile_cnt)
850
+ # TODO: do we need to check if work_tile is valid?
851
  tile_scheduler.advance_to_next_work()
852
+ work_tile = tile_scheduler.get_current_work()
 
 
 
 
 
 
 
 
853
  while work_tile.is_valid_tile:
854
  tile_coord_mnkl = work_tile.tile_idx
855
  batch_idx = tile_coord_mnkl[3]
 
 
 
 
 
 
 
 
 
 
856
  len_k = varlen_manager.len_k(batch_idx)
857
  k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
858
+ if const_expr(self.pingpong):
859
+ self.pingpong_barrier_sync(warp_group_idx, stage="mma")
860
+ tctx.b("mma")
861
+ ab_read_state = self.mma(
862
+ ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx
 
 
 
 
 
863
  )
864
  if const_expr(varlen_k):
865
  if k_tile_cnt == 0:
866
  acc.fill(0.0)
867
+ tctx.e("mma")
868
 
869
+ # EPILOGUE
 
 
870
  if const_expr(self.pingpong):
871
  self.pingpong_barrier_sync(warp_group_idx, "epi")
872
+ tctx.b("epilogue")
 
 
 
 
 
 
873
 
874
  copy_D = None
875
  if const_expr(has_D):
 
880
  self.epi_tile,
881
  sD,
882
  tile_coord_mnkl,
 
883
  )
884
  copy_C = None
885
  if const_expr(has_C):
 
897
  tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
898
  tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
899
  )
900
+ # (R2S, R2S_M, R2S_N, num_epi)
901
+ tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s)
902
  load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
903
  if const_expr(has_C):
904
  tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
 
907
  else:
908
  tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
909
 
 
 
 
 
 
910
  self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
911
 
912
  epi_read_state, epi_producer_state = self.epilogue(
913
  epilogue_params,
914
  epi_smem_tensors,
 
915
  epi_pipeline,
916
  epi_store_pipeline,
917
  epi_read_state,
 
930
  copy_C,
931
  tile_coord_mnkl,
932
  varlen_manager,
933
+ self.epilogue_barrier,
934
  tile_scheduler,
935
  tidx,
936
  is_tma_warp,
 
943
  if is_tma_warp:
944
  epi_store_pipeline.producer_tail()
945
  self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi")
946
+ tctx.e("epilogue")
947
 
948
  if const_expr(not self.pingpong):
949
  tile_scheduler.advance_to_next_work()
 
968
  work_tile = tile_scheduler.get_current_work()
969
  # End of persistent scheduler loop
970
 
971
+ # PDL: hint next kernel to launch (matches cutlass C++ sm90 consumer)
972
+ if const_expr(self.use_pdl):
973
+ cute.arch.griddepcontrol_launch_dependents()
974
+
975
  # Wait for D store complete
976
  if const_expr(not self.pingpong):
977
  if is_tma_warp:
978
  epi_store_pipeline.producer_tail()
979
 
980
+ tctx.flush()
981
+
982
  @cute.jit
983
  def load_AB(
984
  self,
 
998
  peek_ab_empty_status = Boolean(True)
999
  if 0 < k_tile_cnt:
1000
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
 
1001
  # TMA load
 
1002
  for k_tile in cutlass.range(k_tile_cnt, unroll=1):
1003
  # Wait for A/B buffers to be empty before loading into them
1004
  # Also sets the transaction barrier for the A/B buffers
 
1035
  peek_ab_empty_status = Boolean(True)
1036
  if 0 < k_tile_cnt:
1037
  peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
 
1038
  # TMA load on B and cp.async on A
 
1039
  for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
1040
  prefetch_out = ()
1041
  if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
 
1043
  # Wait for A/B buffers to be empty before loading into them
1044
  # Also sets the transaction barrier for the A/B buffers
1045
  # A tiny bit faster to rotate the warp that does TMA
1046
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps)
 
 
 
 
1047
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1048
  smem_idx = ab_producer_state.index
1049
  # A bit faster to load B first while we calculate the indices for A
 
1063
  prefetch_out = ()
1064
  if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
1065
  prefetch_out = (prefetch_A(k_tile, pred=True),)
1066
+ is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps
 
 
1067
  ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
1068
  smem_idx = ab_producer_state.index
1069
  if is_tma_warp:
 
1074
  ab_producer_state.advance()
1075
  return ab_producer_state
1076
 
1077
+ @cute.jit
1078
+ def _make_gather_A_copy(
1079
+ self,
1080
+ mA_mkl: cute.Tensor,
1081
+ sA: cute.Tensor,
1082
+ varlen_manager: VarlenManager,
1083
+ tile_coord_mnkl,
1084
+ batch_idx: Int32,
1085
+ ):
1086
+ """Create copy_A and prefetch_A for gather_A (shared by SM90/SM120 DMA)."""
1087
+ varlen_m = varlen_manager.varlen_m
1088
+ mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
1089
+ if const_expr(varlen_m):
1090
+ gAIdx = cute.local_tile(mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],))
1091
+ mA_mk = mA_mkl
1092
+ else:
1093
+ gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
1094
+ mA_mk = cute.local_tile(
1095
+ mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None)
1096
+ )
1097
+ len_m = varlen_manager.len_m(batch_idx)
1098
+ len_k = varlen_manager.len_k(batch_idx)
1099
+ tiled_copy_A = self._make_gmem_tiled_copy_A(
1100
+ mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
1101
+ )
1102
+ dma_tidx = cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
1103
+ thr_copy_A = tiled_copy_A.get_slice(dma_tidx)
1104
+ copy_A, prefetch_A = None, None
1105
+ if const_expr(varlen_m):
1106
+ copy_A = copy_utils.gather_m_get_copy_fn(
1107
+ thr_copy_A,
1108
+ mA_mk,
1109
+ sA,
1110
+ gAIdx,
1111
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
1112
+ limit_k=len_k,
1113
+ )
1114
+ else:
1115
+ copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn(
1116
+ thr_copy_A,
1117
+ mA_mk,
1118
+ sA,
1119
+ gAIdx,
1120
+ limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0],
1121
+ limit_k=len_k,
1122
+ )
1123
+ return copy_A, prefetch_A
1124
+
1125
  @cute.jit
1126
  def mma(
1127
  self,
1128
  ab_pipeline: cutlass.pipeline.PipelineAsync,
1129
  ab_read_state: cutlass.pipeline.PipelineState,
1130
+ mma_fn: Callable,
 
 
1131
  acc: cute.Tensor,
1132
  acc_slow: Optional[cute.Tensor],
1133
  k_tile_cnt: Int32,
1134
  warp_group_idx: Int32,
1135
+ ) -> cutlass.pipeline.PipelineState:
1136
+ # Prologue MMAs
 
 
1137
  k_pipe_mmas = 1
1138
  ab_release_state = ab_read_state.clone()
1139
  num_prologue_mma = min(k_pipe_mmas, k_tile_cnt)
 
 
1140
  peek_ab_full_status = Boolean(True)
1141
  if 0 < k_tile_cnt:
1142
  peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
1143
+ zero_init = Boolean(True)
 
1144
  for k_tile in cutlass.range(num_prologue_mma):
1145
  # Wait for A/B buffer to be ready
1146
  ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
1147
+ mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init)
1148
+ zero_init = Boolean(False)
 
 
 
 
1149
  ab_read_state.advance()
1150
  peek_ab_full_status = Boolean(True)
1151
  if k_tile + 1 < k_tile_cnt:
 
1156
  warpgroup.wait_group(0)
1157
  acc_slow.store(acc.load())
1158
 
1159
+ # MAINLOOP
 
 
1160
  for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1):
1161
  # Wait for TMA copies to complete
1162
  ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
 
 
1163
  if const_expr(self.fp8_slow_accum):
1164
+ zero_init = Boolean(True)
1165
+ mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init)
1166
+ zero_init = Boolean(False)
 
 
 
1167
  # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
1168
  if const_expr(not self.fp8_slow_accum):
1169
  warpgroup.wait_group(k_pipe_mmas)
 
1187
  ab_release_state.advance()
1188
  if const_expr(self.fp8_slow_accum):
1189
  acc.store(acc_slow.load())
1190
+ return ab_read_state
 
 
1191
 
1192
  @cute.jit
1193
  def epilogue(
1194
  self,
1195
  params: EpilogueParams,
1196
  epi_smem_tensors: Tuple[cute.Tensor, ...],
 
1197
  epi_pipeline: cutlass.pipeline.PipelineAsync,
1198
  epi_store_pipeline: cutlass.pipeline.PipelineAsync,
1199
  epi_read_state: cutlass.pipeline.PipelineState,
 
1219
  ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
1220
  has_C = const_expr(tRS_rC is not None)
1221
  has_D = const_expr(copy_D is not None)
1222
+
1223
+ # Setup postact output (returns None for default epilogue, context tuple for Act)
1224
+ postact_ctx = self.epi_setup_postact(
1225
+ params,
1226
+ epi_smem_tensors,
1227
+ tiled_copy_r2s,
1228
+ tiled_copy_t2r,
1229
+ tile_coord_mnkl,
1230
+ varlen_manager,
1231
+ tidx,
1232
+ )
1233
+
1234
  epi_tile_shape = cute.zipped_divide(
1235
  cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
1236
  ).shape[1]
 
1260
  epi_pipeline.producer_commit(epi_producer_state)
1261
  epi_producer_state.advance()
1262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1263
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
1264
  # The global memory coordinate for the current epi tile
1265
  gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
 
1270
  epi_pipeline.consumer_wait(epi_read_state)
1271
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
1272
  # Fence to make sure shared memory read is visible to TMA load
1273
+ cute.arch.fence_view_async_shared()
 
 
1274
  cute.arch.sync_warp()
1275
  with cute.arch.elect_one():
1276
  epi_pipeline.consumer_release(epi_read_state)
 
1282
  copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
1283
  epi_pipeline.producer_commit(epi_producer_state)
1284
  epi_producer_state.advance()
1285
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
1286
+ # Convert and store postact if this epilogue produces one
1287
+ if const_expr(postact_ctx is not None):
1288
+ tRS_rPostAct_out = self.epi_convert_postact(
1289
+ tRS_rPostAct,
1290
+ epi_loop_tensors["sr_seed"],
1291
+ tidx,
1292
+ tile_coord_mnkl,
1293
+ num_prev_subtiles,
1294
+ epi_idx,
1295
+ )
1296
+ if is_tma_warp:
1297
+ epi_store_pipeline.producer_acquire()
1298
+ epilogue_barrier.arrive_and_wait()
1299
  # Copy from D registers to shared memory
1300
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
1301
  if const_expr(has_D):
1302
+ if const_expr(
1303
+ self.rounding_mode == RoundingMode.RS
1304
+ and self.acc_dtype == cutlass.Float32
1305
+ and self.d_dtype == cutlass.BFloat16
1306
+ ):
1307
+ seed = epi_loop_tensors["sr_seed"] + (
1308
+ tile_coord_mnkl[0] * 65537
1309
+ + tile_coord_mnkl[1] * 257
1310
+ + tile_coord_mnkl[3] * 17
1311
+ + (num_prev_subtiles + epi_idx) * 7
1312
+ )
1313
+ copy_utils.sr_cvt_copy(
1314
+ tiled_copy_r2s,
1315
+ tRS_rD,
1316
+ tRS_sD[None, None, None, epi_buffer],
1317
+ seed,
1318
+ tidx,
1319
+ )
1320
+ else:
1321
+ copy_utils.cvt_copy(
1322
+ tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]
1323
+ )
1324
+ # Copy postact from registers to shared memory
1325
+ if const_expr(postact_ctx is not None):
1326
+ tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx
1327
+ cute.copy(
1328
+ tiled_copy_postact_r2s,
1329
+ tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
1330
+ tRS_sPostAct[None, None, None, epi_buffer],
1331
+ )
1332
+ # Fence and barrier to make sure shared memory store is visible to TMA store
1333
+ cute.arch.fence_view_async_shared()
1334
+ epilogue_barrier.arrive_and_wait()
1335
+ # Copy from shared memory to global memory
1336
+ if is_tma_warp:
1337
+ if const_expr(has_D):
1338
+ copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
1339
+ if const_expr(postact_ctx is not None):
1340
+ copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
1341
+ epi_store_pipeline.producer_commit()
1342
 
1343
  self.epi_end(
1344
  params,
 
1364
  mD: Optional[cute.Tensor],
1365
  scheduler_args,
1366
  varlen_args,
1367
+ epilogue_args,
1368
  ):
1369
  """Create scheduler arguments. Override in subclasses for custom schedulers."""
1370
+ if const_expr(not self.is_persistent):
1371
+ persistence_mode = PersistenceMode.NONE
1372
+ else:
1373
+ if const_expr(self.arch >= 100 and self.use_clc_persistence):
1374
+ persistence_mode = PersistenceMode.CLC
1375
+ elif const_expr(scheduler_args.tile_count_semaphore is not None):
1376
+ persistence_mode = PersistenceMode.DYNAMIC
1377
+ else:
1378
+ persistence_mode = PersistenceMode.STATIC
1379
  if const_expr(varlen_args.mCuSeqlensM is None):
1380
  num_problems = (
1381
  mD.shape[2]
 
1387
  )
1388
  )
1389
  problem_shape_ntile_mnl = (
1390
+ cute.ceil_div(cute.size(mA, mode=[0]), self.cta_tile_shape_mnk[0]),
1391
+ cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]),
1392
  num_problems,
1393
  )
1394
  tile_sched_args = TileSchedulerArguments(
 
1398
  cluster_shape_mnk=self.cluster_shape_mnk,
1399
  tile_count_semaphore=scheduler_args.tile_count_semaphore,
1400
  batch_idx_permute=scheduler_args.batch_idx_permute,
1401
+ persistence_mode=persistence_mode,
1402
  )
1403
  else:
1404
+ assert (mD is not None) or (epilogue_args.mPostAct is not None) or (not self.gather_A)
1405
  problem_shape_ntile_mnl = (
1406
  None,
1407
+ cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]),
1408
  varlen_args.mCuSeqlensM.shape[0] - 1,
1409
  )
1410
  tile_sched_args = VarlenMTileSchedulerArguments(
 
1416
  tile_shape_mn=self.cta_tile_shape_mnk[:2],
1417
  cluster_shape_mnk=self.cluster_shape_mnk,
1418
  tile_count_semaphore=scheduler_args.tile_count_semaphore,
1419
+ persistence_mode=persistence_mode,
1420
  )
1421
  return tile_sched_args
1422
 
1423
+ def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s):
1424
+ """Retile accumulator for epilogue subtile access. SM90 uses flat_divide."""
1425
+ return cute.flat_divide(acc, tRS_rD.layout)
1426
+
1427
  @cute.jit
1428
  def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
1429
+ cute.autovec_copy(tRS_rAcc[None, None, None, epi_idx], tRS_rD)
 
1430
 
1431
  @cute.jit
1432
  def epi_begin(
 
1492
  """Subclasses can override this"""
1493
  return []
1494
 
 
 
 
 
 
 
 
 
 
 
 
 
1495
  @staticmethod
1496
  def epi_smem_bytes_per_stage(
1497
  args: Optional[EpilogueArguments],
 
1555
  tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
1556
  sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
1557
  tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
1558
+ tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype)
1559
  return tiled_copy_r2s, tRS_rD, tRS_sD
1560
 
1561
  def epilog_smem_load_and_partition(
 
1572
  tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom)
1573
  thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
1574
  tSR_sC = thr_copy_s2r.partition_S(sC)
1575
+ tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, dtype)
1576
  tSR_rC = thr_copy_s2r.retile(tRS_rC)
1577
  return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC
1578
 
 
1584
  epi_tile: cute.Tile,
1585
  sD: cute.Tensor,
1586
  tile_coord_mnkl: cute.Coord,
 
1587
  ) -> Tuple[cute.Tensor, cute.Tensor]:
1588
  # (bM, bN)
1589
  gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
 
1600
  cta_layout=cute.make_layout(1),
1601
  src_tensor=src_tensor,
1602
  dst_tensor=dst_tensor,
 
1603
  )
1604
 
1605
  def make_ab_pipeline(
 
1625
  consumer_group=ab_pipeline_consumer_group,
1626
  tx_count=self.num_tma_load_bytes,
1627
  cta_layout_vmnk=cluster_layout_vmnk,
1628
+ defer_sync=True,
1629
  )
1630
 
1631
  def make_epi_pipeline(
 
1645
  producer_group=epi_pipeline_producer_group,
1646
  consumer_group=epi_pipeline_consumer_group,
1647
  tx_count=tma_copy_c_bytes,
1648
+ defer_sync=True,
1649
  )
1650
 
1651
  def make_epi_store_pipeline(self):
 
1662
  # Threads/warps participating in this pipeline
1663
  sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
1664
  cluster_size = cute.size(cluster_layout_mnk)
1665
+ # Each warp will contribute 1 to the arrive count
1666
  # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier
1667
  # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate.
1668
  consumer_arrive_cnt = (
1669
  (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4
1670
  + self.num_ab_load_warps
1671
+ ) * cluster_size
1672
  sched_pipeline_consumer_group = pipeline.CooperativeGroup(
1673
  pipeline.Agent.Thread, consumer_arrive_cnt
1674
  )
 
1679
  consumer_group=sched_pipeline_consumer_group,
1680
  # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster.
1681
  consumer_mask=None if const_expr(cluster_size == 1) else 0,
1682
+ defer_sync=True,
1683
  )
1684
 
1685
  @classmethod
 
1694
  epilogue_args: EpilogueArguments,
1695
  smem_capacity: int,
1696
  occupancy: int,
 
1697
  ) -> Tuple[int, int]:
1698
  """Computes the number of stages for A/B/C operands based on heuristics.
1699
 
 
1714
  """
1715
 
1716
  epi_stage = 4 if epi_tile[1] <= 16 else 2
1717
+ d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
1718
+ epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
1719
+ epilogue_args, cta_tile_shape_mnk, epi_tile
1720
+ )
1721
+ epi_bytes = epi_bytes_per_stage * epi_stage
 
 
 
 
 
1722
  epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2)
1723
  if c_dtype is not None:
1724
  epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage
 
1736
  # Refine epilogue stages:
1737
  # Calculate remaining smem after allocating for A/B stages and reserved bytes
1738
  # Add remaining unused smem to epilogue
1739
+ if epi_bytes_per_stage > 0:
1740
  epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
1741
  return ab_stage, epi_stage, epi_c_stage
1742
 
 
2001
  :rtype: bool
2002
  """
2003
  is_valid = True
2004
+ if a_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}:
 
 
 
 
 
2005
  is_valid = False
2006
  # tested b_dtype
2007
+ if b_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}:
 
 
 
 
 
2008
  is_valid = False
2009
  if acc_dtype not in {Float32, Float16}:
2010
  is_valid = False
build/torch-cuda/quack/gemm_sq_reduce.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
+ # GEMM with column vector reduction of squared output and optional rowvec scaling:
3
+ # D_raw = A @ B (+ C), reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec.
4
+
5
+ from typing import NamedTuple, Optional
6
+
7
+ from torch import Tensor
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass import Float32, const_expr
12
+
13
+ from .cute_dsl_utils import (
14
+ mlir_namedtuple,
15
+ torch2cute_dtype_map,
16
+ get_device_capacity,
17
+ get_max_active_clusters,
18
+ )
19
+ from .epi_ops import ColVecReduce, colvec_reduce_accumulate, vec_multiply
20
+ from .gemm_sm90 import GemmSm90
21
+ from .gemm_sm100 import GemmSm100
22
+ from .gemm_sm120 import GemmSm120
23
+ from .gemm_default_epi import GemmDefaultEpiMixin
24
+ from .rounding import RoundingMode
25
+ from .compile_utils import make_fake_tensor as fake_tensor
26
+ from .cache_utils import jit_cache
27
+ from .gemm_tvm_ffi_utils import (
28
+ get_majors,
29
+ get_dtypes,
30
+ perm3d,
31
+ make_scheduler_args,
32
+ make_varlen_args,
33
+ make_fake_scheduler_args,
34
+ make_fake_varlen_args,
35
+ make_fake_gemm_tensors,
36
+ compile_gemm_kernel,
37
+ )
38
+ from . import utils as utils
39
+
40
+
41
+ class GemmSqReduceMixin(GemmDefaultEpiMixin):
42
+ """GEMM + sq_reduce + optional rowvec scaling.
43
+
44
+ D_raw = A @ B (+ C), reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec.
45
+ The sq_sum is computed BEFORE the rowvec scaling.
46
+ """
47
+
48
+ _epi_ops = (*GemmDefaultEpiMixin._epi_ops, ColVecReduce("mColVecReduce"))
49
+
50
+ @mlir_namedtuple
51
+ class EpilogueArguments(NamedTuple):
52
+ alpha: Optional[Float32 | cute.Tensor] = None
53
+ beta: Optional[Float32 | cute.Tensor] = None
54
+ mRowVecBroadcast: Optional[cute.Tensor] = None
55
+ mColVecBroadcast: Optional[cute.Tensor] = None
56
+ mColVecReduce: Optional[cute.Tensor] = None
57
+ add_to_output: cutlass.Constexpr[bool] = False
58
+ rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
59
+ sr_seed: None = None
60
+
61
+ # EpilogueParams auto-generated from _epi_ops
62
+
63
+ def epi_to_underlying_arguments(self, args, *, loc=None, ip=None):
64
+ self.rounding_mode = args.rounding_mode
65
+ d = self._epi_ops_to_params_dict(args)
66
+ return self.EpilogueParams(**d)
67
+
68
+ @cute.jit
69
+ def epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None):
70
+ tDrColVecReduce = epi_loop_tensors["mColVecReduce"]
71
+ tDrRowVec = epi_loop_tensors["mRowVecBroadcast"]
72
+ # Load accumulator, apply alpha/beta/C (skip rowvec/colvec — we handle rowvec below)
73
+ rD = tRS_rD.load()
74
+ if const_expr(hasattr(params, "alpha") and params.alpha is not None):
75
+ alpha = utils.load_scalar_or_pointer(params.alpha)
76
+ rD *= alpha
77
+ if const_expr(tRS_rC is not None):
78
+ if const_expr(not hasattr(params, "beta") or params.beta is None):
79
+ rD += tRS_rC.load().to(tRS_rD.element_type)
80
+ else:
81
+ beta = utils.load_scalar_or_pointer(params.beta)
82
+ rD += beta * tRS_rC.load().to(tRS_rD.element_type)
83
+ tRS_rD.store(rD)
84
+ # Accumulate sq_sum BEFORE rowvec scaling: reduce[m] += sum_n(D[m,n]^2)
85
+ colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rD, rScale=tRS_rD)
86
+ # Multiply by rowvec (norm_weight) AFTER sq_sum
87
+ vec_multiply(self, tRS_rD, None, tDrRowVec)
88
+ return None
89
+
90
+
91
+ class GemmSqReduceSm90(GemmSqReduceMixin, GemmSm90):
92
+ pass
93
+
94
+
95
+ class GemmSqReduceSm100(GemmSqReduceMixin, GemmSm100):
96
+ pass
97
+
98
+
99
+ class GemmSqReduceSm120(GemmSqReduceMixin, GemmSm120):
100
+ pass
101
+
102
+
103
+ @jit_cache
104
+ def _compile_gemm_sq_reduce(
105
+ a_dtype,
106
+ b_dtype,
107
+ d_dtype,
108
+ c_dtype,
109
+ a_major,
110
+ b_major,
111
+ d_major,
112
+ c_major,
113
+ tile_shape_mn,
114
+ cluster_shape_mnk,
115
+ pingpong,
116
+ persistent,
117
+ is_dynamic_persistent,
118
+ colvec_reduce_dtype,
119
+ colvec_reduce_ndim,
120
+ rowvec_dtype,
121
+ device_capacity,
122
+ ):
123
+ sm_to_cls = {
124
+ 9: GemmSqReduceSm90,
125
+ 10: GemmSqReduceSm100,
126
+ 11: GemmSqReduceSm100,
127
+ 12: GemmSqReduceSm120,
128
+ }
129
+ GemmCls = sm_to_cls[device_capacity[0]]
130
+ mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
131
+ a_dtype,
132
+ b_dtype,
133
+ d_dtype,
134
+ c_dtype,
135
+ a_major,
136
+ b_major,
137
+ d_major,
138
+ c_major,
139
+ )
140
+ n_tiles = cute.sym_int()
141
+ if colvec_reduce_ndim == 3:
142
+ mColVecReduce = fake_tensor(
143
+ colvec_reduce_dtype,
144
+ (l, m, n_tiles),
145
+ leading_dim=2,
146
+ divisibility=1,
147
+ )
148
+ else:
149
+ mColVecReduce = fake_tensor(
150
+ colvec_reduce_dtype,
151
+ (m, n_tiles),
152
+ leading_dim=1,
153
+ divisibility=1,
154
+ )
155
+ mRowVec = fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4)
156
+ epi_args = GemmCls.EpilogueArguments(
157
+ mRowVecBroadcast=mRowVec,
158
+ mColVecReduce=mColVecReduce,
159
+ )
160
+ scheduler_args = make_fake_scheduler_args(
161
+ (is_dynamic_persistent and device_capacity[0] == 9), False, l
162
+ )
163
+ varlen_args = make_fake_varlen_args(False, False, False, None)
164
+ return compile_gemm_kernel(
165
+ GemmCls,
166
+ a_dtype,
167
+ tile_shape_mn,
168
+ cluster_shape_mnk,
169
+ pingpong,
170
+ persistent,
171
+ False,
172
+ is_dynamic_persistent,
173
+ device_capacity,
174
+ mA,
175
+ mB,
176
+ mD,
177
+ mC,
178
+ epi_args,
179
+ scheduler_args,
180
+ varlen_args,
181
+ )
182
+
183
+
184
+ def gemm_sq_reduce(
185
+ A: Tensor, # (l, m, k)
186
+ B: Tensor, # (l, n, k)
187
+ D: Tensor, # (l, m, n)
188
+ C: Optional[Tensor], # (l, m, n)
189
+ colvec_reduce: Tensor, # (l, m, ceildiv(n, tile_n))
190
+ tile_count_semaphore: Optional[Tensor], # (1,)
191
+ tile_M: int,
192
+ tile_N: int,
193
+ cluster_M: int,
194
+ cluster_N: int,
195
+ pingpong: bool = False,
196
+ persistent: bool = True,
197
+ is_dynamic_persistent: bool = False,
198
+ max_swizzle_size: int = 8,
199
+ rowvec: Optional[Tensor] = None, # (l, n) — norm_weight
200
+ ) -> None:
201
+ """GEMM + sq_reduce + optional rowvec scaling.
202
+
203
+ D_raw = A @ B (+ C), colvec_reduce[m] = sum_n(D_raw[m,n]^2), D_out = D_raw * rowvec.
204
+ """
205
+ device_capacity = get_device_capacity(A.device)
206
+ assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
207
+ if device_capacity[0] == 12:
208
+ raise NotImplementedError("SM120 GEMM sq reduce epilogue is not yet supported")
209
+
210
+ A_p, B_p, D_p, C_p = perm3d(A, B, D, C)
211
+ a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
212
+ a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
213
+
214
+ if is_dynamic_persistent and device_capacity[0] == 9:
215
+ assert tile_count_semaphore is not None, (
216
+ "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
217
+ )
218
+
219
+ compiled_fn = _compile_gemm_sq_reduce(
220
+ a_dtype,
221
+ b_dtype,
222
+ d_dtype,
223
+ c_dtype,
224
+ a_major,
225
+ b_major,
226
+ d_major,
227
+ c_major,
228
+ (tile_M, tile_N),
229
+ (cluster_M, cluster_N, 1),
230
+ pingpong,
231
+ persistent,
232
+ is_dynamic_persistent,
233
+ torch2cute_dtype_map[colvec_reduce.dtype],
234
+ colvec_reduce.ndim,
235
+ torch2cute_dtype_map[rowvec.dtype] if rowvec is not None else None,
236
+ device_capacity,
237
+ )
238
+
239
+ from .cache_utils import COMPILE_ONLY
240
+
241
+ if COMPILE_ONLY:
242
+ return
243
+
244
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
245
+ epi_args = GemmSqReduceMixin.EpilogueArguments(
246
+ mRowVecBroadcast=rowvec,
247
+ mColVecReduce=colvec_reduce,
248
+ add_to_output=None, # Constexpr, pass None at runtime
249
+ rounding_mode=None, # Constexpr, pass None at runtime
250
+ )
251
+ scheduler_args = make_scheduler_args(
252
+ max_active_clusters, max_swizzle_size, tile_count_semaphore
253
+ )
254
+ varlen_args = make_varlen_args(None, None, None)
255
+
256
+ if device_capacity[0] in [10, 11]:
257
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
258
+ else:
259
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
build/torch-cuda/quack/gemm_symmetric.py CHANGED
@@ -1,25 +1,36 @@
1
  from typing import Tuple, Optional, Callable
2
- from functools import partial
3
  from torch import Tensor
4
- from .gemm_act import GemmActMixin, act_fn_map, gemm_act
 
 
 
 
 
 
 
 
 
5
  from .gemm_sm90 import GemmSm90
6
  from .gemm_sm100 import GemmSm100
 
 
 
 
 
 
 
 
 
 
 
7
  from .tile_scheduler import TriangularTileScheduler
8
- from .gemm_wrapper_utils import GemmWrapperBase
9
- from .cute_dsl_utils import get_device_capacity, get_max_active_clusters
10
  from .varlen_utils import VarlenManager
11
  from . import copy_utils as copy_utils
12
- import cutlass
13
- import cutlass.cute as cute
14
- import cutlass.torch as cutlass_torch
15
- from cutlass.cute.runtime import make_ptr
16
- from cutlass import Int32, Float32, Boolean, const_expr
17
- import cutlass.utils.hopper_helpers as sm90_utils_og
18
- import cutlass.utils.blackwell_helpers as sm100_utils
19
- from cutlass.cutlass_dsl import if_generate
20
 
21
 
22
- class GemmSymmetricMixin(GemmActMixin, GemmSm90):
23
  def get_scheduler_class(self, varlen_m: bool = False):
24
  return TriangularTileScheduler
25
 
@@ -28,7 +39,6 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
28
  self,
29
  params: GemmActMixin.EpilogueParams,
30
  epi_smem_tensors: Tuple[cute.Tensor, ...],
31
- tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
32
  epi_pipeline: cutlass.pipeline.PipelineAsync,
33
  epi_store_pipeline: cutlass.pipeline.PipelineAsync,
34
  epi_read_state: cutlass.pipeline.PipelineState,
@@ -55,31 +65,14 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
55
  has_C = const_expr(tRS_rC is not None)
56
  has_D = const_expr(copy_D is not None)
57
 
58
- tma_atom_postact = params.tma_atom_postact
59
- mPostAct_mnl = params.mPostAct_mnl
60
- sRowVec, sColVec, sPostAct = epi_smem_tensors
61
- get_smem_store_op = (
62
- partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
63
- if self.arch == 100
64
- else sm90_utils_og.sm90_get_smem_store_op
65
- )
66
- copy_atom_postact_r2s = get_smem_store_op(
67
- self.postact_layout, self.postact_dtype, self.acc_dtype
68
- )
69
- # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
70
- # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
71
- tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
72
- tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
73
- (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
74
- batch_idx = tile_coord_mnkl[3]
75
- copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
76
- tma_atom_postact,
77
- varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
78
- self.cta_tile_shape_postact_mn,
79
- params.epi_tile_postact,
80
- sPostAct,
81
  tile_coord_mnkl,
82
- tma_desc_ptr=tma_desc_postact_ptr,
 
83
  )
84
 
85
  # We iterate over epi tiles in the N dimension first before the M dimension
@@ -111,30 +104,6 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
111
  epi_pipeline.producer_commit(epi_producer_state)
112
  epi_producer_state.advance()
113
 
114
- def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
115
- pid_m = tile_coord_mnkl[0]
116
- pid_n = tile_coord_mnkl[1]
117
- # Fence and barrier to make sure shared memory store is visible to TMA store
118
- cute.arch.fence_proxy(
119
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
120
- )
121
- epilogue_barrier.arrive_and_wait()
122
- # Copy from shared memory to global memory
123
- if is_tma_warp:
124
- square_tile_m = pid_m // self.cluster_shape_mnk[0]
125
- square_tile_n = pid_n // self.cluster_shape_mnk[1]
126
- if const_expr(has_D):
127
- copy_D(src_idx=src_idx, dst_idx=dst_idx)
128
- if square_tile_m != square_tile_n: # don't write twice to the same tile
129
- copy_postact(src_idx=src_idx, dst_idx=dst_idx)
130
- # Can't use if statement here, epi_store_pipeline object isn't captured somehow
131
- if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
132
- if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
133
- epilogue_barrier.arrive_and_wait()
134
-
135
- delay_tma_store = True
136
-
137
- src_idx_prev, dst_idx_prev = None, None
138
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
139
  # The global memory coordinate for the current epi tile
140
  gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
@@ -145,9 +114,7 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
145
  epi_pipeline.consumer_wait(epi_read_state)
146
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
147
  # Fence to make sure shared memory read is visible to TMA load
148
- cute.arch.fence_proxy(
149
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
150
- )
151
  cute.arch.sync_warp()
152
  with cute.arch.elect_one():
153
  epi_pipeline.consumer_release(epi_read_state)
@@ -160,30 +127,61 @@ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
160
  epi_pipeline.producer_commit(epi_producer_state)
161
  epi_producer_state.advance()
162
  tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
163
- epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
164
- if const_expr(delay_tma_store):
165
- if const_expr(epi_idx > 0):
166
- tma_store_fn(
167
- src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
168
- )
169
- src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
 
 
 
 
170
  # Copy from D registers to shared memory
 
171
  if const_expr(has_D):
172
- copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  cute.copy(
174
  tiled_copy_postact_r2s,
175
- tiled_copy_postact_r2s.retile(tRS_rPostAct),
176
  tRS_sPostAct[None, None, None, epi_buffer],
177
  )
178
- if const_expr(not delay_tma_store):
179
- tma_store_fn(
180
- src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
181
- )
182
-
183
- if const_expr(delay_tma_store):
184
- tma_store_fn(
185
- src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
186
- )
 
 
 
 
 
187
 
188
  self.epi_end(
189
  params,
@@ -207,6 +205,97 @@ class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
207
  pass
208
 
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  def gemm_symmetric(
211
  A: Tensor, # (l, m, k)
212
  B: Tensor, # (l, m, k)
@@ -219,112 +308,87 @@ def gemm_symmetric(
219
  cluster_N: int,
220
  pingpong: bool = False,
221
  persistent: bool = True,
 
222
  max_swizzle_size: int = 8,
223
  alpha: float | Tensor = 1.0,
224
  beta: float | Tensor = 1.0,
225
  ) -> None:
226
- # Tranpose D so the "activation" is a write to the mirrored tile
227
  PostAct = D.mT
228
 
229
- L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
230
- A, B, D, C, additional_tensors={"PostAct": PostAct}
231
- )
232
- assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
233
- GemmWrapperBase.permute_tensors(tensor_infos)
234
- GemmWrapperBase.extract_dtypes(tensor_infos)
235
- major_configs = {
236
- "A": ("m", "k", "l"),
237
- "B": ("n", "k", "l"),
238
- "D": ("m", "n", "l"),
239
- "C": ("m", "n", "l"),
240
- "PostAct": ("m", "n", "l"),
241
- }
242
- GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
243
 
244
  device_capacity = get_device_capacity(A.device)
245
- assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
246
- GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
 
 
 
 
247
 
248
- acc_dtype = Float32
249
  tile_shape_mn = (tile_M, tile_N)
250
  cluster_shape_mnk = (cluster_M, cluster_N, 1)
251
- if not GemmCls.is_valid_dtypes(
252
- tensor_infos["A"].dtype,
253
- tensor_infos["B"].dtype,
254
- acc_dtype,
255
- tensor_infos["D"].dtype,
256
- tensor_infos["A"].major,
257
- tensor_infos["B"].major,
258
- ):
259
- raise TypeError("Skipping due to unsupported combination of types and majors")
260
 
261
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
262
- GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
263
-
264
- def scalar_arg(scalar: float | Tensor):
265
- if isinstance(scalar, float):
266
- return Float32(scalar) if scalar != 1.0 else None
267
- else:
268
- assert isinstance(scalar, Tensor)
269
- return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
270
-
271
- activation = None # Equivalent to identity
272
- act_fn = act_fn_map[activation]
273
- epi_args = GemmCls.EpilogueArguments(
274
- tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
275
- )
276
- scheduler_args = GemmWrapperBase.create_scheduler_args(
277
- max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
278
- )
279
- varlen_args = None
280
-
281
- current_stream = cutlass_torch.current_stream()
282
- compile_key = GemmWrapperBase.get_compile_key(
283
- tensor_infos,
284
- activation,
285
  tile_shape_mn,
286
  cluster_shape_mnk,
287
  pingpong,
288
  persistent,
289
- tile_count_semaphore is not None,
 
 
290
  device_capacity,
291
- max_swizzle_size,
292
- 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
293
- 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
294
- key_tensor_names=("A", "B", "D", "PostAct", "C"),
295
- )
296
- cache = gemm_act.compile_cache
297
- if compile_key not in cache:
298
- if device_capacity[0] == 9:
299
- GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
300
- gemm_obj = GemmCls(
301
- acc_dtype,
302
- tensor_infos["A"].dtype,
303
- tile_shape_mn,
304
- cluster_shape_mnk,
305
- gather_A=False,
306
- )
307
- cache[compile_key] = cute.compile(
308
- gemm_obj,
309
- tensor_infos["A"].cute_tensor,
310
- tensor_infos["B"].cute_tensor,
311
- tensor_infos["D"].cute_tensor,
312
- tensor_infos["C"].cute_tensor,
313
- epi_args,
314
- scheduler_args,
315
- varlen_args,
316
- current_stream,
317
- )
318
- cache[compile_key](
319
- tensor_infos["A"].cute_tensor,
320
- tensor_infos["B"].cute_tensor,
321
- tensor_infos["D"].cute_tensor,
322
- tensor_infos["C"].cute_tensor,
323
- epi_args,
324
- scheduler_args,
325
- varlen_args,
326
- current_stream,
327
  )
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- gemm_act.compile_cache = {}
 
 
 
 
1
  from typing import Tuple, Optional, Callable
2
+
3
  from torch import Tensor
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass import Int32, Float32, Boolean, const_expr
8
+ from cutlass.cute.runtime import make_ptr
9
+
10
+ from .compile_utils import make_fake_tensor as fake_tensor
11
+ from .cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map
12
+ from .activation import act_fn_map
13
+ from .gemm_act import GemmActMixin
14
  from .gemm_sm90 import GemmSm90
15
  from .gemm_sm100 import GemmSm100
16
+ from .gemm_sm120 import GemmSm120
17
+ from .gemm_tvm_ffi_utils import (
18
+ div_for_dtype,
19
+ perm3d,
20
+ get_majors,
21
+ get_dtypes,
22
+ make_scheduler_args,
23
+ make_fake_scheduler_args,
24
+ compile_gemm_kernel,
25
+ )
26
+ from .cache_utils import jit_cache
27
  from .tile_scheduler import TriangularTileScheduler
 
 
28
  from .varlen_utils import VarlenManager
29
  from . import copy_utils as copy_utils
30
+ from .rounding import RoundingMode
 
 
 
 
 
 
 
31
 
32
 
33
+ class GemmSymmetricMixin(GemmActMixin):
34
  def get_scheduler_class(self, varlen_m: bool = False):
35
  return TriangularTileScheduler
36
 
 
39
  self,
40
  params: GemmActMixin.EpilogueParams,
41
  epi_smem_tensors: Tuple[cute.Tensor, ...],
 
42
  epi_pipeline: cutlass.pipeline.PipelineAsync,
43
  epi_store_pipeline: cutlass.pipeline.PipelineAsync,
44
  epi_read_state: cutlass.pipeline.PipelineState,
 
65
  has_C = const_expr(tRS_rC is not None)
66
  has_D = const_expr(copy_D is not None)
67
 
68
+ tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = self.epi_setup_postact(
69
+ params,
70
+ epi_smem_tensors,
71
+ tiled_copy_r2s,
72
+ tiled_copy_t2r,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  tile_coord_mnkl,
74
+ varlen_manager,
75
+ tidx,
76
  )
77
 
78
  # We iterate over epi tiles in the N dimension first before the M dimension
 
104
  epi_pipeline.producer_commit(epi_producer_state)
105
  epi_producer_state.advance()
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  for epi_idx in cutlass.range_constexpr(epi_tile_num):
108
  # The global memory coordinate for the current epi tile
109
  gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
 
114
  epi_pipeline.consumer_wait(epi_read_state)
115
  cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
116
  # Fence to make sure shared memory read is visible to TMA load
117
+ cute.arch.fence_view_async_shared()
 
 
118
  cute.arch.sync_warp()
119
  with cute.arch.elect_one():
120
  epi_pipeline.consumer_release(epi_read_state)
 
127
  epi_pipeline.producer_commit(epi_producer_state)
128
  epi_producer_state.advance()
129
  tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
130
+ tRS_rPostAct_out = self.epi_convert_postact(
131
+ tRS_rPostAct,
132
+ epi_loop_tensors["sr_seed"],
133
+ tidx,
134
+ tile_coord_mnkl,
135
+ num_prev_subtiles,
136
+ epi_idx,
137
+ )
138
+ if is_tma_warp:
139
+ epi_store_pipeline.producer_acquire()
140
+ epilogue_barrier.arrive_and_wait()
141
  # Copy from D registers to shared memory
142
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
143
  if const_expr(has_D):
144
+ if const_expr(
145
+ self.rounding_mode == RoundingMode.RS
146
+ and self.acc_dtype == cutlass.Float32
147
+ and self.d_dtype == cutlass.BFloat16
148
+ ):
149
+ seed = epi_loop_tensors["sr_seed"] + (
150
+ tile_coord_mnkl[0] * 65537
151
+ + tile_coord_mnkl[1] * 257
152
+ + tile_coord_mnkl[3] * 17
153
+ + (num_prev_subtiles + epi_idx) * 7
154
+ )
155
+ copy_utils.sr_cvt_copy(
156
+ tiled_copy_r2s,
157
+ tRS_rD,
158
+ tRS_sD[None, None, None, epi_buffer],
159
+ seed,
160
+ tidx,
161
+ )
162
+ else:
163
+ copy_utils.cvt_copy(
164
+ tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]
165
+ )
166
  cute.copy(
167
  tiled_copy_postact_r2s,
168
+ tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
169
  tRS_sPostAct[None, None, None, epi_buffer],
170
  )
171
+ pid_m = tile_coord_mnkl[0]
172
+ pid_n = tile_coord_mnkl[1]
173
+ # Fence and barrier to make sure shared memory store is visible to TMA store
174
+ cute.arch.fence_view_async_shared()
175
+ epilogue_barrier.arrive_and_wait()
176
+ # Copy from shared memory to global memory
177
+ if is_tma_warp:
178
+ square_tile_m = pid_m // self.cluster_shape_mnk[0]
179
+ square_tile_n = pid_n // self.cluster_shape_mnk[1]
180
+ if const_expr(has_D):
181
+ copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
182
+ if square_tile_m != square_tile_n: # don't write twice to the same tile
183
+ copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
184
+ epi_store_pipeline.producer_commit()
185
 
186
  self.epi_end(
187
  params,
 
205
  pass
206
 
207
 
208
+ class GemmSymmetricSm120(GemmSymmetricMixin, GemmSm120):
209
+ pass
210
+
211
+
212
+ @jit_cache
213
+ def _compile_gemm_symmetric(
214
+ a_dtype,
215
+ b_dtype,
216
+ d_dtype,
217
+ c_dtype,
218
+ c_major,
219
+ postact_dtype,
220
+ a_major,
221
+ b_major,
222
+ d_major,
223
+ postact_major,
224
+ tile_shape_mn,
225
+ cluster_shape_mnk,
226
+ pingpong,
227
+ persistent,
228
+ is_dynamic_persistent,
229
+ alpha_mode,
230
+ beta_mode,
231
+ device_capacity,
232
+ ):
233
+ sm_to_cls = {
234
+ 9: GemmSymmetricSm90,
235
+ 10: GemmSymmetricSm100,
236
+ 11: GemmSymmetricSm100,
237
+ 12: GemmSymmetricSm120,
238
+ }
239
+ GemmCls = sm_to_cls[device_capacity[0]]
240
+ # Symmetric GEMM: m == n, so reuse the same sym_int for shape checking
241
+ m, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int()
242
+ a_leading = 1 if a_major == "k" else 0
243
+ b_leading = 1 if b_major == "k" else 0
244
+ d_leading = 1 if d_major == "n" else 0
245
+ c_leading = 1 if c_major == "n" else 0
246
+ div_a, div_b = div_for_dtype(a_dtype), div_for_dtype(b_dtype)
247
+ div_d, div_c = div_for_dtype(d_dtype), div_for_dtype(c_dtype) if c_dtype else 1
248
+ mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a)
249
+ mB = fake_tensor(b_dtype, (m, k, l), leading_dim=b_leading, divisibility=div_b)
250
+ mD = fake_tensor(d_dtype, (m, m, l), leading_dim=d_leading, divisibility=div_d)
251
+ mC = fake_tensor(c_dtype, (m, m, l), leading_dim=c_leading, divisibility=div_c)
252
+ # PostAct = D.mT, so it has the opposite major from D (m↔n swapped)
253
+ div_pa = div_for_dtype(postact_dtype)
254
+ postact_leading = 1 if postact_major == "n" else 0
255
+ mPostAct = fake_tensor(
256
+ postact_dtype, (m, m, l), leading_dim=postact_leading, divisibility=div_pa
257
+ )
258
+
259
+ def fake_scalar(mode):
260
+ if mode == 0:
261
+ return None
262
+ elif mode == 1:
263
+ return Float32(1.0)
264
+ else:
265
+ return make_ptr(Float32, 0, cute.AddressSpace.gmem, assumed_align=4)
266
+
267
+ activation = None # identity
268
+ act_fn = act_fn_map[activation]
269
+ epi_args = GemmCls.EpilogueArguments(
270
+ mPostAct,
271
+ act_fn,
272
+ alpha=fake_scalar(alpha_mode),
273
+ beta=fake_scalar(beta_mode),
274
+ )
275
+ scheduler_args = make_fake_scheduler_args(
276
+ (is_dynamic_persistent and device_capacity[0] == 9), False, l
277
+ )
278
+ varlen_args = None
279
+ return compile_gemm_kernel(
280
+ GemmCls,
281
+ a_dtype,
282
+ tile_shape_mn,
283
+ cluster_shape_mnk,
284
+ pingpong,
285
+ persistent,
286
+ False,
287
+ is_dynamic_persistent,
288
+ device_capacity,
289
+ mA,
290
+ mB,
291
+ mD,
292
+ mC,
293
+ epi_args,
294
+ scheduler_args,
295
+ varlen_args,
296
+ )
297
+
298
+
299
  def gemm_symmetric(
300
  A: Tensor, # (l, m, k)
301
  B: Tensor, # (l, m, k)
 
308
  cluster_N: int,
309
  pingpong: bool = False,
310
  persistent: bool = True,
311
+ is_dynamic_persistent: bool = False,
312
  max_swizzle_size: int = 8,
313
  alpha: float | Tensor = 1.0,
314
  beta: float | Tensor = 1.0,
315
  ) -> None:
316
+ # Transpose D so the "activation" is a write to the mirrored tile
317
  PostAct = D.mT
318
 
319
+ A_p, B_p, D_p, C_p = perm3d(A, B, D, C)
320
+ PostAct_p = PostAct.permute(1, 2, 0) if PostAct.ndim == 3 else PostAct
321
+ a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
322
+ a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
323
+ postact_dtype = torch2cute_dtype_map[PostAct.dtype]
324
+ # PostAct = D.mT has swapped major: if D is n-major, PostAct is m-major
325
+ postact_major = "n" if PostAct_p.stride(1) == 1 else "m"
 
 
 
 
 
 
 
326
 
327
  device_capacity = get_device_capacity(A.device)
328
+ assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
329
+
330
+ if is_dynamic_persistent and device_capacity[0] == 9:
331
+ assert tile_count_semaphore is not None, (
332
+ "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
333
+ )
334
 
 
335
  tile_shape_mn = (tile_M, tile_N)
336
  cluster_shape_mnk = (cluster_M, cluster_N, 1)
337
+ alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0)
338
+ beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0)
 
 
 
 
 
 
 
339
 
340
+ compiled_fn = _compile_gemm_symmetric(
341
+ a_dtype,
342
+ b_dtype,
343
+ d_dtype,
344
+ c_dtype,
345
+ c_major,
346
+ postact_dtype,
347
+ a_major,
348
+ b_major,
349
+ d_major,
350
+ postact_major,
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  tile_shape_mn,
352
  cluster_shape_mnk,
353
  pingpong,
354
  persistent,
355
+ is_dynamic_persistent,
356
+ alpha_mode,
357
+ beta_mode,
358
  device_capacity,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  )
360
 
361
+ from .cache_utils import COMPILE_ONLY
362
+
363
+ if COMPILE_ONLY:
364
+ return
365
+
366
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
367
+
368
+ def scalar_arg(scalar, mode):
369
+ if mode == 0:
370
+ return None
371
+ elif mode == 1:
372
+ return Float32(scalar)
373
+ else:
374
+ return scalar.data_ptr()
375
+
376
+ epi_args = GemmActMixin.EpilogueArguments(
377
+ PostAct_p,
378
+ None, # act_fn is Constexpr, baked in at compile time
379
+ alpha=scalar_arg(alpha, alpha_mode),
380
+ beta=scalar_arg(beta, beta_mode),
381
+ rounding_mode=None,
382
+ sr_seed=None,
383
+ )
384
+ scheduler_args = make_scheduler_args(
385
+ max_active_clusters,
386
+ max_swizzle_size,
387
+ tile_count_semaphore,
388
+ )
389
+ varlen_args = None
390
 
391
+ if device_capacity[0] in [10, 11]:
392
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
393
+ else:
394
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
build/torch-cuda/quack/gemm_tvm_ffi_utils.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ # Shared utilities for TVM-FFI GEMM compilation.
3
+
4
+ from functools import partial
5
+
6
+
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, Int64, Float32
9
+ from cutlass.cute.runtime import make_ptr
10
+
11
+ from .compile_utils import make_fake_tensor as fake_tensor
12
+ from .cute_dsl_utils import torch2cute_dtype_map
13
+ from .tile_scheduler import TileSchedulerOptions
14
+ from .varlen_utils import VarlenArguments
15
+
16
+
17
+ def div_for_dtype(dtype):
18
+ """16-byte alignment: divisibility in elements = 128 // dtype_width_bits."""
19
+ return 128 // dtype.width
20
+
21
+
22
+ def perm3d_single(t, varlen_m=False):
23
+ """Permute a single 3D tensor from (L, *, *) to (*, *, L), skipping for varlen_m or 2D."""
24
+ return t.permute(1, 2, 0) if t is not None and t.ndim == 3 and not varlen_m else t
25
+
26
+
27
+ def perm3d(A, B, D, C, varlen_m=False, varlen_k=False):
28
+ """Permute 3D tensors from (L, *, *) to (*, *, L)."""
29
+
30
+ def _perm(t):
31
+ return t.permute(1, 2, 0) if t is not None and t.ndim == 3 else t
32
+
33
+ if varlen_m:
34
+ return A, _perm(B), D, C
35
+ elif varlen_k:
36
+ return A, B, _perm(D), _perm(C)
37
+ else:
38
+ return _perm(A), _perm(B), _perm(D), _perm(C)
39
+
40
+
41
+ def get_major(t, dim0, dim1):
42
+ return dim1 if t.stride(1) == 1 else dim0
43
+
44
+
45
+ def get_majors(A_p, B_p, D_p, C_p):
46
+ a_major = get_major(A_p, "m", "k")
47
+ b_major = get_major(B_p, "n", "k")
48
+ d_major = get_major(D_p, "m", "n")
49
+ c_major = get_major(C_p, "m", "n") if C_p is not None else None
50
+ return a_major, b_major, d_major, c_major
51
+
52
+
53
+ def get_dtypes(A, B, D, C):
54
+ a_dtype = torch2cute_dtype_map[A.dtype]
55
+ b_dtype = torch2cute_dtype_map[B.dtype]
56
+ d_dtype = torch2cute_dtype_map[D.dtype]
57
+ c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None
58
+ return a_dtype, b_dtype, d_dtype, c_dtype
59
+
60
+
61
+ def make_scheduler_args(
62
+ max_active_clusters, max_swizzle_size, tile_count_semaphore, batch_idx_permute=None
63
+ ):
64
+ return TileSchedulerOptions(
65
+ max_active_clusters=Int32(max_active_clusters),
66
+ raster_order=None,
67
+ max_swizzle_size=max_swizzle_size,
68
+ tile_count_semaphore=(
69
+ tile_count_semaphore.data_ptr() if tile_count_semaphore is not None else None
70
+ ),
71
+ batch_idx_permute=batch_idx_permute,
72
+ )
73
+
74
+
75
+ def make_fake_scheduler_args(has_semaphore, has_batch_idx_permute, l_sym):
76
+ return TileSchedulerOptions(
77
+ max_active_clusters=Int32(1),
78
+ max_swizzle_size=Int32(8),
79
+ tile_count_semaphore=(
80
+ make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) if has_semaphore else None
81
+ ),
82
+ batch_idx_permute=(
83
+ fake_tensor(Int32, (l_sym,), leading_dim=0, divisibility=4)
84
+ if has_batch_idx_permute
85
+ else None
86
+ ),
87
+ )
88
+
89
+
90
+ def make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx):
91
+ if cu_seqlens_m is None and cu_seqlens_k is None:
92
+ return None
93
+ return VarlenArguments(
94
+ mCuSeqlensM=cu_seqlens_m,
95
+ mCuSeqlensK=cu_seqlens_k,
96
+ mAIdx=A_idx,
97
+ )
98
+
99
+
100
+ def make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len):
101
+ if not varlen_m and not varlen_k:
102
+ return None
103
+ num_seqlens = cute.sym_int()
104
+ return VarlenArguments(
105
+ mCuSeqlensM=(
106
+ fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_m else None
107
+ ),
108
+ mCuSeqlensK=(
109
+ fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_k else None
110
+ ),
111
+ mAIdx=(
112
+ fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4) if gather_A else None
113
+ ),
114
+ )
115
+
116
+
117
+ def make_fake_gemm_tensors(
118
+ a_dtype,
119
+ b_dtype,
120
+ d_dtype,
121
+ c_dtype,
122
+ a_major,
123
+ b_major,
124
+ d_major,
125
+ c_major,
126
+ varlen_m=False,
127
+ varlen_k=False,
128
+ gather_A=False,
129
+ ):
130
+ """Create fake tensors for mA, mB, mD, mC with shared sym_ints.
131
+ Pass dtype=None to get None for that tensor (e.g. optional C).
132
+ Returns (mA, mB, mD, mC, m, n, k, l).
133
+ When varlen_m, m is total_m (flattened M of D/C). When varlen_k, k is total_k.
134
+ """
135
+ a_leading = 1 if a_major == "k" else 0
136
+ b_leading = 1 if b_major == "k" else 0
137
+ d_leading = 1 if d_major == "n" else 0
138
+ c_leading = 1 if c_major == "n" else 0
139
+ m, n, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int(), cute.sym_int()
140
+ div_a = div_for_dtype(a_dtype)
141
+ div_b = div_for_dtype(b_dtype)
142
+ div_d = div_for_dtype(d_dtype) if d_dtype is not None else 1
143
+ div_c = div_for_dtype(c_dtype) if c_dtype is not None else 1
144
+ if varlen_m:
145
+ # m is total_m in this case: the flattened M dimension of D/C
146
+ m = cute.sym_int()
147
+ a_m = cute.sym_int() if gather_A else m
148
+ mA = fake_tensor(a_dtype, (a_m, k), leading_dim=a_leading, divisibility=div_a)
149
+ mB = fake_tensor(b_dtype, (n, k, l), leading_dim=b_leading, divisibility=div_b)
150
+ mD = fake_tensor(d_dtype, (m, n), leading_dim=d_leading, divisibility=div_d)
151
+ mC = fake_tensor(c_dtype, (m, n), leading_dim=c_leading, divisibility=div_c)
152
+ elif varlen_k:
153
+ # k is total_k in this case: the flattened K dimension of A/B
154
+ k = cute.sym_int()
155
+ a_k = cute.sym_int() if gather_A else k
156
+ mA = fake_tensor(a_dtype, (m, a_k), leading_dim=a_leading, divisibility=div_a)
157
+ mB = fake_tensor(b_dtype, (n, k), leading_dim=b_leading, divisibility=div_b)
158
+ mD = fake_tensor(d_dtype, (m, n, l), leading_dim=d_leading, divisibility=div_d)
159
+ mC = fake_tensor(c_dtype, (m, n, l), leading_dim=c_leading, divisibility=div_c)
160
+ else:
161
+ mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a)
162
+ mB = fake_tensor(b_dtype, (n, k, l), leading_dim=b_leading, divisibility=div_b)
163
+ mD = fake_tensor(d_dtype, (m, n, l), leading_dim=d_leading, divisibility=div_d)
164
+ mC = fake_tensor(c_dtype, (m, n, l), leading_dim=c_leading, divisibility=div_c)
165
+ return mA, mB, mD, mC, m, n, k, l
166
+
167
+
168
+ def compile_gemm_kernel(
169
+ GemmCls,
170
+ a_dtype,
171
+ tile_shape_mn,
172
+ cluster_shape_mnk,
173
+ pingpong,
174
+ persistent,
175
+ gather_A,
176
+ is_dynamic_persistent,
177
+ device_capacity,
178
+ mA,
179
+ mB,
180
+ mD,
181
+ mC,
182
+ epi_args,
183
+ scheduler_args,
184
+ varlen_args,
185
+ post_init=None,
186
+ mSFA=None,
187
+ mSFB=None,
188
+ has_trace_ptr=False,
189
+ use_tma_gather=False,
190
+ concat_layout=None,
191
+ ):
192
+ """Build GemmCls instance, apply SM90 partial, and cute.compile with TVM-FFI."""
193
+ if device_capacity[0] in [9, 12]:
194
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
195
+ elif device_capacity[0] in [10, 11]:
196
+ GemmCls = partial(
197
+ GemmCls,
198
+ use_clc_persistence=is_dynamic_persistent,
199
+ use_tma_gather=use_tma_gather,
200
+ )
201
+ gemm_obj = GemmCls(
202
+ Float32,
203
+ a_dtype,
204
+ tile_shape_mn,
205
+ cluster_shape_mnk,
206
+ gather_A=gather_A,
207
+ concat_layout=concat_layout,
208
+ )
209
+ if post_init:
210
+ post_init(gemm_obj)
211
+ stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
212
+ sf_args = () if device_capacity[0] in (9, 12) else (mSFA, mSFB)
213
+ # Trace pointer: Optional[Int64]. Compile with Int64(0) when tracing is
214
+ # requested, None otherwise. TVM-FFI caches each variant separately.
215
+ trace_ptr = Int64(0) if has_trace_ptr else None
216
+ return cute.compile(
217
+ gemm_obj,
218
+ mA,
219
+ mB,
220
+ mD,
221
+ mC,
222
+ epi_args,
223
+ scheduler_args,
224
+ varlen_args,
225
+ stream,
226
+ *sf_args,
227
+ trace_ptr,
228
+ options="--enable-tvm-ffi",
229
+ )
build/torch-cuda/quack/gemm_wrapper_utils.py DELETED
@@ -1,317 +0,0 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- from typing import Optional, Tuple, Dict, Any
3
- from dataclasses import dataclass
4
-
5
- import torch
6
- from torch import Tensor
7
-
8
- import cutlass.cute as cute
9
- from cutlass import Int32
10
- from cutlass.cute.runtime import from_dlpack, make_ptr
11
-
12
- from .cute_dsl_utils import torch2cute_dtype_map
13
- from .varlen_utils import VarlenArguments
14
- from .tile_scheduler import TileSchedulerOptions
15
-
16
-
17
- @dataclass
18
- class GemmTensorInfo:
19
- tensor: Optional[Tensor]
20
- dtype: Optional[Any] = None
21
- major: Optional[str] = None
22
- cute_tensor: Optional[cute.Tensor] = None
23
-
24
-
25
- class GemmWrapperBase:
26
- @staticmethod
27
- def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
28
- assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
29
- assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
30
-
31
- @staticmethod
32
- def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
33
- assert tensor.shape == expected_shape, (
34
- f"{name} must have shape {expected_shape}, got {tensor.shape}"
35
- )
36
-
37
- @staticmethod
38
- def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
39
- # Tensor is already permuted to (dims[0], dims[1], dims[2])
40
- # stride(1) == 1 means dims[1] is contiguous (innermost)
41
- return dims[1] if tensor.stride(1) == 1 else dims[0]
42
-
43
- @staticmethod
44
- def create_cute_tensor(
45
- tensor: Optional[Tensor],
46
- major: Optional[str],
47
- dims: Tuple[str, str, str],
48
- assumed_align: int = 16,
49
- ) -> Optional[cute.Tensor]:
50
- if tensor is None:
51
- return None
52
- # Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
53
- # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
54
- leading_dim = 1 if major == dims[1] else 0
55
- return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
56
- leading_dim=leading_dim
57
- )
58
-
59
- @staticmethod
60
- def validate_and_prepare_tensors(
61
- A: Tensor,
62
- B: Tensor,
63
- D: Optional[Tensor] = None,
64
- C: Optional[Tensor] = None,
65
- additional_tensors: Optional[Dict[str, Tensor]] = None,
66
- cu_seqlens_m: Optional[Tensor] = None,
67
- cu_seqlens_k: Optional[Tensor] = None,
68
- A_idx: Optional[Tensor] = None,
69
- ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
70
- assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
71
- "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
72
- )
73
- assert B.dtype == A.dtype, "A and B must have the same dtype"
74
-
75
- # Validate A_idx if provided (for gather_A case)
76
- gather_A = A_idx is not None
77
- if gather_A:
78
- assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
79
- "gather_A requires either varlen_m or varlen_k"
80
- )
81
- assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
82
- assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
83
-
84
- # Determine mode and extract dimensions
85
- if cu_seqlens_m is not None:
86
- # varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
87
- assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
88
- assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
89
-
90
- if gather_A:
91
- # When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
92
- total_M = A_idx.shape[0]
93
- _, K = A.shape
94
- else:
95
- total_M, K = A.shape
96
-
97
- L, N, K_B = B.shape
98
- assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
99
- assert cu_seqlens_m.shape == (L + 1,), (
100
- f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
101
- )
102
- M = total_M
103
- dc_shape = (total_M, N)
104
- dc_ndim = 2
105
- elif cu_seqlens_k is not None:
106
- # varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
107
- assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
108
- assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
109
-
110
- if gather_A:
111
- # When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
112
- M, _ = A.shape
113
- total_K = A_idx.shape[0]
114
- else:
115
- M, total_K = A.shape
116
-
117
- N, K_B = B.shape
118
- assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
119
- L = cu_seqlens_k.shape[0] - 1
120
- assert cu_seqlens_k.shape == (L + 1,), (
121
- f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
122
- )
123
- K = total_K
124
- dc_shape = (L, M, N)
125
- dc_ndim = 3
126
- else:
127
- # Normal case - all tensors must be 3D
128
- GemmWrapperBase.validate_tensor(A, "A", 3)
129
- GemmWrapperBase.validate_tensor(B, "B", 3)
130
- L, M, K = A.shape
131
- _, N, K_B = B.shape
132
- assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
133
- GemmWrapperBase.validate_shape(B, (L, N, K), "B")
134
- dc_shape = (L, M, N)
135
- dc_ndim = 3
136
-
137
- # Validate D and C shapes uniformly
138
- for tensor, name in [(D, "D"), (C, "C")]:
139
- if tensor is not None:
140
- assert tensor.dim() == dc_ndim, (
141
- f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
142
- )
143
- assert tensor.shape == dc_shape, (
144
- f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
145
- )
146
-
147
- tensors = {
148
- "A": GemmTensorInfo(A),
149
- "B": GemmTensorInfo(B),
150
- "D": GemmTensorInfo(D),
151
- "C": GemmTensorInfo(C),
152
- }
153
-
154
- if additional_tensors:
155
- for name, tensor in additional_tensors.items():
156
- if tensor is not None:
157
- assert tensor.dim() == dc_ndim, (
158
- f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
159
- )
160
- assert tensor.shape == dc_shape, (
161
- f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
162
- )
163
- tensors[name] = GemmTensorInfo(tensor)
164
-
165
- return L, M, K, N, tensors
166
-
167
- @staticmethod
168
- def permute_tensors(
169
- tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
170
- ) -> None:
171
- # Determine which tensors need permutation
172
- if varlen_m:
173
- # Only B needs permutation (3D tensor)
174
- tensors_to_permute = ["B"]
175
- elif varlen_k:
176
- # Only D and C need permutation (3D tensors)
177
- tensors_to_permute = ["D", "C"]
178
- else:
179
- # All tensors need permutation
180
- tensors_to_permute = None
181
-
182
- # Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
183
- for name, info in tensors.items():
184
- if info.tensor is not None and info.tensor.ndim == 3:
185
- if tensors_to_permute is None or name in tensors_to_permute:
186
- info.tensor = info.tensor.permute(1, 2, 0)
187
-
188
- @staticmethod
189
- def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
190
- for name, info in tensors.items():
191
- if info.tensor is not None:
192
- info.dtype = torch2cute_dtype_map[info.tensor.dtype]
193
-
194
- @staticmethod
195
- def determine_major_orders(
196
- tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
197
- ) -> None:
198
- for name, dims in major_configs.items():
199
- if name in tensors and tensors[name].tensor is not None:
200
- tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
201
-
202
- @staticmethod
203
- def create_cute_tensors(
204
- tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
205
- ) -> None:
206
- for name, info in tensors.items():
207
- if info.tensor is not None and name in major_configs:
208
- info.cute_tensor = GemmWrapperBase.create_cute_tensor(
209
- info.tensor, info.major, major_configs[name]
210
- )
211
-
212
- @staticmethod
213
- def create_scheduler_args(
214
- max_active_clusters: int,
215
- tile_count_semaphore: Optional[Tensor] = None,
216
- batch_idx_permute: Optional[Tensor] = None,
217
- max_swizzle_size: int = 8,
218
- ) -> TileSchedulerOptions:
219
- return TileSchedulerOptions(
220
- Int32(max_active_clusters),
221
- tile_count_semaphore=make_ptr(
222
- Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
223
- )
224
- if tile_count_semaphore is not None
225
- else None,
226
- batch_idx_permute=(
227
- from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
228
- )
229
- if batch_idx_permute is not None
230
- else None,
231
- max_swizzle_size=Int32(max_swizzle_size),
232
- )
233
-
234
- @staticmethod
235
- def create_varlen_args(
236
- cu_seqlens_m: Optional[Tensor],
237
- cu_seqlens_k: Optional[Tensor],
238
- A_idx: Optional[Tensor],
239
- max_active_clusters: int,
240
- cluster_shape_mnk: Tuple[int, int, int],
241
- tensors: Dict[str, GemmTensorInfo],
242
- num_epi_tensormaps: int = 0,
243
- pingpong: bool = False,
244
- ) -> Optional[Any]:
245
- if cu_seqlens_m is None and cu_seqlens_k is None:
246
- return None
247
- # When varlen_m, we assume persistent=True
248
- # Grid size depends on num_active_clusters and cluster size
249
- cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
250
- num_blocks = max_active_clusters * cluster_size
251
- # Calculate number of tensormaps needed
252
- if cu_seqlens_m is not None:
253
- # For varlen_m: need tensormaps for D and epilogue tensors
254
- num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
255
- if tensors["D"].tensor is not None:
256
- num_tensormaps += 1 if not pingpong else 2 # D tensormap
257
- else:
258
- # For varlen_k: need tensormaps for A & B
259
- num_tensormaps = 2 if A_idx is None else 1
260
- # Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
261
- tensormap_size = 128 // 8 # 16 int64s
262
- if num_tensormaps > 0:
263
- device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
264
- tensormaps = torch.empty(
265
- (num_blocks, num_tensormaps, tensormap_size),
266
- dtype=torch.int64,
267
- device=device,
268
- )
269
- tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
270
- mode=0, stride_order=(0, 1, 2)
271
- )
272
- else:
273
- tensormaps_cute = None
274
-
275
- return VarlenArguments(
276
- mCuSeqlensM=(
277
- from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
278
- if cu_seqlens_m is not None
279
- else None
280
- ),
281
- mCuSeqlensK=(
282
- from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
283
- if cu_seqlens_k is not None
284
- else None
285
- ),
286
- mTensormaps=tensormaps_cute,
287
- mAIdx=(
288
- from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
289
- if A_idx is not None
290
- else None
291
- ),
292
- )
293
-
294
- @staticmethod
295
- def get_compile_key(
296
- tensors: Dict[str, GemmTensorInfo],
297
- activation: Optional[str],
298
- tile_shape_mn: Tuple[int, int],
299
- cluster_shape_mnk: Tuple[int, int, int],
300
- pingpong: bool,
301
- persistent: bool,
302
- has_semaphore: bool,
303
- *args,
304
- key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
305
- ) -> Tuple:
306
- key_parts = []
307
- for name in key_tensor_names:
308
- if name in tensors:
309
- key_parts.append(tensors[name].dtype)
310
- key_parts.append(activation)
311
- key_parts.extend([tile_shape_mn, cluster_shape_mnk])
312
- for name in key_tensor_names:
313
- if name in tensors:
314
- key_parts.append(tensors[name].major)
315
- key_parts.extend([pingpong, persistent, has_semaphore])
316
- key_parts.extend(args)
317
- return tuple(key_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/quack/layout_utils.py CHANGED
@@ -6,8 +6,6 @@ import cutlass.cute as cute
6
 
7
  from cutlass import Int32, const_expr
8
 
9
- from .utils import prmt
10
-
11
 
12
  def transpose_view(a: cute.Tensor) -> cute.Tensor:
13
  """Transpose the first two dimensions of a tensor on smem."""
@@ -20,6 +18,19 @@ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
20
  return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
24
  shape = (*a.shape[:dim], size, *a.shape[dim:])
25
  stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
@@ -55,8 +66,8 @@ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
55
  lower0 = lower if lane_03 else upper
56
  upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
57
  lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
58
- t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
59
- t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
60
 
61
 
62
  @cute.jit
@@ -154,41 +165,43 @@ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
154
  )
155
 
156
 
157
- def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
158
  """
159
  For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
160
  For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
161
  """
162
  acc_layout_col_major = cute.make_layout(acc_layout.shape)
163
- acc_layout_mn = cute.make_layout(
 
164
  (
165
- (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
166
- (
167
- acc_layout_col_major.shape[0][0],
168
- *acc_layout_col_major.shape[0][2:],
169
- acc_layout_col_major.shape[2],
170
- ), # MMA_N
171
- *acc_layout_col_major.shape[3:],
172
- ),
173
- stride=(
174
- (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
175
- (
176
- acc_layout_col_major.stride[0][0],
177
- *acc_layout_col_major.stride[0][2:],
178
- acc_layout_col_major.stride[2],
179
- ), # MMA_N
180
- *acc_layout_col_major.stride[3:],
181
- ),
182
  )
 
 
 
 
183
  return cute.composition(acc_layout, acc_layout_mn)
184
 
185
 
186
- def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
187
- return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
188
 
189
 
190
- def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
191
- return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
192
 
193
 
194
  @cute.jit
@@ -196,10 +209,12 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
196
  # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
197
  # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
198
  # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
 
199
  # TODO: Sm90 FP8
200
  if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
 
201
  l = cute.logical_divide(
202
- acc_layout, ((None, None, 2), None, None)
203
  ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
204
  rA_mma_view = cute.make_layout(
205
  (
@@ -293,3 +308,77 @@ def mma_partition_A_vec(
293
  sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
294
  tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
295
  return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from cutlass import Int32, const_expr
8
 
 
 
9
 
10
  def transpose_view(a: cute.Tensor) -> cute.Tensor:
11
  """Transpose the first two dimensions of a tensor on smem."""
 
18
  return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
19
 
20
 
21
+ def concat_to_interleave(a: cute.Tensor, dim: int) -> cute.Tensor:
22
+ """Reshape a concat [first_half; second_half] layout to interleaved along `dim`.
23
+
24
+ Splits dimension `dim` (size 2N) into hierarchical (2, N) so that elements
25
+ from the first half and second half alternate: [first_0, second_0, first_1, ...].
26
+ Used to convert gated MLP weight layout from concat [gate; up] to interleaved.
27
+ """
28
+ half = cute.size(a, mode=[dim]) // 2
29
+ shape = (*a.shape[:dim], (2, half), *a.shape[dim + 1 :])
30
+ stride = (*a.stride[:dim], (half * a.stride[dim], a.stride[dim]), *a.stride[dim + 1 :])
31
+ return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
32
+
33
+
34
  def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
35
  shape = (*a.shape[:dim], size, *a.shape[dim:])
36
  stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
 
66
  lower0 = lower if lane_03 else upper
67
  upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
68
  lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
69
+ t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
70
+ t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
71
 
72
 
73
  @cute.jit
 
165
  )
166
 
167
 
168
+ def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
169
  """
170
  For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
171
  For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
172
  """
173
  acc_layout_col_major = cute.make_layout(acc_layout.shape)
174
+ shape = (
175
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
176
  (
177
+ acc_layout_col_major.shape[0][0],
178
+ *acc_layout_col_major.shape[0][2:],
179
+ acc_layout_col_major.shape[2],
180
+ ), # MMA_N
181
+ *acc_layout_col_major.shape[3:],
182
+ )
183
+ stride = (
184
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
185
+ (
186
+ acc_layout_col_major.stride[0][0],
187
+ *acc_layout_col_major.stride[0][2:],
188
+ acc_layout_col_major.stride[2],
189
+ ), # MMA_N
190
+ *acc_layout_col_major.stride[3:],
 
 
 
191
  )
192
+ if const_expr(transpose):
193
+ shape = (shape[1], shape[0], *shape[2:])
194
+ stride = (stride[1], stride[0], *stride[2:])
195
+ acc_layout_mn = cute.make_layout(shape, stride=stride)
196
  return cute.composition(acc_layout, acc_layout_mn)
197
 
198
 
199
+ def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
200
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
201
 
202
 
203
+ def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
204
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
205
 
206
 
207
  @cute.jit
 
209
  # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
210
  # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
211
  # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
212
+ # If N / 8 is odd, we'll convert to ((2, 2, 1), MMA_M, N / 8, MMA_N).
213
  # TODO: Sm90 FP8
214
  if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
215
+ div = 2 if const_expr(acc_layout.shape[0][2] % 2 == 0) else 1
216
  l = cute.logical_divide(
217
+ acc_layout, ((None, None, div), None, None)
218
  ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
219
  rA_mma_view = cute.make_layout(
220
  (
 
308
  sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
309
  tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
310
  return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
311
+
312
+
313
+ def copy_partition_S_vec(
314
+ sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
315
+ ) -> cute.Tensor:
316
+ assert cute.rank(sVec) == 2
317
+ assert sVec.stride[0] == 1
318
+ stage = sVec.shape[1]
319
+ shape = (
320
+ (sVec.shape[0], expand_shape, stage)
321
+ if const_expr(is_colvec)
322
+ else (expand_shape, sVec.shape[0], stage)
323
+ )
324
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
325
+ sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
326
+ tC_sVec = reshape_acc_to_mn(thr_copy.partition_S(sVec_thr))
327
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
328
+
329
+
330
+ def copy_partition_D_vec(
331
+ sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
332
+ ) -> cute.Tensor:
333
+ assert cute.rank(sVec) == 2
334
+ assert sVec.stride[0] == 1
335
+ stage = sVec.shape[1]
336
+ shape = (
337
+ (sVec.shape[0], expand_shape, stage)
338
+ if const_expr(is_colvec)
339
+ else (expand_shape, sVec.shape[0], stage)
340
+ )
341
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
342
+ sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
343
+ tC_sVec = reshape_acc_to_mn(thr_copy.partition_D(sVec_thr))
344
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
345
+
346
+
347
+ def tile_atom_to_shape_SF_strided(
348
+ shape: cute.Shape,
349
+ sf_vec_size: int,
350
+ sf_strides,
351
+ ) -> cute.Layout:
352
+ """Build an SFA/SFB layout matching `shape` (A or B operand shape) but
353
+ honoring the scale tensor's actual strides instead of hardcoded packed
354
+ ones.
355
+
356
+ Mirrors `cutlass.utils.blockscaled_layout.tile_atom_to_shape_SF(shape,
357
+ sf_vec_size)`, except outer-mode strides come from `sf_strides` (pass
358
+ `mSFA.stride` / `mSFB.stride` directly). The inner 512-B atom
359
+ `((32, 4), (sf_vec_size, 4)) : ((16, 4), (0, 1))` is hardware-fixed.
360
+
361
+ Implementation uses `cute.blocked_product(atom, outer)`; `blocked_product`
362
+ scales the outer layout's strides by `cosize(atom) == 512`, so we divide
363
+ the byte strides by 512 (one tile) before handing them in.
364
+
365
+ Args:
366
+ shape: A/B operand shape. Rank-3 `(m/n, k, l)` or rank-2
367
+ `(total_mn, k)` (varlen_m).
368
+ sf_vec_size: Scale factor vector size (16 or 32).
369
+ sf_strides: Strides of the scale tensor, which has logical shape
370
+ `(L, rmn, rk, 512)` (rank 4). Only `sf_strides[0..2]` are used:
371
+ `sf_strides[1]` as the rmn stride, `sf_strides[2]` as the rk
372
+ stride, and `sf_strides[0]` as the L stride (only for rank-3
373
+ `shape`).
374
+ """
375
+ from cutlass.utils.blockscaled_layout import BlockScaledBasicChunk
376
+
377
+ atom = BlockScaledBasicChunk(sf_vec_size).layout
378
+ rmn = cute.ceil_div(shape[0], 128)
379
+ rk = cute.ceil_div(shape[1], sf_vec_size * 4)
380
+ outer = cute.make_layout((rmn, rk), stride=(sf_strides[1] // 512, sf_strides[2] // 512))
381
+ sf_layout = cute.blocked_product(atom, outer)
382
+ if const_expr(len(shape) == 3):
383
+ sf_layout = cute.append(sf_layout, cute.make_layout(shape[2], stride=sf_strides[0]))
384
+ return sf_layout
build/torch-cuda/quack/linear.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+
10
+ from .gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
11
+ from .gemm_interface import gemm_gated, gemm_dgated
12
+ from .gemm_interface import act_to_pytorch_fn_map, gated_to_pytorch_fn_map
13
+
14
+
15
+ def _ensure_contiguous(t):
16
+ """Ensure last-dim stride is 1. Under torch.compile use unconditional .contiguous()
17
+ (dynamo can't inspect strides on fake tensors); otherwise check first to avoid copies.
18
+ """
19
+ if torch.compiler.is_compiling():
20
+ return t.contiguous()
21
+ return t if t.stride(-1) == 1 else t.contiguous()
22
+
23
+
24
+ def linear_fwd_convert_type(*tensors):
25
+ autocast_dtype = torch.get_autocast_dtype("cuda")
26
+ if torch.is_autocast_enabled():
27
+ tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
28
+ return tensors
29
+
30
+
31
+ def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
32
+ needs_input_grad, needs_weight_grad = needs_x_w_grad
33
+ if not needs_input_grad:
34
+ weight, weight_og = None, None
35
+ if not needs_weight_grad:
36
+ x = None
37
+ ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
38
+
39
+
40
+ def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
41
+ if ctx.needs_input_grad[0]:
42
+ assert weight is not None
43
+ return matmul_fn(dout, weight)
44
+ else:
45
+ return None
46
+
47
+
48
+ def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
49
+ if ctx.needs_input_grad[1]:
50
+ assert x is not None
51
+ x = x.reshape(-1, x.shape[-1])
52
+ # fuse_grad_accum is not compatible with torch.compile
53
+ if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
54
+ dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
55
+ else:
56
+ # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
57
+ matmul_inplace_fn(dout.T, x, weight_og.grad)
58
+ dweight = weight_og.grad
59
+ weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
60
+ else:
61
+ dweight = None
62
+ return dweight
63
+
64
+
65
+ def _recompute_act_postact(preact, activation):
66
+ """Recompute postact from preact using the activation function (no GEMM)."""
67
+ return act_to_pytorch_fn_map[activation](preact)
68
+
69
+
70
+ def _recompute_gated_postact(preact, activation):
71
+ """Recompute gated postact from interleaved preact (no GEMM)."""
72
+ return gated_to_pytorch_fn_map[activation](preact[..., ::2], preact[..., 1::2])
73
+
74
+
75
+ # --- Ops bundles: matmul function configurations ---
76
+ # Each ops class is a namespace holding the matmul functions for a specific variant
77
+ # (tuned/untuned, act/gated, etc.). Passed as a non-tensor arg to apply() and stored on ctx.
78
+
79
+
80
+ class _LinearOps:
81
+ matmul_fwd_fn = gemm
82
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
83
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
84
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
85
+
86
+
87
+ class _LinearUntunedOps(_LinearOps):
88
+ matmul_fwd_fn = partial(gemm, tuned=False)
89
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
90
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
91
+
92
+
93
+ class _LinearActOps(_LinearOps):
94
+ matmul_fwd_fn = gemm_act
95
+
96
+
97
+ class _LinearActUntunedOps(_LinearUntunedOps):
98
+ matmul_fwd_fn = partial(gemm_act, tuned=False)
99
+
100
+
101
+ class _LinearGatedOps(_LinearOps):
102
+ matmul_fwd_fn = gemm_gated
103
+
104
+
105
+ class _LinearGatedUntunedOps:
106
+ matmul_fwd_fn = partial(gemm_gated, tuned=False)
107
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
108
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
109
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
110
+
111
+
112
+ class _LinearGatedConcatOps(_LinearGatedOps):
113
+ matmul_fwd_fn = partial(gemm_gated, concat_layout=("B", "bias"))
114
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, concat_layout=("B",))
115
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, concat_layout=("out",))
116
+ matmul_bwd_dw_inplace = partial(
117
+ gemm_add_inplace, dynamic_scheduler=True, concat_layout=("C", "out")
118
+ )
119
+
120
+
121
+ class _LinearGatedConcatUntunedOps(_LinearGatedUntunedOps):
122
+ matmul_fwd_fn = partial(gemm_gated, tuned=False, concat_layout=("B", "bias"))
123
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("B",))
124
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("out",))
125
+ matmul_bwd_dw_inplace = partial(
126
+ gemm_add_inplace, dynamic_scheduler=True, tuned=False, concat_layout=("C", "out")
127
+ )
128
+
129
+
130
+ class _DActLinearOps(_LinearOps):
131
+ matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
132
+ recompute_postact = staticmethod(_recompute_act_postact)
133
+
134
+
135
+ class _DActLinearUntunedOps(_LinearUntunedOps):
136
+ matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
137
+ recompute_postact = staticmethod(_recompute_act_postact)
138
+
139
+
140
+ class _DGatedLinearOps(_LinearOps):
141
+ matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True)
142
+ recompute_postact = staticmethod(_recompute_gated_postact)
143
+
144
+
145
+ class _DGatedLinearUntunedOps(_LinearUntunedOps):
146
+ matmul_bwd_dx = partial(gemm_dgated, dynamic_scheduler=True, tuned=False)
147
+ recompute_postact = staticmethod(_recompute_gated_postact)
148
+
149
+
150
+ # --- Autograd Functions (all @staticmethod, torch.compile-compatible) ---
151
+
152
+
153
+ class LinearFunc(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, x, weight, bias, fuse_grad_accum, ops):
156
+ """
157
+ x: (..., in_features)
158
+ weight: (out_features, in_features)
159
+ bias: (out_features,) or None
160
+ out: (..., out_features)
161
+ """
162
+ # Convert types while autocast is still enabled, then disable it for the body.
163
+ x, weight = linear_fwd_convert_type(x, weight)
164
+ with torch.amp.autocast("cuda", enabled=False):
165
+ ctx.weight_dtype = weight.dtype
166
+ ctx.fuse_grad_accum = fuse_grad_accum
167
+ ctx.ops = ops
168
+ weight_og = weight
169
+ batch_shape = x.shape[:-1]
170
+ x = x.reshape(-1, x.shape[-1])
171
+ out = ops.matmul_fwd_fn(x, weight.T, bias=bias)
172
+ linear_fwd_postprocess(
173
+ ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
174
+ )
175
+ ctx.bias_dtype = bias.dtype if bias is not None else None
176
+ ctx.compute_dbias = bias is not None and ctx.needs_input_grad[2]
177
+ return out.reshape(*batch_shape, out.shape[-1])
178
+
179
+ @staticmethod
180
+ def backward(ctx, dout):
181
+ """
182
+ dout: (..., out_features)
183
+ """
184
+ with torch.amp.autocast("cuda", enabled=False):
185
+ ops = ctx.ops
186
+ x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
187
+ batch_shape = dout.shape[:-1]
188
+ dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
189
+ dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None
190
+ dx = linear_bwd_compute_input_grad(ctx, dout, weight, ops.matmul_bwd_dx)
191
+ dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
192
+ dweight = linear_bwd_compute_weight_grad(
193
+ ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace
194
+ )
195
+ return dx, dweight, dbias, None, None
196
+
197
+
198
+ def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
199
+ ops = _LinearOps if tuned else _LinearUntunedOps
200
+ return LinearFunc.apply(x, weight, bias, fuse_grad_accum, ops)
201
+
202
+
203
+ class LinearActFunc(torch.autograd.Function):
204
+ @staticmethod
205
+ def forward(ctx, x, weight, activation, bias, store_preact, fuse_grad_accum, ops):
206
+ """
207
+ x: (..., in_features)
208
+ weight: (out_features, in_features)
209
+ bias: (out_features,) or None
210
+ out: (..., out_features)
211
+ Return both out and post-activation, but only out is differentiable.
212
+ """
213
+ x, weight = linear_fwd_convert_type(x, weight)
214
+ with torch.amp.autocast("cuda", enabled=False):
215
+ ctx.weight_dtype = weight.dtype
216
+ ctx.fuse_grad_accum = fuse_grad_accum
217
+ ctx.ops = ops
218
+ weight_og = weight
219
+ batch_shape = x.shape[:-1]
220
+ x = x.reshape(-1, x.shape[-1])
221
+ out, postact = ops.matmul_fwd_fn(
222
+ x, weight.T, bias=bias, activation=activation, store_preact=store_preact
223
+ )
224
+ linear_fwd_postprocess(
225
+ ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
226
+ )
227
+ if out is not None:
228
+ out = out.reshape(*batch_shape, out.shape[-1])
229
+ ctx.bias_dtype = bias.dtype if bias is not None else None
230
+ ctx.compute_dbias = bias is not None and ctx.needs_input_grad[3]
231
+ ctx.mark_non_differentiable(postact)
232
+ ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
233
+ return out, postact.reshape(*batch_shape, postact.shape[-1])
234
+
235
+ @staticmethod
236
+ def backward(ctx, dout, *args):
237
+ with torch.amp.autocast("cuda", enabled=False):
238
+ ops = ctx.ops
239
+ x, weight, weight_og = ctx.saved_tensors
240
+ batch_shape = dout.shape[:-1]
241
+ dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
242
+ dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None
243
+ dx = linear_bwd_compute_input_grad(ctx, dout, weight, ops.matmul_bwd_dx)
244
+ dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
245
+ dweight = linear_bwd_compute_weight_grad(
246
+ ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace
247
+ )
248
+ return dx, dweight, None, dbias, None, None, None
249
+
250
+
251
+ def linear_act_func(
252
+ x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
253
+ ):
254
+ ops = _LinearActOps if tuned else _LinearActUntunedOps
255
+ return LinearActFunc.apply(x, weight, activation, bias, store_preact, fuse_grad_accum, ops)
256
+
257
+
258
+ def linear_gated_func(
259
+ x,
260
+ weight,
261
+ activation,
262
+ bias=None,
263
+ store_preact=True,
264
+ fuse_grad_accum=False,
265
+ tuned=True,
266
+ concat_layout=False,
267
+ ):
268
+ if concat_layout:
269
+ ops = _LinearGatedConcatOps if tuned else _LinearGatedConcatUntunedOps
270
+ else:
271
+ ops = _LinearGatedOps if tuned else _LinearGatedUntunedOps
272
+ return LinearActFunc.apply(x, weight, activation, bias, store_preact, fuse_grad_accum, ops)
273
+
274
+
275
+ class DActLinearFunc(torch.autograd.Function):
276
+ @staticmethod
277
+ def forward(ctx, preact, weight, x, activation, bias, fuse_grad_accum, ops):
278
+ """
279
+ x: (..., in_features)
280
+ weight: (out_features, in_features)
281
+ bias: (out_features,) or None
282
+ out: (..., out_features)
283
+ Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
284
+ """
285
+ x, weight = linear_fwd_convert_type(x, weight)
286
+ with torch.amp.autocast("cuda", enabled=False):
287
+ ctx.weight_dtype = weight.dtype
288
+ ctx.fuse_grad_accum = fuse_grad_accum
289
+ ctx.ops = ops
290
+ weight_og = weight
291
+ batch_shape = x.shape[:-1]
292
+ x = x.reshape(-1, x.shape[-1])
293
+ out = ops.matmul_fwd_fn(x, weight.T, bias=bias)
294
+ # Store preact instead of x, we will recompute x (postact) in backward.
295
+ # dpreact needs gemm_dact(dout, weight, preact) → needs both weight and preact.
296
+ # dweight needs postact: if dpreact is also needed, postact comes from gemm_dact;
297
+ # otherwise we can recompute postact = act(preact) cheaply without weight.
298
+ need_preact = ctx.needs_input_grad[0] or ctx.needs_input_grad[1]
299
+ need_weight = ctx.needs_input_grad[0] # only gemm_dact needs weight
300
+ linear_fwd_postprocess(
301
+ ctx, preact, weight, weight_og, needs_x_w_grad=(need_weight, need_preact)
302
+ )
303
+ ctx.activation = activation
304
+ ctx.bias_dtype = bias.dtype if bias is not None else None
305
+ ctx.compute_dbias = bias is not None and ctx.needs_input_grad[4]
306
+ return out.reshape(*batch_shape, out.shape[-1])
307
+
308
+ @staticmethod
309
+ def backward(ctx, dout):
310
+ """
311
+ dout: (..., out_features)
312
+ """
313
+ with torch.amp.autocast("cuda", enabled=False):
314
+ ops = ctx.ops
315
+ # weight_og is None if not ctx.fuse_grad_accum
316
+ preact, weight, weight_og = ctx.saved_tensors
317
+ batch_shape = dout.shape[:-1]
318
+ dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
319
+ dbias = dout.sum(0, dtype=ctx.bias_dtype) if ctx.compute_dbias else None
320
+ if ctx.needs_input_grad[0]:
321
+ # Need dpreact: gemm_dact(dout, weight, preact) → (dpreact, postact)
322
+ preact = preact.reshape(-1, preact.shape[-1])
323
+ assert weight is not None
324
+ dpreact, x = ops.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
325
+ elif ctx.needs_input_grad[1]:
326
+ # Only need dweight: recompute postact from preact cheaply (no GEMM needed)
327
+ preact = preact.reshape(-1, preact.shape[-1])
328
+ x = ops.recompute_postact(preact, ctx.activation)
329
+ dpreact = None
330
+ else:
331
+ dpreact, x = None, None
332
+ dpreact = (
333
+ dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
334
+ )
335
+ dweight = linear_bwd_compute_weight_grad(
336
+ ctx, dout, x, weight_og, ops.matmul_bwd_dw, ops.matmul_bwd_dw_inplace
337
+ )
338
+ return dpreact, dweight, None, None, dbias, None, None
339
+
340
+
341
+ def act_linear_func(preact, weight, x, activation, bias=None, fuse_grad_accum=False, tuned=True):
342
+ ops = _DActLinearOps if tuned else _DActLinearUntunedOps
343
+ return DActLinearFunc.apply(preact, weight, x, activation, bias, fuse_grad_accum, ops)
344
+
345
+
346
+ def gated_linear_func(preact, weight, x, activation, bias=None, fuse_grad_accum=False, tuned=True):
347
+ ops = _DGatedLinearOps if tuned else _DGatedLinearUntunedOps
348
+ return DActLinearFunc.apply(preact, weight, x, activation, bias, fuse_grad_accum, ops)
349
+
350
+
351
+ class Linear(nn.Linear):
352
+ def __init__(
353
+ self,
354
+ in_features: int,
355
+ out_features: int,
356
+ bias: bool = False,
357
+ device=None,
358
+ dtype=None,
359
+ fuse_grad_accum: bool = False,
360
+ ) -> None:
361
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
362
+ self.fuse_grad_accum = fuse_grad_accum
363
+
364
+ def forward(self, input: Tensor) -> Tensor:
365
+ if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
366
+ return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
367
+ else:
368
+ return F.linear(input, self.weight, self.bias)
build/torch-cuda/quack/linear_cross_entropy.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Optional, Literal
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.amp import custom_fwd, custom_bwd
9
+
10
+ from .cross_entropy import cross_entropy, cross_entropy_fwd_out
11
+ from .gemm_interface import gemm, gemm_add, gemm_add_inplace
12
+ from .linear import linear_fwd_convert_type
13
+
14
+
15
+ def linear_cross_entropy_func(
16
+ x: Tensor, # (..., d)
17
+ weight: Tensor, # (V, d)
18
+ bias: Optional[Tensor], # (V,) or None
19
+ target: Tensor, # (...,), int or long
20
+ ignore_index: int = -100,
21
+ reduction: Literal["none", "mean", "sum"] = "mean",
22
+ inplace_backward: bool = False,
23
+ ) -> Tensor:
24
+ y = F.linear(x, weight, bias) # (..., V)
25
+ return cross_entropy(
26
+ y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
27
+ )
28
+
29
+
30
+ def linear_cross_entropy_func_ref(
31
+ x: Tensor, # (..., d)
32
+ weight: Tensor, # (V, d)
33
+ bias: Optional[Tensor], # (V,) or None
34
+ target: Tensor, # (...,), int or long
35
+ ignore_index: int = -100,
36
+ reduction: Literal["none", "mean", "sum"] = "mean",
37
+ ) -> Tensor:
38
+ y = F.linear(x, weight, bias) # (..., V)
39
+ return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
40
+
41
+
42
+ def chunked_linear_cross_entropy_fwd(
43
+ x: Tensor, # (B*L, d) where B is batch, L is seqlen
44
+ weight: Tensor, # (V, d) where V is vocab size
45
+ target: Tensor, # (B*L,)
46
+ chunk_size: int = 4096,
47
+ ignore_index: int = -100,
48
+ tuned: bool = True,
49
+ ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
50
+ """
51
+ Chunked forward pass for linear cross entropy.
52
+
53
+ Splits input along batch dimension, computes matmul and cross_entropy_fwd
54
+ for each chunk, stores dx for each chunk, and accumulates dw.
55
+
56
+ Returns:
57
+ loss: (B*L,) loss values
58
+ dx: (B*L, d) gradient w.r.t. input
59
+ dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
60
+ last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
61
+ last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
62
+ """
63
+ B_L, d = x.shape
64
+ V, _ = weight.shape
65
+ device = x.device
66
+ num_chunks = (B_L + chunk_size - 1) // chunk_size
67
+ # Since we use gemm with TMA we require some alignment
68
+ assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
69
+ assert B_L % 8 == 0
70
+ # Pre-allocate outputs
71
+ loss = torch.empty(B_L, device=device, dtype=torch.float32)
72
+ logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
73
+ dx = torch.empty_like(x)
74
+ # Last chunk of dw will be deferred to the backward pass
75
+ dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
76
+ last_dlogits_chunk = None
77
+ last_x_chunk = None
78
+
79
+ # Process in chunks
80
+ for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
81
+ zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
82
+ ):
83
+ chunk_len = x_chunk.shape[0]
84
+ logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
85
+ torch.mm(x_chunk, weight.mT, out=logits_chunk)
86
+ # Compute cross entropy forward with gradients
87
+ dlogits_chunk = logits_chunk # inplace_backward
88
+ cross_entropy_fwd_out(
89
+ logits_chunk,
90
+ target_chunk,
91
+ None, # target_logit
92
+ loss=loss_chunk,
93
+ lse=None, # we don't need lse here
94
+ dx=dlogits_chunk,
95
+ ignore_index=ignore_index,
96
+ )
97
+ # Compute dx for this chunk: dlogits @ weight
98
+ torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
99
+ # Compute dw for all chunks except the last
100
+ if i == num_chunks - 1:
101
+ # Last chunk: save for backward pass
102
+ last_dlogits_chunk = dlogits_chunk
103
+ last_x_chunk = x_chunk
104
+ elif i == 0:
105
+ # First chunk: dw = dlogits.T @ x_chunk
106
+ gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
107
+ else:
108
+ # Middle chunks: dw += dlogits.T @ x_chunk
109
+ gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
110
+ return loss, dx, dw, last_dlogits_chunk, last_x_chunk
111
+
112
+
113
+ class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
114
+ @staticmethod
115
+ @custom_fwd(device_type="cuda")
116
+ def forward(
117
+ ctx,
118
+ x: Tensor,
119
+ weight: Tensor,
120
+ target: Tensor,
121
+ ignore_index: int = -100,
122
+ reduction: Literal["mean", "sum"] = "mean",
123
+ chunk_size: int = 4096,
124
+ tuned: bool = True,
125
+ ):
126
+ """
127
+ Forward pass computes loss and stores dx and dw for backward.
128
+ """
129
+ ctx.weight_dtype = weight.dtype
130
+ x, weight = linear_fwd_convert_type(x, weight)
131
+ batch_shape = x.shape[:-1]
132
+ x = x.reshape(-1, x.shape[-1])
133
+ # TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
134
+ loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
135
+ x, weight, target, chunk_size, ignore_index, tuned=tuned
136
+ )
137
+ loss_sum = loss.sum()
138
+ loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
139
+ ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
140
+ ctx.batch_shape = batch_shape
141
+ ctx.ignore_index = ignore_index
142
+ ctx.reduction = reduction
143
+ ctx.tuned = tuned
144
+ return loss_sum if loss_scale is None else loss_sum * loss_scale
145
+
146
+ @staticmethod
147
+ @custom_bwd(device_type="cuda")
148
+ def backward(ctx, dloss):
149
+ """
150
+ Backward pass scales pre-computed gradients by dloss and completes
151
+ the last chunk's dw computation.
152
+ dloss is a scalar.
153
+ """
154
+ dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
155
+ tuned = ctx.tuned
156
+ if loss_scale is not None:
157
+ dloss = dloss * loss_scale
158
+ # TODO: the case where x or weight doesn't require grad
159
+ dx.mul_(dloss)
160
+ dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
161
+ # Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
162
+ if dw is None:
163
+ # Only had one chunk, compute dw directly with dloss scaling
164
+ dw = gemm(
165
+ last_dlogits_chunk.T,
166
+ last_x_chunk,
167
+ out_dtype=ctx.weight_dtype,
168
+ alpha=dloss,
169
+ tuned=tuned,
170
+ )
171
+ else:
172
+ # Add last chunk's contribution with dloss scaling
173
+ # dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
174
+ # We use alpha=dloss, beta=dloss
175
+ if ctx.weight_dtype == dw.dtype:
176
+ gemm_add_inplace(
177
+ last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
178
+ )
179
+ else:
180
+ dw = gemm_add(
181
+ last_dlogits_chunk.T,
182
+ last_x_chunk,
183
+ dw,
184
+ alpha=dloss,
185
+ beta=dloss,
186
+ out_dtype=ctx.weight_dtype,
187
+ tuned=tuned,
188
+ )
189
+ return dx, dw, None, None, None, None, None
190
+
191
+
192
+ def chunked_linear_cross_entropy(
193
+ x: Tensor,
194
+ weight: Tensor,
195
+ target: Tensor,
196
+ chunk_size: int = 4096,
197
+ ignore_index: int = -100,
198
+ reduction: Literal["mean", "sum"] = "mean",
199
+ tuned: bool = True,
200
+ ) -> Tensor:
201
+ """
202
+ Chunked linear cross entropy with automatic differentiation support.
203
+
204
+ Args:
205
+ x: Input tensor of shape (B*L, d)
206
+ weight: Weight tensor of shape (V, d)
207
+ target: Target indices of shape (B*L,)
208
+ chunk_size: Size of chunks to process
209
+ ignore_index: Index to ignore in loss computation
210
+ reduction: Type of reduction to apply
211
+ tuned: Whether to use tuned kernels
212
+
213
+ Returns:
214
+ Loss tensor with specified reduction
215
+ """
216
+ if reduction not in ["mean", "sum"]:
217
+ raise ValueError(f"Invalid reduction: {reduction}")
218
+ loss = ChunkedLinearCrossEntropyFunction.apply(
219
+ x, weight, target, ignore_index, reduction, chunk_size, tuned
220
+ )
221
+ return loss
222
+
223
+
224
+ class LinearCrossEntropy(nn.Linear):
225
+ def __init__(
226
+ self,
227
+ in_features: int,
228
+ out_features: int,
229
+ bias: bool = False,
230
+ ignore_index: int = -100,
231
+ reduction: Literal["none", "mean", "sum"] = "mean",
232
+ chunk_size: Optional[int] = None,
233
+ inplace_backward: bool = False,
234
+ tuned: bool = True,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
239
+ self.ignore_index = ignore_index
240
+ self.reduction = reduction
241
+ self.chunk_size = chunk_size
242
+ self.inplace_backward = inplace_backward
243
+ self.tuned = tuned
244
+
245
+ def forward(self, input: Tensor, target: Tensor) -> Tensor:
246
+ if (
247
+ self.bias is None
248
+ and input.is_cuda
249
+ and input.stride(-1) == 1
250
+ and self.in_features % 8 == 0
251
+ and self.out_features % 8 == 0
252
+ and input.shape[:-1].numel() % 8 == 0
253
+ and self.chunk_size is not None
254
+ and self.chunk_size % 8 == 0
255
+ and self.reduction in ["mean", "sum"]
256
+ ):
257
+ return chunked_linear_cross_entropy(
258
+ input,
259
+ self.weight,
260
+ target,
261
+ chunk_size=self.chunk_size,
262
+ ignore_index=self.ignore_index,
263
+ reduction=self.reduction,
264
+ tuned=self.tuned,
265
+ )
266
+ else:
267
+ return linear_cross_entropy_func(
268
+ input,
269
+ self.weight,
270
+ self.bias,
271
+ target,
272
+ ignore_index=self.ignore_index,
273
+ reduction=self.reduction,
274
+ inplace_backward=self.inplace_backward,
275
+ )
build/torch-cuda/quack/mlp.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Literal
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+
9
+ from einops import rearrange
10
+
11
+ from .linear import linear_act_func, act_linear_func
12
+ from .linear import linear_gated_func, gated_linear_func
13
+ from .linear import linear_fwd_convert_type
14
+ from .linear import _recompute_act_postact, _recompute_gated_postact
15
+ from .activation import gate_fn_map
16
+ from .gemm_interface import (
17
+ act_to_pytorch_fn_map,
18
+ gated_to_pytorch_fn_map,
19
+ gemm,
20
+ gemm_add_inplace,
21
+ gemm_gated,
22
+ gemm_dgated,
23
+ gemm_act,
24
+ gemm_dact,
25
+ )
26
+
27
+ Activation = Literal[
28
+ "gelu_tanh_approx",
29
+ "relu",
30
+ "relu_sq",
31
+ "swiglu",
32
+ "swiglu_oai",
33
+ "reglu",
34
+ "geglu",
35
+ "glu",
36
+ ]
37
+
38
+
39
+ # --- Ops bundles for MLP recompute variants ---
40
+
41
+
42
+ class _MLPOps:
43
+ matmul_fwd = gemm
44
+ matmul_fwd_act = gemm_act
45
+ matmul_bwd_dact = partial(gemm_dact, dynamic_scheduler=True)
46
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
47
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
48
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
49
+ recompute_postact = staticmethod(_recompute_act_postact)
50
+
51
+
52
+ class _MLPUntunedOps:
53
+ matmul_fwd = partial(gemm, tuned=False)
54
+ matmul_fwd_act = partial(gemm_act, tuned=False)
55
+ matmul_bwd_dact = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
56
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
57
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
58
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True, tuned=False)
59
+ recompute_postact = staticmethod(_recompute_act_postact)
60
+
61
+
62
+ class _MLPGatedOps(_MLPOps):
63
+ matmul_fwd_act = gemm_gated
64
+ matmul_bwd_dact = partial(gemm_dgated, dynamic_scheduler=True)
65
+ recompute_postact = staticmethod(_recompute_gated_postact)
66
+
67
+
68
+ class _MLPGatedUntunedOps(_MLPUntunedOps):
69
+ matmul_fwd_act = partial(gemm_gated, tuned=False)
70
+ matmul_bwd_dact = partial(gemm_dgated, dynamic_scheduler=True, tuned=False)
71
+ recompute_postact = staticmethod(_recompute_gated_postact)
72
+
73
+
74
+ class _MLPGatedConcatOps(_MLPGatedOps):
75
+ matmul_fwd_act = partial(gemm_gated, concat_layout=("B",))
76
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, concat_layout=("B",))
77
+ matmul_bwd_dw1 = partial(gemm, dynamic_scheduler=True, concat_layout=("out",))
78
+ matmul_bwd_dw1_inplace = partial(
79
+ gemm_add_inplace, dynamic_scheduler=True, concat_layout=("C", "out")
80
+ )
81
+ recompute_fwd = partial(gemm, concat_layout=("B",))
82
+
83
+
84
+ class _MLPGatedConcatUntunedOps(_MLPGatedUntunedOps):
85
+ matmul_fwd_act = partial(gemm_gated, tuned=False, concat_layout=("B",))
86
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("B",))
87
+ matmul_bwd_dw1 = partial(gemm, dynamic_scheduler=True, tuned=False, concat_layout=("out",))
88
+ matmul_bwd_dw1_inplace = partial(
89
+ gemm_add_inplace, dynamic_scheduler=True, tuned=False, concat_layout=("out",)
90
+ )
91
+ recompute_fwd = partial(gemm, tuned=False, concat_layout=("B",))
92
+
93
+
94
+ class MLPRecomputeFunc(torch.autograd.Function):
95
+ """MLP with activation recomputation: saves only x (not preact) to reduce memory.
96
+
97
+ In backward, recomputes preact = x @ W1.T (one extra matmul) instead of loading it
98
+ from saved tensors. This trades compute for memory:
99
+ - Saves: batch * 2 * hidden * dtype_size bytes of activation memory
100
+ - Costs: one extra GEMM (x @ W1.T) during backward
101
+
102
+ Ops class selects between non-gated (gemm_act/gemm_dact) and gated (gemm_gated/gemm_dgated)
103
+ variants, as well as tuned/untuned.
104
+ """
105
+
106
+ @staticmethod
107
+ def forward(ctx, x, weight1, weight2, activation, fuse_grad_accum, ops):
108
+ x, weight1, weight2 = linear_fwd_convert_type(x, weight1, weight2)
109
+ with torch.amp.autocast("cuda", enabled=False):
110
+ ctx.weight_dtype = weight1.dtype
111
+ ctx.fuse_grad_accum = fuse_grad_accum
112
+ ctx.activation = activation
113
+ ctx.ops = ops
114
+ weight1_og, weight2_og = weight1, weight2
115
+ batch_shape = x.shape[:-1]
116
+ x_flat = x.reshape(-1, x.shape[-1])
117
+ _preact, postact = ops.matmul_fwd_act(x_flat, weight1.T, activation=activation)
118
+ out = ops.matmul_fwd(postact, weight2.T)
119
+ # Save only x and weights — no preact (the whole point of recompute)
120
+ needs_input_grad = ctx.needs_input_grad
121
+ any_grad = needs_input_grad[0] or needs_input_grad[1] or needs_input_grad[2]
122
+ need_dact = needs_input_grad[0] or needs_input_grad[1] # gemm_dact for dpreact
123
+ saved_x = x if any_grad else None # recompute preact = x @ W1.T
124
+ saved_w1 = weight1 if any_grad else None # recompute + dx
125
+ saved_w2 = weight2 if need_dact else None # only gemm_dact needs W2
126
+ ctx.save_for_backward(
127
+ saved_x,
128
+ saved_w1,
129
+ saved_w2,
130
+ weight1_og if fuse_grad_accum else None,
131
+ weight2_og if fuse_grad_accum else None,
132
+ )
133
+ return out.reshape(*batch_shape, out.shape[-1])
134
+
135
+ @staticmethod
136
+ def backward(ctx, dout):
137
+ with torch.amp.autocast("cuda", enabled=False):
138
+ ops = ctx.ops
139
+ x, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors
140
+ batch_shape = dout.shape[:-1]
141
+ dout = dout.reshape(-1, dout.shape[-1]).contiguous()
142
+ # Recompute preact = x @ W1.T (the extra matmul we trade for memory)
143
+ x_flat = x.reshape(-1, x.shape[-1]) if x is not None else None
144
+ need_dact = ctx.needs_input_grad[0] or ctx.needs_input_grad[1]
145
+ any_grad = need_dact or ctx.needs_input_grad[2]
146
+ # concat ops override recompute_fwd to produce interleaved preact matching forward
147
+ recompute_fwd = getattr(ops, "recompute_fwd", ops.matmul_fwd)
148
+ if need_dact:
149
+ preact = recompute_fwd(x_flat, weight1.T)
150
+ # gemm_dact computes: dpreact = d_act(dout @ W2, preact) AND recomputes postact
151
+ dpreact, postact = ops.matmul_bwd_dact(
152
+ dout, weight2, preact, activation=ctx.activation
153
+ )
154
+ elif any_grad:
155
+ # Only dW2 needed: recompute postact from preact cheaply (no gemm_dact)
156
+ preact = recompute_fwd(x_flat, weight1.T)
157
+ postact = ops.recompute_postact(preact, ctx.activation)
158
+ dpreact = None
159
+ else:
160
+ dpreact, postact = None, None
161
+ # dW2 = dout.T @ postact
162
+ dweight2 = _compute_weight_grad(
163
+ ctx,
164
+ dout,
165
+ postact,
166
+ weight2_og,
167
+ ops.matmul_bwd_dw,
168
+ ops.matmul_bwd_dw_inplace,
169
+ ctx.needs_input_grad[2],
170
+ )
171
+ # dx = dpreact @ W1
172
+ if ctx.needs_input_grad[0]:
173
+ dx = ops.matmul_bwd_dx(dpreact, weight1)
174
+ dx = dx.reshape(*batch_shape, dx.shape[-1])
175
+ else:
176
+ dx = None
177
+ # dW1 = dpreact.T @ x (use dw1 ops if available, e.g. concat layout)
178
+ dw1_fn = getattr(ops, "matmul_bwd_dw1", ops.matmul_bwd_dw)
179
+ dw1_inplace_fn = getattr(ops, "matmul_bwd_dw1_inplace", ops.matmul_bwd_dw_inplace)
180
+ dweight1 = _compute_weight_grad(
181
+ ctx,
182
+ dpreact,
183
+ x_flat,
184
+ weight1_og,
185
+ dw1_fn,
186
+ dw1_inplace_fn,
187
+ ctx.needs_input_grad[1],
188
+ )
189
+ return dx, dweight1, dweight2, None, None, None
190
+
191
+
192
+ def _compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn, needs_grad):
193
+ if not needs_grad:
194
+ return None
195
+ x = x.reshape(-1, x.shape[-1])
196
+ if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
197
+ return matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
198
+ else:
199
+ matmul_inplace_fn(dout.T, x, weight_og.grad)
200
+ dweight = weight_og.grad
201
+ weight_og.grad = None
202
+ return dweight
203
+
204
+
205
+ def mlp_func(
206
+ x,
207
+ weight1,
208
+ weight2,
209
+ activation: str,
210
+ bias1=None,
211
+ bias2=None,
212
+ fuse_grad_accum=False,
213
+ tuned=True,
214
+ recompute=False,
215
+ concat_layout=False,
216
+ ):
217
+ gated = activation in gate_fn_map
218
+ if concat_layout:
219
+ assert gated, "concat_layout is only supported for gated MLP"
220
+ if recompute:
221
+ if concat_layout:
222
+ ops = _MLPGatedConcatOps if tuned else _MLPGatedConcatUntunedOps
223
+ elif gated:
224
+ ops = _MLPGatedOps if tuned else _MLPGatedUntunedOps
225
+ else:
226
+ ops = _MLPOps if tuned else _MLPUntunedOps
227
+ return MLPRecomputeFunc.apply(x, weight1, weight2, activation, fuse_grad_accum, ops)
228
+ fc1_fn = linear_gated_func if gated else linear_act_func
229
+ fc2_fn = gated_linear_func if gated else act_linear_func
230
+ preact, postact = fc1_fn(
231
+ x,
232
+ weight1,
233
+ activation,
234
+ bias=bias1,
235
+ store_preact=torch.is_grad_enabled(),
236
+ fuse_grad_accum=fuse_grad_accum,
237
+ tuned=tuned,
238
+ **({"concat_layout": concat_layout} if concat_layout and gated else {}),
239
+ )
240
+ out = fc2_fn(
241
+ preact,
242
+ weight2,
243
+ postact,
244
+ activation=activation,
245
+ bias=bias2,
246
+ fuse_grad_accum=fuse_grad_accum,
247
+ tuned=tuned,
248
+ )
249
+ return out
250
+
251
+
252
+ class MLP(nn.Module):
253
+ def __init__(
254
+ self,
255
+ in_features,
256
+ hidden_features=None,
257
+ out_features=None,
258
+ bias1=False,
259
+ bias2=False,
260
+ activation: Activation = "gelu_tanh_approx",
261
+ multiple_of=1,
262
+ device=None,
263
+ dtype=None,
264
+ fuse_grad_accum: bool = False,
265
+ tuned: bool = True,
266
+ recompute: bool = False,
267
+ concat_layout: bool = False,
268
+ ):
269
+ factory_kwargs = {"device": device, "dtype": dtype}
270
+ super().__init__()
271
+ out_features = out_features if out_features is not None else in_features
272
+ self.activation = activation
273
+ self.gated = activation in gate_fn_map
274
+ assert not concat_layout or self.gated, "concat_layout is only supported for gated MLP"
275
+ if hidden_features is None:
276
+ hidden_features = int(8 / 3 * in_features) if self.gated else 4 * in_features
277
+ if multiple_of > 1:
278
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
279
+ fc1_out = 2 * hidden_features if self.gated else hidden_features
280
+ self.fc1 = nn.Linear(in_features, fc1_out, bias=bias1, **factory_kwargs)
281
+ if self.gated:
282
+ if concat_layout:
283
+ self.fc1.weight._muon_reshape_functions = (
284
+ lambda w: rearrange(w, "(two d) e -> two d e", two=2),
285
+ lambda w: rearrange(w, "two d e -> (two d) e"),
286
+ )
287
+ else:
288
+ self.fc1.weight._muon_reshape_functions = (
289
+ lambda w: rearrange(w, "(d two) e -> two d e", two=2),
290
+ lambda w: rearrange(w, "two d e -> (d two) e"),
291
+ )
292
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
293
+ self.fuse_grad_accum = fuse_grad_accum
294
+ self.tuned = tuned
295
+ self.recompute = recompute
296
+ self.concat_layout = concat_layout
297
+
298
+ def forward(self, input: Tensor) -> Tensor:
299
+ # Allow bias in the fused path during inference (fwd-only, no bwd).
300
+ bias_ok = not torch.is_grad_enabled() or (self.fc1.bias is None and self.fc2.bias is None)
301
+ if (
302
+ bias_ok
303
+ and input.is_cuda
304
+ and input.stride(-1) == 1
305
+ and self.fc1.in_features % 8 == 0
306
+ and self.fc1.out_features % (16 if self.gated else 8) == 0
307
+ and self.fc2.out_features % 8 == 0
308
+ ):
309
+ return mlp_func(
310
+ input,
311
+ self.fc1.weight,
312
+ self.fc2.weight,
313
+ activation=self.activation,
314
+ bias1=self.fc1.bias,
315
+ bias2=self.fc2.bias,
316
+ fuse_grad_accum=self.fuse_grad_accum,
317
+ tuned=self.tuned,
318
+ recompute=self.recompute,
319
+ concat_layout=self.concat_layout,
320
+ )
321
+ else:
322
+ y = self.fc1(input)
323
+ if self.gated:
324
+ if self.concat_layout:
325
+ gate, up = y.chunk(2, dim=-1)
326
+ y = gated_to_pytorch_fn_map[self.activation](gate, up)
327
+ else:
328
+ y = gated_to_pytorch_fn_map[self.activation](y[..., ::2], y[..., 1::2])
329
+ else:
330
+ y = act_to_pytorch_fn_map[self.activation](y)
331
+ return self.fc2(y)
build/torch-cuda/quack/mx_utils.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal MX / NVFP4 quantization + scale swizzling utilities.
2
+
3
+ Ported from torchao (BSD-3) to avoid the runtime dependency:
4
+ torchao/prototype/mx_formats/{mx_tensor, nvfp4_tensor, utils, constants}.py
5
+ torchao/prototype/custom_fp_utils.py
6
+ torchao/prototype/mx_formats/kernels.py
7
+
8
+ All quantizers are pure-PyTorch. Use the `to_mx_compiled` / `to_mxfp4_compiled` /
9
+ `to_nvfp4_compiled` module-level handles if you want torch.compile-generated
10
+ Triton kernels (much faster on big tensors; one-time compile overhead).
11
+
12
+ Only the FLOOR scaling mode is ported (torchao's default for MX formats).
13
+ """
14
+
15
+ import torch
16
+
17
+ F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
18
+ F8E4M3_MAX_POW2 = 8
19
+ E8M0_EXPONENT_BIAS = 127
20
+ E8M0_EXPONENT_NAN_VAL = 255
21
+ F32_EXP_BIAS = 127
22
+ F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) # 2**-126
23
+ MBITS_F32 = 23
24
+ EBITS_F32 = 8
25
+
26
+ # FP4 E2M1 constants
27
+ F4_E2M1_MAX = 6.0
28
+ F4_E2M1_MAX_POW2 = 2
29
+ F4_E2M1_MAX_INT = 7 # 3-bit magnitude mask
30
+ EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
31
+
32
+ E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
33
+
34
+
35
+ def _n_ones(n: int) -> int:
36
+ return (1 << n) - 1
37
+
38
+
39
+ def to_mx(data_hp: torch.Tensor, block_size: int = 32):
40
+ """MXFP8-e4m3 quantization with FLOOR scaling.
41
+
42
+ Args:
43
+ data_hp: (..., K) bf16 or fp32 tensor, contiguous, K % block_size == 0.
44
+ Returns:
45
+ qdata: (..., K) float8_e4m3fn
46
+ scale: (..., K // block_size) float8_e8m0fnu
47
+ """
48
+ assert data_hp.dtype in (torch.bfloat16, torch.float32)
49
+ assert data_hp.shape[-1] % block_size == 0
50
+ assert data_hp.is_contiguous()
51
+
52
+ orig_shape = data_hp.shape
53
+ data_hp = data_hp.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
54
+ max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
55
+
56
+ data_hp = data_hp.to(torch.float32)
57
+ max_abs = max_abs.to(torch.float32)
58
+
59
+ # FLOOR scaling: extract biased exponent of max_abs via bit-shift
60
+ max_abs_int32 = max_abs.view(torch.int32)
61
+ extracted_pow2 = ((torch.bitwise_right_shift(max_abs_int32, MBITS_F32)) & 0xFF) - F32_EXP_BIAS
62
+ scale_e8m0_unbiased = extracted_pow2 - F8E4M3_MAX_POW2
63
+ scale_e8m0_unbiased = torch.clamp(
64
+ scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1
65
+ )
66
+ scale_e8m0_biased = (scale_e8m0_unbiased + E8M0_EXPONENT_BIAS).to(torch.uint8)
67
+ # restore NaN sentinel (uint8 cast drops NaN)
68
+ scale_e8m0_biased = torch.where(torch.isnan(max_abs), E8M0_EXPONENT_NAN_VAL, scale_e8m0_biased)
69
+
70
+ # reconstruct fp32 scale from biased exponent
71
+ scale_fp32 = (torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)).view(
72
+ torch.float32
73
+ )
74
+ # avoid 2**-127 being flushed to 0 (pytorch #125557)
75
+ scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)
76
+
77
+ data_lp = data_hp / scale_fp32
78
+ # eager fp8 cast is unsaturated; clamp explicitly
79
+ if not torch._dynamo.is_compiling():
80
+ data_lp = torch.clamp(data_lp, min=-F8E4M3_MAX, max=F8E4M3_MAX)
81
+
82
+ qdata = data_lp.to(torch.float8_e4m3fn).reshape(orig_shape)
83
+ scale = scale_e8m0_biased.view(torch.float8_e8m0fnu).squeeze(-1)
84
+ return qdata, scale
85
+
86
+
87
+ def _f32_to_floatx_unpacked(x: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor:
88
+ """FP32 -> sub-byte float (uint8, code in low bits). Verbatim from torchao.
89
+
90
+ Round-to-nearest-even via magic-adder; saturation on overflow; no NaN.
91
+ """
92
+ assert x.dtype == torch.float
93
+ assert 1 + ebits + mbits <= 8
94
+ exp_bias = _n_ones(ebits - 1)
95
+ max_int = _n_ones(ebits + mbits)
96
+ sign_mask = 1 << (ebits + mbits)
97
+ magic_adder = _n_ones(MBITS_F32 - mbits - 1)
98
+ max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
99
+ min_normal = 2 ** (1 - exp_bias)
100
+ denorm_exp = (F32_EXP_BIAS - exp_bias) + (MBITS_F32 - mbits) + 1
101
+ denorm_mask_int = denorm_exp << MBITS_F32
102
+ denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
103
+
104
+ x = x.view(torch.int32)
105
+ sign = x & 0x80000000
106
+ x = x ^ sign
107
+ x = x.view(torch.float)
108
+ saturate_mask = x >= max_normal
109
+ denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
110
+ normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
111
+ denormal_x = x + denorm_mask_float
112
+ denormal_x = denormal_x.view(torch.int32)
113
+ denormal_x -= denorm_mask_int
114
+ denormal_x = denormal_x.to(torch.uint8)
115
+ normal_x = x.view(torch.int32)
116
+ mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
117
+ val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
118
+ normal_x += val_to_add
119
+ normal_x += mant_odd
120
+ normal_x = normal_x >> (MBITS_F32 - mbits)
121
+ normal_x = normal_x.to(torch.uint8)
122
+ x = torch.full_like(x, max_int, dtype=torch.uint8)
123
+ x = torch.where(denormal_mask, denormal_x, x)
124
+ x = torch.where(normal_mask, normal_x, x)
125
+ sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
126
+ sign_lp = sign_lp.to(torch.uint8)
127
+ sign_lp = sign_lp & sign_mask
128
+ x = x | sign_lp
129
+ return x.to(torch.uint8)
130
+
131
+
132
+ def _pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
133
+ """Pack 4-bit uint8 values in pairs: pair (a,b) -> byte (b<<4 | a)."""
134
+ shape = uint8_data.shape
135
+ assert shape[-1] % 2 == 0
136
+ uint8_data = uint8_data.contiguous().view(-1)
137
+ return (uint8_data[::2] | uint8_data[1::2] << 4).view(*shape[:-1], shape[-1] // 2)
138
+
139
+
140
+ def _compute_e8m0_scale_floor(max_abs: torch.Tensor, target_max_pow2: int) -> torch.Tensor:
141
+ """Return biased E8M0 scale (uint8) for FLOOR-mode MX quantization."""
142
+ max_abs_int32 = max_abs.view(torch.int32)
143
+ extracted_pow2 = ((torch.bitwise_right_shift(max_abs_int32, MBITS_F32)) & 0xFF) - F32_EXP_BIAS
144
+ scale_unbiased = extracted_pow2 - target_max_pow2
145
+ scale_unbiased = torch.clamp(
146
+ scale_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1
147
+ )
148
+ scale_biased = (scale_unbiased + E8M0_EXPONENT_BIAS).to(torch.uint8)
149
+ scale_biased = torch.where(torch.isnan(max_abs), E8M0_EXPONENT_NAN_VAL, scale_biased)
150
+ return scale_biased
151
+
152
+
153
+ def to_mxfp4(x: torch.Tensor, block_size: int = 32):
154
+ """MXFP4 quantization: E2M1 data + E8M0 per-block scales, FLOOR scaling.
155
+
156
+ Args:
157
+ x: (..., K) bf16/fp16/fp32, contiguous, K % block_size == 0.
158
+ Returns:
159
+ qdata_packed: uint8, shape (..., K // 2). Two FP4 values per byte
160
+ (first -> low nibble, second -> high nibble).
161
+ scale: float8_e8m0fnu, shape (..., K // block_size).
162
+ """
163
+ assert x.dtype in (torch.bfloat16, torch.float16, torch.float32)
164
+ assert x.shape[-1] % block_size == 0
165
+ assert x.is_contiguous()
166
+
167
+ orig_shape = x.shape
168
+ data_hp = x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
169
+ max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
170
+ data_hp = data_hp.to(torch.float32)
171
+ max_abs = max_abs.to(torch.float32)
172
+
173
+ scale_biased = _compute_e8m0_scale_floor(max_abs, F4_E2M1_MAX_POW2)
174
+ scale_fp32 = (torch.bitwise_left_shift(scale_biased.to(torch.int32), MBITS_F32)).view(
175
+ torch.float32
176
+ )
177
+ scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)
178
+
179
+ data_lp = data_hp / scale_fp32
180
+ data_lp = data_lp.reshape(orig_shape)
181
+ data_lp = _f32_to_floatx_unpacked(data_lp.float(), EBITS_F4_E2M1, MBITS_F4_E2M1)
182
+ data_lp = _pack_uint4(data_lp)
183
+
184
+ scale = scale_biased.view(torch.float8_e8m0fnu).squeeze(-1)
185
+ return data_lp, scale
186
+
187
+
188
+ def nvfp4_per_tensor_scale(amax: torch.Tensor) -> torch.Tensor:
189
+ """NVFP4 per-tensor scale: amax / (F8E4M3_MAX * F4_E2M1_MAX) = amax / 2688."""
190
+ return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX)
191
+
192
+
193
+ def to_nvfp4(x: torch.Tensor, block_size: int = 16, per_tensor_scale=None):
194
+ """NVFP4 quantization: E2M1 data + E4M3 per-block scales + optional fp32 per-tensor scale.
195
+
196
+ Args:
197
+ x: (..., K) bf16/fp32, contiguous, K % 16 == 0.
198
+ block_size: must be 16.
199
+ per_tensor_scale: scalar fp32 tensor, or None (uses 1.0 / returns unit).
200
+ Returns:
201
+ qdata_packed: uint8, shape (..., K // 2)
202
+ scale: float8_e4m3fn, shape (..., K // 16)
203
+ per_tensor_scale: scalar fp32 tensor (1.0 if None was passed)
204
+ """
205
+ assert x.dtype in (torch.bfloat16, torch.float32)
206
+ assert x.shape[-1] % block_size == 0
207
+ assert x.is_contiguous()
208
+ assert block_size == 16, "NVFP4 requires block_size=16"
209
+
210
+ orig_shape = x.shape
211
+ data_hp = x.float().reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
212
+ max_abs = torch.amax(torch.abs(data_hp), dim=-1)
213
+ block_scale = max_abs / F4_E2M1_MAX
214
+
215
+ if per_tensor_scale is None:
216
+ block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
217
+ torch.float8_e4m3fn
218
+ )
219
+ recip = 1.0 / block_scale_fp8.to(torch.float32)
220
+ returned_pts = torch.tensor(1.0, dtype=torch.float32, device=x.device)
221
+ else:
222
+ scaled = block_scale.to(torch.float32) / per_tensor_scale
223
+ block_scale_fp8 = torch.clamp(scaled, min=E4M3_EPS, max=F8E4M3_MAX).to(torch.float8_e4m3fn)
224
+ recip = (1.0 / per_tensor_scale) / block_scale_fp8.to(torch.float32)
225
+ returned_pts = per_tensor_scale.to(torch.float32)
226
+
227
+ data_scaled = data_hp * recip.unsqueeze(-1)
228
+ data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
229
+ data_scaled = data_scaled.view(orig_shape)
230
+ data_lp = _f32_to_floatx_unpacked(data_scaled.float(), EBITS_F4_E2M1, MBITS_F4_E2M1)
231
+ data_lp = _pack_uint4(data_lp)
232
+ return data_lp, block_scale_fp8, returned_pts
233
+
234
+
235
+ # ---------------------------------------------------------------------------
236
+ # torch.compile-wrapped fast paths. Generates fused Triton quant kernels via
237
+ # Inductor. dynamic=True avoids recompilation on shape changes.
238
+ # ---------------------------------------------------------------------------
239
+ to_mx_compiled = torch.compile(to_mx, dynamic=True)
240
+ to_mxfp4_compiled = torch.compile(to_mxfp4, dynamic=True)
241
+ to_nvfp4_compiled = torch.compile(to_nvfp4, dynamic=True)
242
+
243
+
244
+ def _ceil_div(a, b):
245
+ return (a + b - 1) // b
246
+
247
+
248
+ def to_blocked(input_matrix: torch.Tensor) -> torch.Tensor:
249
+ """Swizzle a (H, W) e8m0 scale tensor into the 128x4 blocked layout
250
+ cuBLAS expects for MXFP8 _scaled_mm. Returns a 1-D flat tensor of size
251
+ 32*ceil(H/128) * 16*ceil(W/4)."""
252
+ rows, cols = input_matrix.shape
253
+ n_row_blocks = _ceil_div(rows, 128)
254
+ n_col_blocks = _ceil_div(cols, 4)
255
+ padded_rows = n_row_blocks * 128
256
+ padded_cols = n_col_blocks * 4
257
+
258
+ padded = input_matrix
259
+ if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
260
+ padded = torch.zeros(
261
+ (padded_rows, padded_cols),
262
+ device=input_matrix.device,
263
+ dtype=input_matrix.dtype,
264
+ )
265
+ padded[:rows, :cols] = input_matrix
266
+
267
+ blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
268
+ rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
269
+ return rearranged.flatten()
build/torch-cuda/quack/nvmmh_heuristic.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ """nvMatmulHeuristics-based config selection for GEMM.
3
+
4
+ Queries NVIDIA's analytic heuristic library to pick tile/cluster dims based on
5
+ problem shape, then selects swap_ab by comparing estimated runtimes for both
6
+ orientations.
7
+ """
8
+
9
+ import logging
10
+ import torch
11
+
12
+ from .gemm_config import GemmConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ _nvmmh_available = None
17
+ _iface = None
18
+ _hw_descriptors = {} # gpu_enum -> hw descriptor
19
+
20
+
21
+ def _get_iface():
22
+ """Lazily initialize the nvMatmulHeuristics interface."""
23
+ global _nvmmh_available, _iface
24
+ if _nvmmh_available is not None:
25
+ return _iface
26
+ try:
27
+ from nvMatmulHeuristics import (
28
+ NvMatmulHeuristicsInterface,
29
+ NvMatmulHeuristicsTarget,
30
+ )
31
+
32
+ _iface = NvMatmulHeuristicsInterface(
33
+ backend=NvMatmulHeuristicsTarget.CUTLASS3,
34
+ precision="BSB", # overridden per-call
35
+ )
36
+ _nvmmh_available = True
37
+ except Exception as e:
38
+ logger.debug(f"nvMatmulHeuristics not available: {e}")
39
+ _nvmmh_available = False
40
+ _iface = None
41
+ return _iface
42
+
43
+
44
+ def _get_hw(device_capacity):
45
+ """Get or create a hardware descriptor for the given SM version."""
46
+ global _hw_descriptors
47
+ if device_capacity in _hw_descriptors:
48
+ return _hw_descriptors[device_capacity]
49
+ try:
50
+ from nvMatmulHeuristics import (
51
+ NvMatmulHeuristicsNvidiaGpu,
52
+ NvMatmulHeuristicsMatmulLayout,
53
+ )
54
+
55
+ iface = _get_iface()
56
+ if iface is None:
57
+ return None
58
+ gpu_map = {
59
+ 9: NvMatmulHeuristicsNvidiaGpu.H100_SXM,
60
+ 10: NvMatmulHeuristicsNvidiaGpu.B200,
61
+ }
62
+ gpu = gpu_map.get(device_capacity)
63
+ if gpu is None:
64
+ return None
65
+ hw = iface.createHardwareDescriptor()
66
+ iface.setHardwarePredefinedGpu(hw, gpu)
67
+ # Load discovery sets for TN_ROW_MAJOR and TN_COL_MAJOR
68
+ for layout in [
69
+ NvMatmulHeuristicsMatmulLayout.TN_ROW_MAJOR,
70
+ NvMatmulHeuristicsMatmulLayout.TN_COL_MAJOR,
71
+ ]:
72
+ iface.loadInternalDiscoverySet(layout, hw)
73
+ _hw_descriptors[device_capacity] = hw
74
+ return hw
75
+ except Exception as e:
76
+ logger.debug(f"Failed to create hardware descriptor: {e}")
77
+ _hw_descriptors[device_capacity] = None
78
+ return None
79
+
80
+
81
+ _TORCH_DTYPE_TO_NVMMH_PRECISION = {
82
+ torch.bfloat16: "BSB",
83
+ torch.float16: "HSH",
84
+ torch.float32: "SSS",
85
+ }
86
+
87
+
88
+ def _query_top1(iface, hw, m, n, k, layout, precision):
89
+ """Query nvMMH for top-1 config. Returns (tile_m, tile_n, cl_m, cl_n, est_runtime) or None."""
90
+ try:
91
+ original_precision = iface.precision
92
+ iface.precision = precision
93
+ results = iface.get_with_mnk(
94
+ m=m,
95
+ n=n,
96
+ k=k,
97
+ matmulLayout=layout,
98
+ count=1,
99
+ hardware_descriptor=hw,
100
+ )
101
+ iface.precision = original_precision
102
+ if not results:
103
+ return None
104
+ cfg = results[0]["kernel"]
105
+ return cfg.cta_tile_m, cfg.cta_tile_n, cfg.cluster_m, cfg.cluster_n, results[0]["runtime"]
106
+ except Exception:
107
+ return None
108
+
109
+
110
+ def nvmmh_default_config(A, B, device_capacity):
111
+ """Use nvMatmulHeuristics to pick a GemmConfig based on problem shape.
112
+
113
+ Queries both normal (M,N,K) with row-major output and swapped (N,M,K) with
114
+ col-major output, picks the orientation with lower estimated runtime.
115
+
116
+ Returns None if nvMatmulHeuristics is unavailable, letting the caller fall
117
+ back to the hardcoded default.
118
+ """
119
+ from nvMatmulHeuristics import NvMatmulHeuristicsMatmulLayout
120
+
121
+ iface = _get_iface()
122
+ if iface is None:
123
+ return None
124
+ hw = _get_hw(device_capacity)
125
+ if hw is None:
126
+ return None
127
+
128
+ precision = _TORCH_DTYPE_TO_NVMMH_PRECISION.get(A.dtype)
129
+ if precision is None:
130
+ return None
131
+
132
+ # Extract M, N, K from tensor shapes
133
+ # A: (M, K) or (L, M, K), B: (K, N) or (L, K, N)
134
+ m = A.shape[-2] if A.ndim >= 2 else A.shape[0]
135
+ k = A.shape[-1]
136
+ n = B.shape[-1]
137
+
138
+ # Query normal orientation: D(M,N) row-major
139
+ normal = _query_top1(iface, hw, m, n, k, NvMatmulHeuristicsMatmulLayout.TN_ROW_MAJOR, precision)
140
+ # Query swapped orientation: D(N,M) col-major
141
+ swapped = _query_top1(
142
+ iface, hw, n, m, k, NvMatmulHeuristicsMatmulLayout.TN_COL_MAJOR, precision
143
+ )
144
+
145
+ if normal is None and swapped is None:
146
+ return None
147
+
148
+ # Pick orientation with lower estimated runtime
149
+ normal_rt = normal[4] if normal else float("inf")
150
+ swapped_rt = swapped[4] if swapped else float("inf")
151
+
152
+ if swapped_rt < normal_rt and swapped is not None:
153
+ tile_m, tile_n, cl_m, cl_n = swapped[:4]
154
+ swap_ab = True
155
+ else:
156
+ tile_m, tile_n, cl_m, cl_n = normal[:4]
157
+ swap_ab = False
158
+
159
+ # SM90: pingpong only works with tile_m <= 128
160
+ # SM100: no pingpong
161
+ pingpong = (device_capacity == 9) and (tile_m <= 128)
162
+
163
+ return GemmConfig(
164
+ tile_m=tile_m,
165
+ tile_n=tile_n,
166
+ pingpong=pingpong,
167
+ cluster_m=cl_m,
168
+ cluster_n=cl_n,
169
+ swap_ab=swap_ab,
170
+ max_swizzle_size=8,
171
+ device_capacity=device_capacity,
172
+ )
build/torch-cuda/quack/pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2025, Tri Dao.
2
 
3
  from typing import Optional
4
  from dataclasses import dataclass
@@ -6,9 +6,51 @@ from dataclasses import dataclass
6
  import cutlass.cute as cute
7
  from cutlass import Boolean, Int32, const_expr
8
  from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
9
- from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
10
- from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
11
- from cutlass.pipeline import PipelineTmaUmma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class PipelineStateWAdvance(PipelineState):
@@ -33,99 +75,236 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
33
  Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
34
  """
35
  if type is PipelineUserType.Producer:
36
- return PipelineStateWAdvance(
37
- stages,
38
- Int32(0),
39
- Int32(0),
40
- Int32(1),
41
- )
42
  elif type is PipelineUserType.Consumer:
43
- return PipelineStateWAdvance(
44
- stages,
45
- Int32(0),
46
- Int32(0),
47
- Int32(0),
48
- )
49
  else:
50
  assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @dataclass(frozen=True)
54
- class PipelineTmaCpAsync(PipelineTmaAsync):
55
  """
56
- PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
 
 
 
 
 
 
 
 
 
 
57
  """
58
 
 
 
 
 
 
59
  @staticmethod
60
  def create(
61
- *,
62
- num_stages: int,
63
- producer_group: CooperativeGroup,
64
- consumer_group: CooperativeGroup,
65
- tx_count: int,
66
- barrier_storage: cute.Pointer = None,
67
- cta_layout_vmnk: Optional[cute.Layout] = None,
68
- tidx: Optional[Int32] = None,
69
  ):
70
- """
71
- This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
72
- :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
73
- :type barrier_storage: cute.Pointer
74
- :param num_stages: Number of buffer stages for this pipeline
75
- :type num_stages: Int32
76
- :param producer_group: CooperativeGroup for the producer agent
77
- :type producer_group: CooperativeGroup
78
- :param consumer_group: CooperativeGroup for the consumer agent
79
- :type consumer_group: CooperativeGroup
80
- :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
81
- :type tx_count: int
82
- :param cta_layout_vmnk: Layout of the cluster shape
83
- :type cta_layout_vmnk: cute.Layout | None
84
- :param tidx: thread index to consumer async threads
85
- :type tidx: Int32 | None
86
- """
87
- if not isinstance(barrier_storage, cute.Pointer):
88
- raise ValueError(
89
- f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
90
- )
91
 
92
- producer_type = PipelineOp.TmaLoad
93
- consumer_type = PipelineOp.AsyncThread
 
 
 
 
 
 
 
 
 
94
 
95
- producer = (producer_type, producer_group)
96
- consumer = (consumer_type, consumer_group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- sync_object_full = PipelineAsync._make_sync_object(
99
- barrier_storage.align(min_align=8), num_stages, producer, tx_count
 
 
 
 
 
 
 
 
100
  )
101
- sync_object_empty = PipelineAsync._make_sync_object(
102
- barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
- if tidx is None:
105
- tidx, _, _ = cute.arch.thread_idx()
106
- if cta_layout_vmnk is None:
107
- cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
108
- (
109
- dst_rank,
110
- is_signalling_thread,
111
- ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
112
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
113
- dst_rank = None
114
  else:
115
- dst_rank = dst_rank
 
116
 
117
- producer_mask = None
118
 
119
- pipeline_init_wait(cta_layout_vmnk)
120
 
121
- return PipelineTmaCpAsync(
122
- sync_object_full,
123
- sync_object_empty,
124
- num_stages,
125
- producer_mask,
126
- dst_rank,
127
- is_signalling_thread,
128
- )
129
 
130
  @dsl_user_op
131
  def producer_acquire(
@@ -133,30 +312,115 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
133
  state: PipelineState,
134
  try_acquire_token: Optional[Boolean] = None,
135
  is_tma_warp: Optional[Boolean] = True,
 
136
  *,
137
  loc=None,
138
  ip=None,
139
  ):
140
  """
141
- TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
142
  """
143
  if_generate(
144
  try_acquire_token is None or try_acquire_token == 0,
145
  lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
  # This is the difference between this and PipelineTmaAsync: we could have multiple
148
  # warps calling this, but only 1 warp should do the arrive on the full barrier
149
  if_generate(
150
  is_tma_warp,
151
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
 
 
152
  )
153
 
154
  @dsl_user_op
155
  def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
156
- """
157
- We need the mbarrier to track the completion of cp.async
158
- """
159
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
 
 
 
 
 
 
160
 
161
 
162
  class MbarrierArrayWDropCount(MbarrierArray):
@@ -204,13 +468,17 @@ class MbarrierArrayWDropCount(MbarrierArray):
204
  )
205
 
206
 
 
 
 
207
  @dataclass(frozen=True)
208
- class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
209
  """
210
  PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
211
  (e.g. Blackwell mainloops)
212
  """
213
 
 
214
  @staticmethod
215
  def create(
216
  *,
@@ -220,28 +488,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
220
  tx_count: int,
221
  barrier_storage: cute.Pointer = None,
222
  cta_layout_vmnk: Optional[cute.Layout] = None,
223
- producer_drop_count: Optional[Int32] = None,
224
  mcast_mode_mn: tuple[int, int] = (1, 1),
 
 
 
 
225
  ):
226
- """
227
- This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
228
- :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
229
- :type barrier_storage: cute.Pointer
230
  :param num_stages: Number of buffer stages for this pipeline
231
- :type num_stages: Int32
232
- :param producer_group: `CooperativeGroup` for the producer agent
233
  :type producer_group: CooperativeGroup
234
- :param consumer_group: `CooperativeGroup` for the consumer agent
235
  :type consumer_group: CooperativeGroup
236
  :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
237
  :type tx_count: int
 
 
238
  :param cta_layout_vmnk: Layout of the cluster shape
239
- :type cta_layout_vmnk: cute.Layout | None
240
  :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
241
  :type mcast_mode_mn: tuple[int, int], optional
 
 
 
242
  """
243
  if not isinstance(barrier_storage, cute.Pointer):
244
- raise ValueError(
245
  f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
246
  )
247
 
@@ -257,29 +531,44 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
257
  producer,
258
  tx_count,
259
  drop_count=producer_drop_count,
 
 
260
  )
261
- sync_object_empty = PipelineTmaUmma._make_sync_object(
262
- barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
 
 
 
 
263
  )
264
 
265
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
266
  # No mcast mask if not using clusters
267
  producer_mask = None
268
  # All threadblocks are leaders if not using clusters
269
  is_leader_cta = True
270
  else:
271
- producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn)
272
- is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
 
 
 
 
273
 
274
  cta_group = (
275
  cute.nvgpu.tcgen05.CtaGroup.ONE
276
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
277
  else cute.nvgpu.tcgen05.CtaGroup.TWO
278
  )
279
 
280
  consumer_mask = producer_mask
281
 
282
- pipeline_init_wait(cta_layout_vmnk)
 
 
 
 
 
283
 
284
  return PipelineTmaCpAsyncUmma(
285
  sync_object_full,
@@ -308,12 +597,16 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
308
  if_generate(
309
  try_acquire_token is None or try_acquire_token == 0,
310
  lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
 
 
311
  )
312
  # This is the difference between this and PipelineTmaAsync: we could have multiple
313
  # warps calling this, but only 1 warp should do the arrive on the full barrier
314
  if_generate(
315
  and_(self.is_leader_cta, is_tma_warp),
316
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
 
 
317
  )
318
 
319
  @dsl_user_op
@@ -321,4 +614,6 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
321
  """
322
  We need the mbarrier to track the completion of cp.async
323
  """
324
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
 
 
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
 
3
  from typing import Optional
4
  from dataclasses import dataclass
 
6
  import cutlass.cute as cute
7
  from cutlass import Boolean, Int32, const_expr
8
  from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
9
+ from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
10
+ from cutlass.pipeline import PipelineState, PipelineUserType
11
+ from cutlass.pipeline import Agent, agent_sync
12
+ from cutlass.pipeline import NamedBarrier as NamedBarrierOg
13
+ from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
14
+ from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg
15
+ from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
16
+ from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
17
+ from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
18
+ from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
19
+
20
+
21
+ # ── Shared helpers ───────────────────────────────────────────────────────────
22
+
23
+
24
+ def _override_create(parent_cls, child_cls):
25
+ """Create a static factory that constructs parent_cls then re-classes to child_cls."""
26
+
27
+ @staticmethod
28
+ def create(*args, **kwargs):
29
+ obj = parent_cls.create(*args, **kwargs)
30
+ # Can't assign to __class__ directly since the dataclass is frozen
31
+ object.__setattr__(obj, "__class__", child_cls)
32
+ return obj
33
+
34
+ return create
35
+
36
+
37
+ def _make_state(index: Int32, phase: Int32) -> PipelineState:
38
+ """Construct a PipelineState from index and phase (count/stages unused by callers)."""
39
+ return PipelineState(stages=0, count=Int32(0), index=index, phase=phase)
40
+
41
+
42
+ def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip):
43
+ """Optionally wrap a parent pipeline method call in sync_warp + elect_one."""
44
+ if const_expr(elect_one):
45
+ if const_expr(syncwarp):
46
+ cute.arch.sync_warp()
47
+ with cute.arch.elect_one():
48
+ parent_method(self, state, loc=loc, ip=ip)
49
+ else:
50
+ parent_method(self, state, loc=loc, ip=ip)
51
+
52
+
53
+ # ── Pipeline state ──────────────────────────────────────────────────────────
54
 
55
 
56
  class PipelineStateWAdvance(PipelineState):
 
75
  Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
76
  """
77
  if type is PipelineUserType.Producer:
78
+ return PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1))
 
 
 
 
 
79
  elif type is PipelineUserType.Consumer:
80
+ return PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(0))
 
 
 
 
 
81
  else:
82
  assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
83
 
84
 
85
+ # ── Mixin: _w_index / _w_index_phase variants ───────────────────────────────
86
+
87
+
88
+ class _PipelineIndexPhaseMixin:
89
+ """Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents."""
90
+
91
+ @dsl_user_op
92
+ def producer_acquire_w_index_phase(
93
+ self,
94
+ index: Int32,
95
+ phase: Int32,
96
+ try_acquire_token: Optional[Boolean] = None,
97
+ *,
98
+ loc=None,
99
+ ip=None,
100
+ ):
101
+ state = _make_state(index, phase)
102
+ self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip)
103
+
104
+ @dsl_user_op
105
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
106
+ state = _make_state(index, Int32(0))
107
+ self.producer_commit(state, loc=loc, ip=ip)
108
+
109
+ @dsl_user_op
110
+ def consumer_wait_w_index_phase(
111
+ self,
112
+ index: Int32,
113
+ phase: Int32,
114
+ try_wait_token: Optional[Boolean] = None,
115
+ *,
116
+ loc=None,
117
+ ip=None,
118
+ ):
119
+ state = _make_state(index, phase)
120
+ self.consumer_wait(state, try_wait_token, loc=loc, ip=ip)
121
+
122
+ @dsl_user_op
123
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
124
+ state = _make_state(index, Int32(0))
125
+ self.consumer_release(state, loc=loc, ip=ip)
126
+
127
+
128
+ # ── NamedBarrier ─────────────────────────────────────────────────────────────
129
+
130
+
131
+ @dataclass(frozen=True)
132
+ class NamedBarrier(NamedBarrierOg):
133
+ create = _override_create(NamedBarrierOg, None) # patched below
134
+
135
+ @dsl_user_op
136
+ def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
137
+ """
138
+ The aligned flavor of arrive is used when all threads in the CTA will execute the
139
+ same instruction. See PTX documentation.
140
+ """
141
+ cute.arch.barrier_arrive(
142
+ barrier_id=self.barrier_id + index,
143
+ number_of_threads=self.num_threads,
144
+ loc=loc,
145
+ ip=ip,
146
+ )
147
+
148
+ @dsl_user_op
149
+ def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
150
+ cute.arch.barrier(
151
+ barrier_id=self.barrier_id + index,
152
+ number_of_threads=self.num_threads,
153
+ loc=loc,
154
+ ip=ip,
155
+ )
156
+
157
+
158
+ NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier)
159
+
160
+
161
+ # ── PipelineAsync ────────────────────────────────────────────────────────────
162
+
163
+
164
  @dataclass(frozen=True)
165
+ class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg):
166
  """
167
+ PipelineAsync with optional elect_one for producer_commit and consumer_release.
168
+
169
+ When elect_one_*=True (set at create time), only one elected thread per warp
170
+ signals the barrier arrive. This is useful when the mask count is set to 1 per warp.
171
+
172
+ Args (to create):
173
+ elect_one_commit: If True, only elected thread signals producer_commit.
174
+ syncwarp_before_commit: If True (default), issue syncwarp before elect_one.
175
+ elect_one_release: If True, only elected thread signals consumer_release.
176
+ syncwarp_before_release: If True (default), issue syncwarp before elect_one.
177
+ Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group).
178
  """
179
 
180
+ _elect_one_commit: bool = False
181
+ _syncwarp_before_commit: bool = True
182
+ _elect_one_release: bool = False
183
+ _syncwarp_before_release: bool = True
184
+
185
  @staticmethod
186
  def create(
187
+ *args,
188
+ elect_one_commit: bool = False,
189
+ syncwarp_before_commit: bool = True,
190
+ elect_one_release: bool = False,
191
+ syncwarp_before_release: bool = True,
192
+ **kwargs,
 
 
193
  ):
194
+ obj = PipelineAsyncOg.create(*args, **kwargs)
195
+ object.__setattr__(obj, "__class__", PipelineAsync)
196
+ object.__setattr__(obj, "_elect_one_commit", elect_one_commit)
197
+ object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit)
198
+ object.__setattr__(obj, "_elect_one_release", elect_one_release)
199
+ object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
200
+ return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ @dsl_user_op
203
+ def producer_commit(self, state: PipelineState, *, loc=None, ip=None):
204
+ _call_with_elect_one(
205
+ PipelineAsyncOg.producer_commit,
206
+ self,
207
+ state,
208
+ self._elect_one_commit,
209
+ self._syncwarp_before_commit,
210
+ loc,
211
+ ip,
212
+ )
213
 
214
+ @dsl_user_op
215
+ def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
216
+ _call_with_elect_one(
217
+ PipelineAsyncOg.consumer_release,
218
+ self,
219
+ state,
220
+ self._elect_one_release,
221
+ self._syncwarp_before_release,
222
+ loc,
223
+ ip,
224
+ )
225
+
226
+ # _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate
227
+ # to producer_commit / consumer_release above.
228
+
229
+
230
+ # ── PipelineCpAsync ──────────────────────────────────────────────────────────
231
+
232
+
233
+ @dataclass(frozen=True)
234
+ class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg):
235
+ _elect_one_release: bool = False
236
+ _syncwarp_before_release: bool = True
237
+
238
+ @staticmethod
239
+ def create(
240
+ *args,
241
+ elect_one_release: bool = False,
242
+ syncwarp_before_release: bool = True,
243
+ **kwargs,
244
+ ):
245
+ obj = PipelineCpAsyncOg.create(*args, **kwargs)
246
+ object.__setattr__(obj, "__class__", PipelineCpAsync)
247
+ object.__setattr__(obj, "_elect_one_release", elect_one_release)
248
+ object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
249
+ return obj
250
 
251
+ @dsl_user_op
252
+ def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
253
+ _call_with_elect_one(
254
+ PipelineCpAsyncOg.consumer_release,
255
+ self,
256
+ state,
257
+ self._elect_one_release,
258
+ self._syncwarp_before_release,
259
+ loc,
260
+ ip,
261
  )
262
+
263
+ # _w_index variants inherited from _PipelineIndexPhaseMixin.
264
+
265
+
266
+ # ── PipelineTmaAsync ────────────────────────────────────────────────────────
267
+
268
+
269
+ @dataclass(frozen=True)
270
+ class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg):
271
+ """Override producer_acquire to take in extra_tx_count parameter."""
272
+
273
+ @dsl_user_op
274
+ def producer_acquire(
275
+ self,
276
+ state: PipelineState,
277
+ try_acquire_token: Optional[Boolean] = None,
278
+ extra_tx_count: int = 0,
279
+ *,
280
+ loc=None,
281
+ ip=None,
282
+ ):
283
+ """
284
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
285
+ """
286
+ if_generate(
287
+ try_acquire_token is None or try_acquire_token == 0,
288
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
289
+ loc=loc,
290
+ ip=ip,
291
  )
292
+ if const_expr(extra_tx_count == 0):
293
+ self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
 
 
 
 
 
 
 
 
294
  else:
295
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
296
+ self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
297
 
 
298
 
299
+ PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync)
300
 
301
+
302
+ # ── PipelineTmaUmma ─────────────────────────────────────────────────────────
303
+
304
+
305
+ @dataclass(frozen=True)
306
+ class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg):
307
+ """Override producer_acquire to take in extra_tx_count parameter."""
 
308
 
309
  @dsl_user_op
310
  def producer_acquire(
 
312
  state: PipelineState,
313
  try_acquire_token: Optional[Boolean] = None,
314
  is_tma_warp: Optional[Boolean] = True,
315
+ extra_tx_count: int = 0,
316
  *,
317
  loc=None,
318
  ip=None,
319
  ):
320
  """
321
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
322
  """
323
  if_generate(
324
  try_acquire_token is None or try_acquire_token == 0,
325
  lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
326
+ loc=loc,
327
+ ip=ip,
328
+ )
329
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
330
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
331
+ if const_expr(extra_tx_count == 0):
332
+ if_generate(
333
+ and_(self.is_leader_cta, is_tma_warp),
334
+ lambda: self.sync_object_full.arrive(
335
+ state.index, self.producer_mask, loc=loc, ip=ip
336
+ ),
337
+ loc=loc,
338
+ ip=ip,
339
+ )
340
+ else:
341
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
342
+ if_generate(
343
+ and_(self.is_leader_cta, is_tma_warp),
344
+ lambda: self.sync_object_full.arrive_and_expect_tx(
345
+ state.index, tx_count, loc=loc, ip=ip
346
+ ),
347
+ loc=loc,
348
+ ip=ip,
349
+ )
350
+
351
+
352
+ PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma)
353
+
354
+
355
+ # ── PipelineUmmaAsync ───────────────────────────────────────────────────────
356
+
357
+
358
+ @dataclass(frozen=True)
359
+ class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg):
360
+ pass
361
+
362
+
363
+ PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync)
364
+
365
+
366
+ # ── PipelineAsyncUmma ───────────────────────────────────────────────────────
367
+
368
+
369
+ @dataclass(frozen=True)
370
+ class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg):
371
+ pass
372
+
373
+
374
+ PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma)
375
+
376
+
377
+ # ── PipelineTmaCpAsync ──────────────────────────────────────────────────────
378
+
379
+
380
+ @dataclass(frozen=True)
381
+ class PipelineTmaCpAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg):
382
+ """
383
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers.
384
+ Compared to PipelineTmaAsync, producer_acquire gates the full-barrier arrive on is_tma_warp.
385
+ """
386
+
387
+ @dsl_user_op
388
+ def producer_acquire(
389
+ self,
390
+ state: PipelineState,
391
+ try_acquire_token: Optional[Boolean] = None,
392
+ is_tma_warp: Optional[Boolean] = True,
393
+ *,
394
+ loc=None,
395
+ ip=None,
396
+ ):
397
+ if_generate(
398
+ try_acquire_token is None or try_acquire_token == 0,
399
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
400
+ loc=loc,
401
+ ip=ip,
402
  )
403
  # This is the difference between this and PipelineTmaAsync: we could have multiple
404
  # warps calling this, but only 1 warp should do the arrive on the full barrier
405
  if_generate(
406
  is_tma_warp,
407
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
408
+ loc=loc,
409
+ ip=ip,
410
  )
411
 
412
  @dsl_user_op
413
  def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
414
+ """We need the mbarrier to track the completion of cp.async."""
415
+ cute.arch.cp_async_mbarrier_arrive_noinc(
416
+ self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
417
+ )
418
+
419
+
420
+ PipelineTmaCpAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaCpAsync)
421
+
422
+
423
+ # ── MbarrierArrayWDropCount ─────────────────────────────────────────────────
424
 
425
 
426
  class MbarrierArrayWDropCount(MbarrierArray):
 
468
  )
469
 
470
 
471
+ # ── PipelineTmaCpAsyncUmma ──────────────────────────────────────────────────
472
+
473
+
474
  @dataclass(frozen=True)
475
+ class PipelineTmaCpAsyncUmma(PipelineTmaUmmaOg):
476
  """
477
  PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers
478
  (e.g. Blackwell mainloops)
479
  """
480
 
481
+ @dsl_user_op
482
  @staticmethod
483
  def create(
484
  *,
 
488
  tx_count: int,
489
  barrier_storage: cute.Pointer = None,
490
  cta_layout_vmnk: Optional[cute.Layout] = None,
 
491
  mcast_mode_mn: tuple[int, int] = (1, 1),
492
+ defer_sync: bool = False,
493
+ producer_drop_count: Optional[Int32] = None,
494
+ loc=None,
495
+ ip=None,
496
  ):
497
+ """Creates and initializes a new PipelineTmaUmma instance.
498
+
 
 
499
  :param num_stages: Number of buffer stages for this pipeline
500
+ :type num_stages: int
501
+ :param producer_group: CooperativeGroup for the producer agent
502
  :type producer_group: CooperativeGroup
503
+ :param consumer_group: CooperativeGroup for the consumer agent
504
  :type consumer_group: CooperativeGroup
505
  :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
506
  :type tx_count: int
507
+ :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
508
+ :type barrier_storage: cute.Pointer, optional
509
  :param cta_layout_vmnk: Layout of the cluster shape
510
+ :type cta_layout_vmnk: cute.Layout, optional
511
  :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
512
  :type mcast_mode_mn: tuple[int, int], optional
513
+ :raises ValueError: If barrier_storage is not a cute.Pointer instance
514
+ :return: A new PipelineTmaUmma instance configured with the provided parameters
515
+ :rtype: PipelineTmaUmma
516
  """
517
  if not isinstance(barrier_storage, cute.Pointer):
518
+ raise TypeError(
519
  f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
520
  )
521
 
 
531
  producer,
532
  tx_count,
533
  drop_count=producer_drop_count,
534
+ loc=loc,
535
+ ip=ip,
536
  )
537
+ sync_object_empty = PipelineTmaUmmaOg._make_sync_object(
538
+ barrier_storage.align(min_align=8) + num_stages,
539
+ num_stages,
540
+ consumer,
541
+ loc=loc,
542
+ ip=ip,
543
  )
544
 
545
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
546
  # No mcast mask if not using clusters
547
  producer_mask = None
548
  # All threadblocks are leaders if not using clusters
549
  is_leader_cta = True
550
  else:
551
+ producer_mask = PipelineTmaUmmaOg._compute_mcast_arrival_mask(
552
+ cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip
553
+ )
554
+ is_leader_cta = PipelineTmaUmmaOg._compute_is_leader_cta(
555
+ cta_layout_vmnk, loc=loc, ip=ip
556
+ )
557
 
558
  cta_group = (
559
  cute.nvgpu.tcgen05.CtaGroup.ONE
560
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1
561
  else cute.nvgpu.tcgen05.CtaGroup.TWO
562
  )
563
 
564
  consumer_mask = producer_mask
565
 
566
+ if not defer_sync:
567
+ cute.arch.mbarrier_init_fence()
568
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
569
+ agent_sync(Agent.ThreadBlock)
570
+ else:
571
+ agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
572
 
573
  return PipelineTmaCpAsyncUmma(
574
  sync_object_full,
 
597
  if_generate(
598
  try_acquire_token is None or try_acquire_token == 0,
599
  lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
600
+ loc=loc,
601
+ ip=ip,
602
  )
603
  # This is the difference between this and PipelineTmaAsync: we could have multiple
604
  # warps calling this, but only 1 warp should do the arrive on the full barrier
605
  if_generate(
606
  and_(self.is_leader_cta, is_tma_warp),
607
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
608
+ loc=loc,
609
+ ip=ip,
610
  )
611
 
612
  @dsl_user_op
 
614
  """
615
  We need the mbarrier to track the completion of cp.async
616
  """
617
+ cute.arch.cp_async_mbarrier_arrive_noinc(
618
+ self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
619
+ )
build/torch-cuda/quack/reduce.py CHANGED
@@ -196,9 +196,9 @@ def online_softmax_reduce(
196
  )
197
  cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
198
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
199
- max_x_single_warp = cute.make_fragment(num_iter, Float32)
200
  max_x_single_warp.fill(-Float32.inf)
201
- sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
202
  sum_exp_x_single_warp.fill(0.0)
203
  for i in cutlass.range_constexpr(num_iter):
204
  idx = lane_idx + i * cute.arch.WARP_SIZE
 
196
  )
197
  cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
198
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
199
+ max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
200
  max_x_single_warp.fill(-Float32.inf)
201
+ sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
202
  sum_exp_x_single_warp.fill(0.0)
203
  for i in cutlass.range_constexpr(num_iter):
204
  idx = lane_idx + i * cute.arch.WARP_SIZE
build/torch-cuda/quack/rms_final_reduce.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025-2026, Tri Dao.
2
+ # Given a 2D array of partial squared sums, compute rstd[m] = rsqrt(sum_n(x[m,n]) * scale + eps).
3
+ # This is the second kernel in a gemm_rms fused pipeline where the first GEMM kernel
4
+ # writes per-tile partial sums of squares.
5
+
6
+ import math
7
+ from typing import Type
8
+
9
+ import cuda.bindings.driver as cuda
10
+
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+ from cutlass import Float32, const_expr
14
+
15
+ import torch
16
+ from ._ops_compat import add_quack_op_namespace_prefix
17
+ from torch import Tensor
18
+
19
+ from . import copy_utils as copy_utils
20
+ from .compile_utils import make_fake_tensor as fake_tensor
21
+ from .reduce import row_reduce
22
+ from .reduction_base import ReductionBase
23
+ from .cache_utils import jit_cache
24
+ from .cute_dsl_utils import torch2cute_dtype_map
25
+
26
+
27
+ class RmsFinalReduce(ReductionBase):
28
+ """Reduce partial squared sums and compute rstd: rstd[m] = rsqrt(sum_n(x[m,n]) * scale + eps).
29
+
30
+ Inherits from ReductionBase for tiled copy, reduction buffer, and cluster support.
31
+ """
32
+
33
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
34
+ super().__init__(dtype, N, stage=1)
35
+
36
+ def _threads_per_row(self):
37
+ N = self.N
38
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
39
+ if N <= limit:
40
+ return threads
41
+ return 256
42
+
43
+ def _set_cluster_n(self):
44
+ self.cluster_n = 1
45
+
46
+ @cute.jit
47
+ def __call__(
48
+ self,
49
+ mX: cute.Tensor,
50
+ mRstd: cute.Tensor,
51
+ scale: Float32,
52
+ eps: Float32,
53
+ stream: cuda.CUstream,
54
+ ):
55
+ assert mX.element_type == self.dtype
56
+ self._set_cluster_n()
57
+ vecsize = math.gcd(self.N, 128 // self.dtype.width)
58
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
59
+ num_threads = tiled_copy.size
60
+ self.kernel(mX, mRstd, scale, eps, tiler_mn, tiled_copy, threads_per_row).launch(
61
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
62
+ block=[num_threads, 1, 1],
63
+ stream=stream,
64
+ )
65
+
66
+ @cute.kernel
67
+ def kernel(
68
+ self,
69
+ mX: cute.Tensor,
70
+ mRstd: cute.Tensor,
71
+ scale: Float32,
72
+ eps: Float32,
73
+ tiler_mn: cute.Shape,
74
+ tiled_copy: cute.TiledCopy,
75
+ threads_per_row: cutlass.Constexpr[int],
76
+ ):
77
+ tidx, _, _ = cute.arch.thread_idx()
78
+ bidx, _, _ = cute.arch.block_idx()
79
+ tv_layout = tiled_copy.layout_tv_tiled
80
+
81
+ smem = cutlass.utils.SmemAllocator()
82
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
83
+
84
+ shape = mX.shape
85
+ idX = cute.make_identity_tensor(shape)
86
+ gX = cute.local_tile(mX, tiler_mn, (bidx, 0))
87
+ cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
88
+
89
+ thr_copy = tiled_copy.get_slice(tidx)
90
+ tXgX = thr_copy.partition_S(gX)
91
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
92
+
93
+ tXrX = cute.make_rmem_tensor_like(tXgX)
94
+ cute.filter_zeros(tXrX).fill(0)
95
+
96
+ is_even_N = const_expr(shape[1] == tiler_mn[1])
97
+ tXpX = (
98
+ copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
99
+ if not is_even_N
100
+ else None
101
+ )
102
+
103
+ row = tXcX[0][0]
104
+ if row < shape[0]:
105
+ copy_utils.copy(tXgX, tXrX, pred=tXpX)
106
+ x = tXrX.load().to(Float32)
107
+
108
+ sum_x = row_reduce(
109
+ x,
110
+ cute.ReductionOp.ADD,
111
+ threads_per_row,
112
+ reduction_buffer[None, None, 0],
113
+ mbar_ptr,
114
+ init_val=0.0,
115
+ )
116
+ rstd = cute.math.rsqrt(sum_x * scale + eps, fastmath=True)
117
+ if tXcX[0][1] == 0 and row < shape[0]:
118
+ mRstd[row] = rstd
119
+
120
+
121
+ @jit_cache
122
+ def _compile_rms_final_reduce(dtype, N):
123
+ batch_sym = cute.sym_int()
124
+ div = math.gcd(N, 128 // dtype.width)
125
+ x_cute = fake_tensor(dtype, (batch_sym, N), div)
126
+ rstd_cute = fake_tensor(Float32, (batch_sym,))
127
+ return cute.compile(
128
+ RmsFinalReduce(dtype, N),
129
+ x_cute,
130
+ rstd_cute,
131
+ Float32(0), # scale
132
+ Float32(0), # eps
133
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
134
+ options="--enable-tvm-ffi",
135
+ )
136
+
137
+
138
+ @torch.library.custom_op(
139
+ add_quack_op_namespace_prefix("rms_final_reduce_out"),
140
+ mutates_args=("rstd",),
141
+ device_types="cuda",
142
+ )
143
+ def _rms_final_reduce_out(
144
+ x: Tensor,
145
+ rstd: Tensor,
146
+ scale: float,
147
+ eps: float,
148
+ ) -> None:
149
+ """Compute rstd[m] = rsqrt(sum_n(x[m, n]) * scale + eps)."""
150
+ x_dtype = torch2cute_dtype_map[x.dtype]
151
+ N = x.shape[1]
152
+ compiled_fn = _compile_rms_final_reduce(x_dtype, N)
153
+ compiled_fn(x, rstd, scale, eps)
154
+
155
+
156
+ @_rms_final_reduce_out.register_fake
157
+ def _rms_final_reduce_out_fake(x, rstd, scale, eps):
158
+ from .cache_utils import COMPILE_ONLY
159
+
160
+ if COMPILE_ONLY and not isinstance(x.shape[0], torch.SymInt):
161
+ x_dtype = torch2cute_dtype_map[x.dtype]
162
+ _compile_rms_final_reduce(x_dtype, x.shape[1])
163
+
164
+
165
+ def rms_final_reduce(
166
+ x: Tensor, # (M, N) partial squared sums
167
+ scale: float, # typically 1.0 / total_columns
168
+ eps: float = 1e-6,
169
+ ) -> Tensor:
170
+ """Compute rstd[m] = rsqrt(sum_n(x[m, n]) * scale + eps)."""
171
+ assert x.ndim == 2
172
+ M = x.shape[0]
173
+ rstd = torch.empty(M, dtype=torch.float32, device=x.device)
174
+
175
+ from .cache_utils import COMPILE_ONLY
176
+
177
+ if COMPILE_ONLY:
178
+ return rstd
179
+
180
+ _rms_final_reduce_out(x, rstd, scale, eps)
181
+ return rstd