danieldk HF Staff commited on
Commit
4298e26
·
verified ·
1 Parent(s): 828edc5

Build uploaded using `kernels`.

Browse files
Files changed (44) hide show
  1. build/torch-cuda/__init__.py +24 -0
  2. build/torch-cuda/_ops.py +8 -0
  3. build/torch-cuda/ampere_helpers.py +103 -0
  4. build/torch-cuda/barrier.py +71 -0
  5. build/torch-cuda/benchmark.py +268 -0
  6. build/torch-cuda/blackwell_helpers.py +1089 -0
  7. build/torch-cuda/block_info.py +108 -0
  8. build/torch-cuda/block_sparse_utils.py +1476 -0
  9. build/torch-cuda/block_sparsity.py +440 -0
  10. build/torch-cuda/cache_utils.py +307 -0
  11. build/torch-cuda/compute_block_sparsity.py +378 -0
  12. build/torch-cuda/copy_utils.py +372 -0
  13. build/torch-cuda/cute_dsl_ptxas.py +151 -0
  14. build/torch-cuda/cute_dsl_utils.py +167 -0
  15. build/torch-cuda/fast_math.py +21 -0
  16. build/torch-cuda/flash_attn4/__init__.py +26 -0
  17. build/torch-cuda/flash_bwd.py +1264 -0
  18. build/torch-cuda/flash_bwd_postprocess.py +585 -0
  19. build/torch-cuda/flash_bwd_preprocess.py +361 -0
  20. build/torch-cuda/flash_bwd_sm100.py +0 -0
  21. build/torch-cuda/flash_bwd_sm90.py +1591 -0
  22. build/torch-cuda/flash_fwd.py +0 -0
  23. build/torch-cuda/flash_fwd_combine.py +692 -0
  24. build/torch-cuda/flash_fwd_sm100.py +0 -0
  25. build/torch-cuda/interface.py +1855 -0
  26. build/torch-cuda/mask.py +653 -0
  27. build/torch-cuda/metadata.json +8 -0
  28. build/torch-cuda/mma_sm100_desc.py +296 -0
  29. build/torch-cuda/named_barrier.py +32 -0
  30. build/torch-cuda/pack_gqa.py +165 -0
  31. build/torch-cuda/paged_kv.py +214 -0
  32. build/torch-cuda/pipeline.py +440 -0
  33. build/torch-cuda/quack/__init__.py +0 -0
  34. build/torch-cuda/quack/activation.py +568 -0
  35. build/torch-cuda/quack/compile_utils.py +19 -0
  36. build/torch-cuda/quack/copy_utils.py +1007 -0
  37. build/torch-cuda/quack/cute_dsl_utils.py +165 -0
  38. build/torch-cuda/quack/layout_utils.py +297 -0
  39. build/torch-cuda/quack/sm90_utils.py +161 -0
  40. build/torch-cuda/seqlen_info.py +138 -0
  41. build/torch-cuda/softmax.py +592 -0
  42. build/torch-cuda/testing.py +456 -0
  43. build/torch-cuda/tile_scheduler.py +727 -0
  44. build/torch-cuda/utils.py +698 -0
build/torch-cuda/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flash Attention CUTE (CUDA Template Engine) implementation."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ # Update when syncing again.
6
+ __version__ = "4.0.0.beta4"
7
+
8
+ import cutlass.cute as cute
9
+
10
+ from .interface import (
11
+ flash_attn_func,
12
+ flash_attn_varlen_func,
13
+ )
14
+
15
+ from .cute_dsl_utils import cute_compile_patched
16
+
17
+ # Patch cute.compile to optionally dump SASS
18
+ cute.compile = cute_compile_patched
19
+
20
+
21
+ __all__ = [
22
+ "flash_attn_func",
23
+ "flash_attn_varlen_func",
24
+ ]
build/torch-cuda/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._flash_attn4_c07a63b
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_flash_attn4_c07a63b::{op_name}"
build/torch-cuda/ampere_helpers.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Type, Callable, Optional
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
9
+ dtype_byte = cutlass.const_expr(dtype.width // 8)
10
+ bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)
11
+ smem_k_block_size = (
12
+ cutlass.const_expr(
13
+ 128
14
+ if bytes_per_row % 128 == 0
15
+ else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
16
+ )
17
+ // dtype_byte
18
+ )
19
+ swizzle_bits = (
20
+ 4
21
+ if smem_k_block_size == 128
22
+ else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
23
+ )
24
+ swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
25
+ return cute.make_composed_layout(
26
+ cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
27
+ 0,
28
+ cute.make_ordered_layout(
29
+ (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)
30
+ ),
31
+ )
32
+
33
+
34
+ @cute.jit
35
+ def gemm(
36
+ tiled_mma: cute.TiledMma,
37
+ acc: cute.Tensor,
38
+ tCrA: cute.Tensor,
39
+ tCrB: cute.Tensor,
40
+ tCsA: cute.Tensor,
41
+ tCsB: cute.Tensor,
42
+ smem_thr_copy_A: cute.TiledCopy,
43
+ smem_thr_copy_B: cute.TiledCopy,
44
+ hook_fn: Optional[Callable] = None,
45
+ A_in_regs: cutlass.Constexpr[bool] = False,
46
+ B_in_regs: cutlass.Constexpr[bool] = False,
47
+ swap_AB: cutlass.Constexpr[bool] = False,
48
+ ) -> None:
49
+ if cutlass.const_expr(swap_AB):
50
+ gemm(
51
+ tiled_mma,
52
+ acc,
53
+ tCrB,
54
+ tCrA,
55
+ tCsB,
56
+ tCsA,
57
+ smem_thr_copy_B,
58
+ smem_thr_copy_A,
59
+ hook_fn,
60
+ A_in_regs=B_in_regs,
61
+ B_in_regs=A_in_regs,
62
+ swap_AB=False,
63
+ )
64
+ else:
65
+ tCrA_copy_view = smem_thr_copy_A.retile(tCrA)
66
+ tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
67
+ if cutlass.const_expr(not A_in_regs):
68
+ cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])
69
+ if cutlass.const_expr(not B_in_regs):
70
+ cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
71
+ for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):
72
+ if k < cute.size(tCsA.shape[2]) - 1:
73
+ if cutlass.const_expr(not A_in_regs):
74
+ cute.copy(
75
+ smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]
76
+ )
77
+ if cutlass.const_expr(not B_in_regs):
78
+ cute.copy(
79
+ smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]
80
+ )
81
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
82
+ if cutlass.const_expr(k == 0 and hook_fn is not None):
83
+ hook_fn()
84
+
85
+
86
+ @cute.jit
87
+ def gemm_rs(
88
+ tiled_mma: cute.TiledMma,
89
+ acc: cute.Tensor,
90
+ tCrA: cute.Tensor,
91
+ tCrB: cute.Tensor,
92
+ tCsB: cute.Tensor,
93
+ smem_thr_copy_B: cute.TiledCopy,
94
+ hook_fn: Optional[Callable] = None,
95
+ ) -> None:
96
+ tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
97
+ cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
98
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
99
+ if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):
100
+ cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])
101
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
102
+ if cutlass.const_expr(k == 0 and hook_fn is not None):
103
+ hook_fn()
build/torch-cuda/barrier.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cutlass
2
+ import cutlass.cute as cute
3
+ from cutlass import Int32
4
+ from cutlass.cutlass_dsl import T, dsl_user_op
5
+ from cutlass._mlir.dialects import llvm
6
+
7
+
8
+ @dsl_user_op
9
+ def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
10
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
11
+ state = llvm.inline_asm(
12
+ T.i32(),
13
+ [lock_ptr_i64],
14
+ "ld.global.acquire.gpu.b32 $0, [$1];",
15
+ "=r,l",
16
+ has_side_effects=True,
17
+ is_align_stack=False,
18
+ asm_dialect=llvm.AsmDialect.AD_ATT,
19
+ )
20
+ return cutlass.Int32(state)
21
+
22
+
23
+ @dsl_user_op
24
+ def red_relaxed(
25
+ lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
26
+ ) -> None:
27
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
28
+ llvm.inline_asm(
29
+ None,
30
+ [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
31
+ "red.relaxed.gpu.global.add.s32 [$0], $1;",
32
+ "l,r",
33
+ has_side_effects=True,
34
+ is_align_stack=False,
35
+ asm_dialect=llvm.AsmDialect.AD_ATT,
36
+ )
37
+
38
+
39
+ @dsl_user_op
40
+ def red_release(
41
+ lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
42
+ ) -> None:
43
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
44
+ llvm.inline_asm(
45
+ None,
46
+ [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
47
+ "red.release.gpu.global.add.s32 [$0], $1;",
48
+ "l,r",
49
+ has_side_effects=True,
50
+ is_align_stack=False,
51
+ asm_dialect=llvm.AsmDialect.AD_ATT,
52
+ )
53
+
54
+
55
+ @cute.jit
56
+ def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
57
+ flag_ptr = lock_ptr + flag_offset
58
+ if thread_idx == 0:
59
+ read_val = Int32(0)
60
+ while read_val != val:
61
+ read_val = ld_acquire(flag_ptr)
62
+
63
+
64
+ @cute.jit
65
+ def arrive_inc(
66
+ lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
67
+ ) -> None:
68
+ flag_ptr = lock_ptr + flag_offset
69
+ if thread_idx == 0:
70
+ red_release(flag_ptr, val)
71
+ # red_relaxed(flag_ptr, val)
build/torch-cuda/benchmark.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ """Useful functions for writing test code."""
3
+
4
+ import torch
5
+ import torch.utils.benchmark as benchmark
6
+
7
+
8
+ def benchmark_forward(
9
+ fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
10
+ ):
11
+ """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
12
+ if verbose:
13
+ print(desc, "- Forward pass")
14
+
15
+ def amp_wrapper(*inputs, **kwinputs):
16
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
17
+ fn(*inputs, **kwinputs)
18
+
19
+ t = benchmark.Timer(
20
+ stmt="fn_amp(*inputs, **kwinputs)",
21
+ globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
22
+ num_threads=torch.get_num_threads(),
23
+ )
24
+ m = t.timeit(repeats)
25
+ if verbose:
26
+ print(m)
27
+ return t, m
28
+
29
+
30
+ def benchmark_backward(
31
+ fn,
32
+ *inputs,
33
+ grad=None,
34
+ repeats=10,
35
+ desc="",
36
+ verbose=True,
37
+ amp=False,
38
+ amp_dtype=torch.float16,
39
+ **kwinputs,
40
+ ):
41
+ """Use Pytorch Benchmark on the backward pass of an arbitrary function."""
42
+ if verbose:
43
+ print(desc, "- Backward pass")
44
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
45
+ y = fn(*inputs, **kwinputs)
46
+ if type(y) is tuple:
47
+ y = y[0]
48
+ if grad is None:
49
+ grad = torch.randn_like(y)
50
+ else:
51
+ if grad.shape != y.shape:
52
+ raise RuntimeError("Grad shape does not match output shape")
53
+
54
+ def f(*inputs, y, grad):
55
+ # Set .grad to None to avoid extra operation of gradient accumulation
56
+ for x in inputs:
57
+ if isinstance(x, torch.Tensor):
58
+ x.grad = None
59
+ y.backward(grad, retain_graph=True)
60
+
61
+ t = benchmark.Timer(
62
+ stmt="f(*inputs, y=y, grad=grad)",
63
+ globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
64
+ num_threads=torch.get_num_threads(),
65
+ )
66
+ m = t.timeit(repeats)
67
+ if verbose:
68
+ print(m)
69
+ return t, m
70
+
71
+
72
+ def benchmark_combined(
73
+ fn,
74
+ *inputs,
75
+ grad=None,
76
+ repeats=10,
77
+ desc="",
78
+ verbose=True,
79
+ amp=False,
80
+ amp_dtype=torch.float16,
81
+ **kwinputs,
82
+ ):
83
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
84
+ if verbose:
85
+ print(desc, "- Forward + Backward pass")
86
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
87
+ y = fn(*inputs, **kwinputs)
88
+ if type(y) is tuple:
89
+ y = y[0]
90
+ if grad is None:
91
+ grad = torch.randn_like(y)
92
+ else:
93
+ if grad.shape != y.shape:
94
+ raise RuntimeError("Grad shape does not match output shape")
95
+
96
+ def f(grad, *inputs, **kwinputs):
97
+ for x in inputs:
98
+ if isinstance(x, torch.Tensor):
99
+ x.grad = None
100
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
101
+ y = fn(*inputs, **kwinputs)
102
+ if type(y) is tuple:
103
+ y = y[0]
104
+ y.backward(grad, retain_graph=True)
105
+
106
+ t = benchmark.Timer(
107
+ stmt="f(grad, *inputs, **kwinputs)",
108
+ globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
109
+ num_threads=torch.get_num_threads(),
110
+ )
111
+ m = t.timeit(repeats)
112
+ if verbose:
113
+ print(m)
114
+ return t, m
115
+
116
+
117
+ def benchmark_fwd_bwd(
118
+ fn,
119
+ *inputs,
120
+ grad=None,
121
+ repeats=10,
122
+ desc="",
123
+ verbose=True,
124
+ amp=False,
125
+ amp_dtype=torch.float16,
126
+ **kwinputs,
127
+ ):
128
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
129
+ return (
130
+ benchmark_forward(
131
+ fn,
132
+ *inputs,
133
+ repeats=repeats,
134
+ desc=desc,
135
+ verbose=verbose,
136
+ amp=amp,
137
+ amp_dtype=amp_dtype,
138
+ **kwinputs,
139
+ ),
140
+ benchmark_backward(
141
+ fn,
142
+ *inputs,
143
+ grad=grad,
144
+ repeats=repeats,
145
+ desc=desc,
146
+ verbose=verbose,
147
+ amp=amp,
148
+ amp_dtype=amp_dtype,
149
+ **kwinputs,
150
+ ),
151
+ )
152
+
153
+
154
+ def benchmark_all(
155
+ fn,
156
+ *inputs,
157
+ grad=None,
158
+ repeats=10,
159
+ desc="",
160
+ verbose=True,
161
+ amp=False,
162
+ amp_dtype=torch.float16,
163
+ **kwinputs,
164
+ ):
165
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
166
+ return (
167
+ benchmark_forward(
168
+ fn,
169
+ *inputs,
170
+ repeats=repeats,
171
+ desc=desc,
172
+ verbose=verbose,
173
+ amp=amp,
174
+ amp_dtype=amp_dtype,
175
+ **kwinputs,
176
+ ),
177
+ benchmark_backward(
178
+ fn,
179
+ *inputs,
180
+ grad=grad,
181
+ repeats=repeats,
182
+ desc=desc,
183
+ verbose=verbose,
184
+ amp=amp,
185
+ amp_dtype=amp_dtype,
186
+ **kwinputs,
187
+ ),
188
+ benchmark_combined(
189
+ fn,
190
+ *inputs,
191
+ grad=grad,
192
+ repeats=repeats,
193
+ desc=desc,
194
+ verbose=verbose,
195
+ amp=amp,
196
+ amp_dtype=amp_dtype,
197
+ **kwinputs,
198
+ ),
199
+ )
200
+
201
+
202
+ def pytorch_profiler(
203
+ fn,
204
+ *inputs,
205
+ trace_filename=None,
206
+ backward=False,
207
+ amp=False,
208
+ amp_dtype=torch.float16,
209
+ cpu=False,
210
+ verbose=True,
211
+ **kwinputs,
212
+ ):
213
+ """Wrap benchmark functions in Pytorch profiler to see CUDA information."""
214
+ if backward:
215
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
216
+ out = fn(*inputs, **kwinputs)
217
+ if type(out) is tuple:
218
+ out = out[0]
219
+ g = torch.randn_like(out)
220
+ for _ in range(30): # Warm up
221
+ if backward:
222
+ for x in inputs:
223
+ if isinstance(x, torch.Tensor):
224
+ x.grad = None
225
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
226
+ out = fn(*inputs, **kwinputs)
227
+ if type(out) is tuple:
228
+ out = out[0]
229
+ # Backward should be done outside autocast
230
+ if backward:
231
+ out.backward(g, retain_graph=True)
232
+ activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
233
+ torch.profiler.ProfilerActivity.CUDA
234
+ ]
235
+ with torch.profiler.profile(
236
+ activities=activities,
237
+ record_shapes=True,
238
+ # profile_memory=True,
239
+ with_stack=True,
240
+ ) as prof:
241
+ if backward:
242
+ for x in inputs:
243
+ if isinstance(x, torch.Tensor):
244
+ x.grad = None
245
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
246
+ out = fn(*inputs, **kwinputs)
247
+ if type(out) is tuple:
248
+ out = out[0]
249
+ if backward:
250
+ out.backward(g, retain_graph=True)
251
+ if verbose:
252
+ # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
253
+ print(prof.key_averages().table(row_limit=50))
254
+ if trace_filename is not None:
255
+ prof.export_chrome_trace(trace_filename)
256
+
257
+
258
+ def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
259
+ torch.cuda.empty_cache()
260
+ torch.cuda.reset_peak_memory_stats()
261
+ torch.cuda.synchronize()
262
+ fn(*inputs, **kwinputs)
263
+ torch.cuda.synchronize()
264
+ mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
265
+ if verbose:
266
+ print(f"{desc} max memory: {mem}GB")
267
+ torch.cuda.empty_cache()
268
+ return mem
build/torch-cuda/blackwell_helpers.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional, Tuple
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Int32, Boolean, const_expr
7
+ from cutlass.cute.nvgpu import tcgen05
8
+ from cutlass._mlir.dialects import llvm
9
+
10
+ from . import mma_sm100_desc as sm100_desc
11
+
12
+
13
+ @cute.jit
14
+ def gemm_w_idx(
15
+ tiled_mma: cute.TiledMma,
16
+ acc: cute.Tensor,
17
+ tCrA: cute.Tensor,
18
+ tCrB: cute.Tensor,
19
+ A_idx: Optional[Int32] = None,
20
+ B_idx: Optional[Int32] = None,
21
+ zero_init: bool | Boolean = False,
22
+ swap_AB: bool = False,
23
+ num_unroll_groups: int = 1,
24
+ ) -> None:
25
+ if const_expr(swap_AB):
26
+ return gemm_w_idx(
27
+ tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False
28
+ )
29
+ else:
30
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
31
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
32
+
33
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
34
+ for k in cutlass.range(
35
+ cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups
36
+ ):
37
+ mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
38
+ cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)
39
+
40
+
41
+ @cute.jit
42
+ def gemm_ptx_w_idx(
43
+ tiled_mma: cute.TiledMma,
44
+ acc: cute.Tensor,
45
+ tCrA: cute.Tensor,
46
+ tCrB: cute.Tensor,
47
+ sA: Optional[cute.Tensor],
48
+ sB: cute.Tensor,
49
+ A_idx: Optional[Int32] = None,
50
+ B_idx: Optional[Int32] = None,
51
+ zero_init: bool | Boolean = False,
52
+ cta_group: int = 1,
53
+ **kwargs,
54
+ ) -> None:
55
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
56
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
57
+ sA_cur = None
58
+ if const_expr(sA is not None):
59
+ sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]
60
+ sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]
61
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
62
+ acc_tmem_addr = acc.iterator.toint()
63
+ gemm_ptx_partial(
64
+ mma_atom.op,
65
+ acc_tmem_addr,
66
+ rA,
67
+ rB,
68
+ sA_cur,
69
+ sB_cur,
70
+ zero_init=zero_init,
71
+ cta_group=cta_group,
72
+ **kwargs,
73
+ )
74
+
75
+
76
+ @cute.jit
77
+ def gemm(
78
+ tiled_mma: cute.TiledMma,
79
+ acc: cute.Tensor,
80
+ tCrA: cute.Tensor,
81
+ tCrB: cute.Tensor,
82
+ zero_init: bool | Boolean = False,
83
+ ) -> None:
84
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
85
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
86
+ mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
87
+ cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
88
+
89
+
90
+ def i64_to_i32x2(i: int) -> Tuple[int, int]:
91
+ """Convert a 64-bit integer to a tuple of two 32-bit integers."""
92
+ return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF
93
+
94
+
95
+ @cute.jit
96
+ def gemm_ptx(
97
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
98
+ acc: cute.Tensor,
99
+ tCrA: cute.Tensor,
100
+ tCrB: cute.Tensor,
101
+ sA: Optional[cute.Tensor],
102
+ sB: cute.Tensor,
103
+ zero_init: bool | Boolean = False,
104
+ ) -> None:
105
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
106
+ if const_expr(not is_ts):
107
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
108
+ sA_layout = sA.layout if sA is not None else None
109
+ sB_layout = sB.layout
110
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
111
+ if const_expr(not is_ts):
112
+ sA_swizzle = sA.iterator.type.swizzle_type
113
+ smem_desc_base_a: int = const_expr(
114
+ sm100_desc.make_smem_desc_base(
115
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
116
+ sA_swizzle,
117
+ sm100_desc.Major.K
118
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
119
+ else sm100_desc.Major.MN,
120
+ )
121
+ )
122
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
123
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
124
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
125
+ else:
126
+ smem_desc_base_a = None
127
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
128
+ sB_swizzle = sB.iterator.type.swizzle_type
129
+ smem_desc_base_b: int = const_expr(
130
+ sm100_desc.make_smem_desc_base(
131
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
132
+ sB_swizzle,
133
+ sm100_desc.Major.K
134
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
135
+ else sm100_desc.Major.MN,
136
+ )
137
+ )
138
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
139
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
140
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
141
+
142
+ if const_expr(not is_ts):
143
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(
144
+ sA[None, None, 0].iterator
145
+ )
146
+ else:
147
+ smem_desc_start_a_lo = None
148
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(
149
+ sB[None, None, 0].iterator
150
+ )
151
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
152
+ if const_expr(not is_ts):
153
+ smem_desc_a_lo = smem_desc_start_a_lo + (
154
+ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
155
+ )
156
+ smem_desc_b_lo = smem_desc_start_b_lo + (
157
+ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
158
+ )
159
+ # with cute.arch.elect_one():
160
+ # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo)
161
+ # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct)
162
+ with cute.arch.elect_one():
163
+ if const_expr(not is_ts):
164
+ llvm.inline_asm(
165
+ None,
166
+ [
167
+ acc.iterator.toint().ir_value(),
168
+ smem_desc_a_lo.ir_value(),
169
+ smem_desc_b_lo.ir_value(),
170
+ Int32(not zero_init or k != 0).ir_value(),
171
+ ],
172
+ "{\n\t"
173
+ ".reg .pred p;\n\t"
174
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
175
+ ".reg .b32 idesc;\n\t"
176
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
177
+ f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t"
178
+ f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
179
+ "setp.ne.b32 p, $3, 0;\n\t"
180
+ f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t"
181
+ "}\n",
182
+ "r,r,r,r",
183
+ has_side_effects=True,
184
+ is_align_stack=False,
185
+ asm_dialect=llvm.AsmDialect.AD_ATT,
186
+ )
187
+ else:
188
+ llvm.inline_asm(
189
+ None,
190
+ [
191
+ acc.iterator.toint().ir_value(),
192
+ tCrA[None, None, k].iterator.toint().ir_value(),
193
+ smem_desc_b_lo.ir_value(),
194
+ Int32(not zero_init or k != 0).ir_value(),
195
+ ],
196
+ "{\n\t"
197
+ ".reg .pred p;\n\t"
198
+ ".reg .b64 smem_desc_b;\n\t"
199
+ f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t"
200
+ "setp.ne.b32 p, $3, 0;\n\t"
201
+ f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t"
202
+ "}\n",
203
+ "r,r,r,r",
204
+ has_side_effects=True,
205
+ is_align_stack=False,
206
+ asm_dialect=llvm.AsmDialect.AD_ATT,
207
+ )
208
+
209
+
210
+ @cute.jit
211
+ def gemm_ptx_loop(
212
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
213
+ acc: cute.Tensor,
214
+ tCrA: cute.Tensor,
215
+ tCrB: cute.Tensor,
216
+ sA: Optional[cute.Tensor],
217
+ sB: cute.Tensor,
218
+ zero_init: bool | Boolean = False,
219
+ ) -> None:
220
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
221
+ if const_expr(not is_ts):
222
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
223
+ sA_layout = sA.layout if sA is not None else tCrA.layout
224
+ sB_layout = sB.layout
225
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
226
+ if const_expr(not is_ts):
227
+ sA_swizzle = sA.iterator.type.swizzle_type
228
+ smem_desc_base_a: int = const_expr(
229
+ sm100_desc.make_smem_desc_base(
230
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
231
+ sA_swizzle,
232
+ sm100_desc.Major.K
233
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
234
+ else sm100_desc.Major.MN,
235
+ )
236
+ )
237
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
238
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
239
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
240
+ else:
241
+ smem_desc_base_a = None
242
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
243
+ sB_swizzle = sB.iterator.type.swizzle_type
244
+ smem_desc_base_b: int = const_expr(
245
+ sm100_desc.make_smem_desc_base(
246
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
247
+ sB_swizzle,
248
+ sm100_desc.Major.K
249
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
250
+ else sm100_desc.Major.MN,
251
+ )
252
+ )
253
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
254
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
255
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
256
+
257
+ if const_expr(not is_ts):
258
+ offset_a = [
259
+ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4
260
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
261
+ ]
262
+ else:
263
+ offset_a = [
264
+ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
265
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))
266
+ ]
267
+ offset_a_diff = [
268
+ offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
269
+ ]
270
+ offset_b = [
271
+ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4
272
+ for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))
273
+ ]
274
+ offset_b_diff = [
275
+ offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))
276
+ ]
277
+
278
+ if const_expr(not is_ts):
279
+ smem_desc_start_a_lo = Int32(
280
+ smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
281
+ )
282
+ else:
283
+ smem_desc_start_a_lo = None
284
+ smem_desc_start_b_lo = Int32(
285
+ smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
286
+ )
287
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
288
+ if const_expr(not is_ts):
289
+ llvm.inline_asm(
290
+ None,
291
+ [
292
+ acc.iterator.toint().ir_value(),
293
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
294
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
295
+ Int32(not zero_init).ir_value(),
296
+ ],
297
+ "{\n\t"
298
+ ".reg .pred leader_thread;\n\t"
299
+ ".reg .pred p;\n\t"
300
+ ".reg .b32 idesc;\n\t"
301
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
302
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
303
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
304
+ "elect.sync _|leader_thread, -1;\n\t"
305
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
306
+ "mov.b32 smem_desc_a_lo, $1;\n\t"
307
+ "mov.b32 smem_desc_b_lo, $2;\n\t"
308
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
309
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
310
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
311
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
312
+ "setp.ne.b32 p, $3, 0;\n\t"
313
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
314
+ + "".join(
315
+ (
316
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
317
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
318
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
319
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
320
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
321
+ )
322
+ for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
323
+ )
324
+ + "}\n",
325
+ "r,r,r,r",
326
+ has_side_effects=True,
327
+ is_align_stack=False,
328
+ asm_dialect=llvm.AsmDialect.AD_ATT,
329
+ )
330
+ else:
331
+ llvm.inline_asm(
332
+ None,
333
+ [
334
+ acc.iterator.toint().ir_value(),
335
+ Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
336
+ Int32(smem_desc_start_b_lo).ir_value(),
337
+ Int32(not zero_init).ir_value(),
338
+ ],
339
+ "{\n\t"
340
+ ".reg .pred leader_thread;\n\t"
341
+ ".reg .pred p;\n\t"
342
+ ".reg .b32 idesc;\n\t"
343
+ ".reg .b32 tmem_a;\n\t"
344
+ ".reg .b32 smem_desc_b_lo;\n\t"
345
+ ".reg .b32 smem_desc_b_hi;\n\t"
346
+ ".reg .b64 smem_desc_b;\n\t"
347
+ "elect.sync _|leader_thread, -1;\n\t"
348
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
349
+ "mov.b32 tmem_a, $1;\n\t"
350
+ "mov.b32 smem_desc_b_lo, $2;\n\t"
351
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
352
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
353
+ "setp.ne.b32 p, $3, 0;\n\t"
354
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
355
+ + "".join(
356
+ (
357
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
358
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
359
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
360
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t"
361
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
362
+ )
363
+ for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))
364
+ )
365
+ + "}\n",
366
+ "r,r,r,r",
367
+ has_side_effects=True,
368
+ is_align_stack=False,
369
+ asm_dialect=llvm.AsmDialect.AD_ATT,
370
+ )
371
+
372
+
373
+ @cute.jit
374
+ def gemm_ptx_partial(
375
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
376
+ acc_tmem_addr: Int32,
377
+ tCrA: cute.Tensor,
378
+ tCrB: cute.Tensor,
379
+ sA: Optional[cute.Tensor],
380
+ sB: cute.Tensor,
381
+ mbar_ptr: Optional[cutlass.Pointer] = None,
382
+ mbar_phase: Optional[Int32] = None,
383
+ split_arrive: Optional[int] = None,
384
+ zero_init: bool | Boolean = False,
385
+ # sA_offset: Int32 = 0,
386
+ # acc_offset: Int32 = 0,
387
+ tA_addr: Optional[Int32] = None,
388
+ cta_group: int = 1,
389
+ ) -> None:
390
+ # acc_tmem_addr += acc_offset
391
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
392
+ if const_expr(not is_ts):
393
+ assert sA is not None, "sA must be provided when a_src is not TMEM"
394
+ sA_layout = sA.layout if sA is not None else tCrA.layout
395
+ sB_layout = sB.layout
396
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
397
+ if const_expr(not is_ts):
398
+ sA_swizzle = sA.iterator.type.swizzle_type
399
+ smem_desc_base_a: int = const_expr(
400
+ sm100_desc.make_smem_desc_base(
401
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
402
+ sA_swizzle,
403
+ sm100_desc.Major.K
404
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
405
+ else sm100_desc.Major.MN,
406
+ )
407
+ )
408
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
409
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
410
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
411
+ else:
412
+ smem_desc_base_a = None
413
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
414
+ sB_swizzle = sB.iterator.type.swizzle_type
415
+ smem_desc_base_b: int = const_expr(
416
+ sm100_desc.make_smem_desc_base(
417
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
418
+ sB_swizzle,
419
+ sm100_desc.Major.K
420
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
421
+ else sm100_desc.Major.MN,
422
+ )
423
+ )
424
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
425
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
426
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
427
+
428
+ tCrA_layout = (
429
+ tCrA.layout
430
+ if const_expr(not is_ts)
431
+ else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)
432
+ )
433
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]
434
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
435
+ offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]
436
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
437
+
438
+ if const_expr(not is_ts):
439
+ smem_desc_start_a_lo = Int32(
440
+ smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)
441
+ )
442
+ # ) + sA_offset
443
+ else:
444
+ smem_desc_start_a_lo = None
445
+ smem_desc_start_b_lo = Int32(
446
+ smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)
447
+ )
448
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
449
+ if const_expr(not is_ts):
450
+ assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
451
+ llvm.inline_asm(
452
+ None,
453
+ [
454
+ # acc.iterator.toint().ir_value(),
455
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
456
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
457
+ Int32(not zero_init).ir_value(),
458
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
459
+ ],
460
+ "{\n\t"
461
+ ".reg .pred leader_thread;\n\t"
462
+ ".reg .pred p;\n\t"
463
+ ".reg .b32 idesc;\n\t"
464
+ ".reg .b32 tmem_acc;\n\t"
465
+ ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
466
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
467
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
468
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
469
+ "elect.sync _|leader_thread, -1;\n\t"
470
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
471
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
472
+ f"mov.b32 tmem_acc, $3;\n\t"
473
+ "mov.b32 smem_desc_a_lo_start, $0;\n\t"
474
+ "mov.b32 smem_desc_b_lo_start, $1;\n\t"
475
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
476
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
477
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
478
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
479
+ "setp.ne.b32 p, $2, 0;\n\t"
480
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
481
+ + "".join(
482
+ (
483
+ # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
484
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
485
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
486
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
487
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
488
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
489
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
490
+ )
491
+ for k in range(1, cute.size(tCrA.shape[2]))
492
+ )
493
+ + "}\n",
494
+ # "r,r,r",
495
+ "r,r,r,r",
496
+ has_side_effects=True,
497
+ is_align_stack=False,
498
+ asm_dialect=llvm.AsmDialect.AD_ATT,
499
+ )
500
+ else:
501
+ # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to
502
+ # explicitly pass in the tA_addr for correctness.
503
+ tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr
504
+ input_args = [
505
+ # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),
506
+ Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),
507
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
508
+ Int32(not zero_init).ir_value(),
509
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
510
+ ]
511
+ if const_expr(mbar_ptr is not None):
512
+ assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
513
+ assert split_arrive is not None, (
514
+ "split_arrive must be provided when mbar_ptr is not None"
515
+ )
516
+ split_arrive_idx = split_arrive // op.shape_mnk[2]
517
+ input_args.append(mbar_ptr.toint().ir_value())
518
+ input_args.append(Int32(mbar_phase).ir_value())
519
+ mbar_wait_str = (
520
+ ".reg .pred P1; \n\t"
521
+ "LAB_WAIT: \n\t"
522
+ "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
523
+ "@P1 bra DONE; \n\t"
524
+ "bra LAB_WAIT; \n\t"
525
+ "DONE: \n\t"
526
+ )
527
+ else:
528
+ mbar_wait_str = ""
529
+ llvm.inline_asm(
530
+ None,
531
+ # [
532
+ # # acc.iterator.toint().ir_value(),
533
+ # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
534
+ # Int32(smem_desc_start_b_lo).ir_value(),
535
+ # Int32(not zero_init).ir_value(),
536
+ # ],
537
+ input_args,
538
+ "{\n\t"
539
+ ".reg .pred leader_thread;\n\t"
540
+ ".reg .pred p;\n\t"
541
+ ".reg .b32 idesc;\n\t"
542
+ ".reg .b32 tmem_acc;\n\t"
543
+ ".reg .b32 tmem_a;\n\t"
544
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
545
+ ".reg .b32 smem_desc_b_lo;\n\t"
546
+ ".reg .b32 smem_desc_b_hi;\n\t"
547
+ ".reg .b64 smem_desc_b;\n\t"
548
+ "elect.sync _|leader_thread, -1;\n\t"
549
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
550
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
551
+ f"mov.b32 tmem_acc, $3;\n\t"
552
+ f"mov.b32 tmem_a, $0;\n\t"
553
+ f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
554
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
555
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
556
+ "setp.ne.b32 p, $2, 0;\n\t"
557
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
558
+ + "".join(
559
+ (
560
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
561
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
562
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
563
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
564
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
565
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
566
+ )
567
+ for k in range(
568
+ 1,
569
+ cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx,
570
+ )
571
+ )
572
+ + mbar_wait_str
573
+ + (
574
+ "".join(
575
+ (
576
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
577
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
578
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
579
+ )
580
+ for k in range(split_arrive_idx, cute.size(tCrA.shape[2]))
581
+ )
582
+ if const_expr(mbar_ptr is not None)
583
+ else ""
584
+ )
585
+ + "}\n",
586
+ "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
587
+ has_side_effects=True,
588
+ is_align_stack=False,
589
+ asm_dialect=llvm.AsmDialect.AD_ATT,
590
+ )
591
+
592
+
593
+ @cute.jit
594
+ def gemm_ptx_partial1(
595
+ op: cute.nvgpu.tcgen05.mma.MmaOp,
596
+ acc_tmem_addr: cutlass.Constexpr[int],
597
+ tCrA: cute.Tensor,
598
+ tCrB: cute.Tensor,
599
+ sA_base_addr_for_desc: Int32,
600
+ sA_addr_offset_for_desc: cutlass.Constexpr[int],
601
+ sA_stage: Int32,
602
+ sB_base_addr_for_desc: Int32,
603
+ sB_addr_offset_for_desc: cutlass.Constexpr[int],
604
+ sB_stage: Int32,
605
+ sA_layout: Optional[cute.Layout],
606
+ sB_layout: Optional[cute.Layout],
607
+ sA_swizzle: Optional[cute.Swizzle],
608
+ sB_swizzle: cute.Swizzle,
609
+ zero_init: bool | Boolean = False,
610
+ ) -> None:
611
+ is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
612
+ if const_expr(not is_ts):
613
+ assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM"
614
+ assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM"
615
+ idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))
616
+ if const_expr(not is_ts):
617
+ smem_desc_base_a: int = const_expr(
618
+ sm100_desc.make_smem_desc_base(
619
+ cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),
620
+ sA_swizzle,
621
+ sm100_desc.Major.K
622
+ if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
623
+ else sm100_desc.Major.MN,
624
+ )
625
+ )
626
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
627
+ smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)
628
+ smem_desc_a_hi = const_expr(smem_desc_a_hi)
629
+ else:
630
+ smem_desc_base_a = None
631
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
632
+ smem_desc_base_b: int = const_expr(
633
+ sm100_desc.make_smem_desc_base(
634
+ cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),
635
+ sB_swizzle,
636
+ sm100_desc.Major.K
637
+ if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)
638
+ else sm100_desc.Major.MN,
639
+ )
640
+ )
641
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
642
+ smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)
643
+ smem_desc_b_hi = const_expr(smem_desc_b_hi)
644
+ mask = [Int32(0)] * 4
645
+
646
+ if const_expr(not is_ts):
647
+ offset_a = [
648
+ (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4
649
+ for k in range(cute.size(tCrA.shape[2]))
650
+ ]
651
+ else:
652
+ offset_a = [
653
+ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32
654
+ for k in range(cute.size(tCrA.shape[2]))
655
+ ]
656
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]
657
+ offset_b = [
658
+ (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4
659
+ for k in range(cute.size(tCrB.shape[2]))
660
+ ]
661
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]
662
+
663
+ if const_expr(not is_ts):
664
+ # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))
665
+ smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)
666
+ else:
667
+ smem_desc_start_a_lo = None
668
+ # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))
669
+ smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)
670
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
671
+ if const_expr(not is_ts):
672
+ llvm.inline_asm(
673
+ None,
674
+ [
675
+ # acc.iterator.toint().ir_value(),
676
+ # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
677
+ Int32(sA_base_addr_for_desc).ir_value(),
678
+ Int32(sA_stage).ir_value(),
679
+ # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
680
+ Int32(sB_base_addr_for_desc).ir_value(),
681
+ Int32(sB_stage).ir_value(),
682
+ Int32(not zero_init).ir_value(),
683
+ mask[0].ir_value(),
684
+ mask[1].ir_value(),
685
+ mask[2].ir_value(),
686
+ mask[3].ir_value(),
687
+ ],
688
+ "{\n\t"
689
+ ".reg .pred leader_thread;\n\t"
690
+ ".reg .pred p;\n\t"
691
+ ".reg .b32 idesc;\n\t"
692
+ ".reg .b32 tmem_acc;\n\t"
693
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
694
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
695
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
696
+ "elect.sync _|leader_thread, -1;\n\t"
697
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
698
+ f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
699
+ # "mov.b32 smem_desc_a_lo, $0;\n\t"
700
+ # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t"
701
+ f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t"
702
+ # "mov.b32 smem_desc_b_lo, $2;\n\t"
703
+ f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t"
704
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
705
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
706
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
707
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
708
+ "setp.ne.b32 p, $4, 0;\n\t"
709
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t"
710
+ + "".join(
711
+ (
712
+ f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
713
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
714
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
715
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
716
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t"
717
+ )
718
+ for k in range(1, cute.size(tCrA.shape[2]))
719
+ )
720
+ + "}\n",
721
+ "r,r,r,r,r,r,r,r,r",
722
+ has_side_effects=True,
723
+ is_align_stack=False,
724
+ asm_dialect=llvm.AsmDialect.AD_ATT,
725
+ )
726
+ else:
727
+ llvm.inline_asm(
728
+ None,
729
+ [
730
+ # acc.iterator.toint().ir_value(),
731
+ Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),
732
+ Int32(smem_desc_start_b_lo).ir_value(),
733
+ Int32(not zero_init).ir_value(),
734
+ mask[0].ir_value(),
735
+ mask[1].ir_value(),
736
+ mask[2].ir_value(),
737
+ mask[3].ir_value(),
738
+ ],
739
+ "{\n\t"
740
+ ".reg .pred leader_thread;\n\t"
741
+ ".reg .pred p;\n\t"
742
+ ".reg .b32 idesc;\n\t"
743
+ ".reg .b32 tmem_a;\n\t"
744
+ ".reg .b32 smem_desc_b_lo;\n\t"
745
+ ".reg .b32 smem_desc_b_hi;\n\t"
746
+ ".reg .b64 smem_desc_b;\n\t"
747
+ "elect.sync _|leader_thread, -1;\n\t"
748
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
749
+ f"mov.b32 tmem_a, $1;\n\t"
750
+ f"mov.b32 smem_desc_b_lo, $2;\n\t"
751
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
752
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
753
+ "setp.ne.b32 p, $3, 0;\n\t"
754
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t"
755
+ + "".join(
756
+ (
757
+ f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
758
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
759
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
760
+ f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t"
761
+ )
762
+ for k in range(1, cute.size(tCrA.shape[2]))
763
+ )
764
+ + "}\n",
765
+ "r,r,r,r,r,r,r,r",
766
+ has_side_effects=True,
767
+ is_align_stack=False,
768
+ asm_dialect=llvm.AsmDialect.AD_ATT,
769
+ )
770
+
771
+
772
+ @cute.jit
773
+ def gemm_ptx_precomputed(
774
+ acc_tmem_addr: Int32,
775
+ smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
776
+ smem_desc_start_b: Int32,
777
+ idesc: int,
778
+ smem_desc_base_a: Optional[int],
779
+ smem_desc_base_b: int,
780
+ tCrA_layout: cute.Layout,
781
+ tCrB_layout: cute.Layout,
782
+ mbar_ptr: Optional[cutlass.Pointer] = None,
783
+ mbar_phase: Optional[Int32] = None,
784
+ zero_init: bool | Boolean = False,
785
+ cta_group: int = 1,
786
+ ) -> None:
787
+ # acc_tmem_addr += acc_offset
788
+ is_ts = const_expr(smem_desc_base_a is None)
789
+ num_k_tile = cute.size(tCrA_layout.shape[2])
790
+ if const_expr(not is_ts):
791
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
792
+ else:
793
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
794
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
795
+
796
+ tCrA_layout = (
797
+ tCrA_layout
798
+ if const_expr(not is_ts)
799
+ # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
800
+ # currently hard-coding the width to 16
801
+ else cute.recast_layout(32, 16, tCrA_layout)
802
+ )
803
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
804
+ offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)]
805
+ offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
806
+ offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)]
807
+
808
+ smem_desc_start_a_lo = None
809
+ if const_expr(not is_ts):
810
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
811
+ # smem_desc_start_a_lo = smem_desc_start_a
812
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
813
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
814
+ if const_expr(not is_ts):
815
+ assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM"
816
+ llvm.inline_asm(
817
+ None,
818
+ [
819
+ # acc.iterator.toint().ir_value(),
820
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),
821
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
822
+ Int32(not zero_init).ir_value(),
823
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
824
+ ],
825
+ "{\n\t"
826
+ ".reg .pred leader_thread;\n\t"
827
+ ".reg .pred p;\n\t"
828
+ ".reg .b32 idesc;\n\t"
829
+ ".reg .b32 tmem_acc;\n\t"
830
+ ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t"
831
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
832
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
833
+ ".reg .b64 smem_desc_a, smem_desc_b;\n\t"
834
+ "elect.sync _|leader_thread, -1;\n\t"
835
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
836
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
837
+ f"mov.b32 tmem_acc, $3;\n\t"
838
+ "mov.b32 smem_desc_a_lo_start, $0;\n\t"
839
+ "mov.b32 smem_desc_b_lo_start, $1;\n\t"
840
+ f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t"
841
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
842
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
843
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
844
+ "setp.ne.b32 p, $2, 0;\n\t"
845
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
846
+ + "".join(
847
+ (
848
+ # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
849
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
850
+ f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t"
851
+ f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
852
+ f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
853
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
854
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
855
+ )
856
+ for k in range(1, num_k_tile)
857
+ )
858
+ + "}\n",
859
+ # "r,r,r",
860
+ "r,r,r,r",
861
+ has_side_effects=True,
862
+ is_align_stack=False,
863
+ asm_dialect=llvm.AsmDialect.AD_ATT,
864
+ )
865
+ else:
866
+ input_args = [
867
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(),
868
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
869
+ Int32(not zero_init).ir_value(),
870
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
871
+ ]
872
+ if const_expr(mbar_ptr is not None):
873
+ assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None"
874
+ input_args.append(mbar_ptr.toint().ir_value())
875
+ input_args.append(Int32(mbar_phase).ir_value())
876
+ mbar_wait_str = (
877
+ ".reg .pred P1; \n\t"
878
+ "LAB_WAIT: \n\t"
879
+ "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t"
880
+ "@P1 bra DONE; \n\t"
881
+ "bra LAB_WAIT; \n\t"
882
+ "DONE: \n\t"
883
+ )
884
+ else:
885
+ mbar_wait_str = ""
886
+ llvm.inline_asm(
887
+ None,
888
+ # [
889
+ # # acc.iterator.toint().ir_value(),
890
+ # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(),
891
+ # Int32(smem_desc_start_b_lo).ir_value(),
892
+ # Int32(not zero_init).ir_value(),
893
+ # ],
894
+ input_args,
895
+ "{\n\t"
896
+ ".reg .pred leader_thread;\n\t"
897
+ ".reg .pred p;\n\t"
898
+ ".reg .b32 idesc;\n\t"
899
+ ".reg .b32 tmem_acc;\n\t"
900
+ ".reg .b32 tmem_a;\n\t"
901
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
902
+ ".reg .b32 smem_desc_b_lo;\n\t"
903
+ ".reg .b32 smem_desc_b_hi;\n\t"
904
+ ".reg .b64 smem_desc_b;\n\t"
905
+ "elect.sync _|leader_thread, -1;\n\t"
906
+ f"mov.b32 idesc, {hex(idesc)};\n\t"
907
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
908
+ f"mov.b32 tmem_acc, $3;\n\t"
909
+ f"mov.b32 tmem_a, $0;\n\t"
910
+ f"mov.b32 smem_desc_b_lo_start, $1;\n\t"
911
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
912
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
913
+ "setp.ne.b32 p, $2, 0;\n\t"
914
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
915
+ + "".join(
916
+ (
917
+ # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
918
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
919
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
920
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
921
+ # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
922
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
923
+ )
924
+ for k in range(
925
+ 1,
926
+ num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3,
927
+ )
928
+ )
929
+ + mbar_wait_str
930
+ + (
931
+ "".join(
932
+ (
933
+ # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
934
+ f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
935
+ f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
936
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
937
+ )
938
+ for k in range(num_k_tile // 4 * 3, num_k_tile)
939
+ )
940
+ if const_expr(mbar_ptr is not None)
941
+ else ""
942
+ )
943
+ + "}\n",
944
+ "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r",
945
+ has_side_effects=True,
946
+ is_align_stack=False,
947
+ asm_dialect=llvm.AsmDialect.AD_ATT,
948
+ )
949
+
950
+
951
+ @cute.jit
952
+ def declare_ptx_smem_desc(
953
+ smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A
954
+ smem_desc_base_a: Optional[int],
955
+ tCrA_layout: cute.Layout,
956
+ var_name_prefix: str = "smem_desc",
957
+ ) -> None:
958
+ is_ts = const_expr(smem_desc_base_a is None)
959
+ num_k_tile = cute.size(tCrA_layout.shape[2])
960
+ smem_desc_base_a_lo, smem_desc_a_hi = None, None
961
+ if const_expr(not is_ts):
962
+ smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)
963
+ tCrA_layout = (
964
+ tCrA_layout
965
+ if const_expr(not is_ts)
966
+ # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)
967
+ # currently hard-coding the width to 16
968
+ else cute.recast_layout(32, 16, tCrA_layout)
969
+ )
970
+ offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]
971
+ smem_desc_start_a_lo = None
972
+ if const_expr(not is_ts):
973
+ smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)
974
+ if const_expr(not is_ts):
975
+ llvm.inline_asm(
976
+ None,
977
+ [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()],
978
+ f".reg .b32 {var_name_prefix}_lo;\n\t"
979
+ f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t"
980
+ f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t"
981
+ + "".join(
982
+ (
983
+ f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t"
984
+ f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t"
985
+ )
986
+ for k in range(1, num_k_tile)
987
+ ),
988
+ "r",
989
+ has_side_effects=True,
990
+ is_align_stack=False,
991
+ asm_dialect=llvm.AsmDialect.AD_ATT,
992
+ )
993
+
994
+
995
+ @cute.jit
996
+ def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None:
997
+ idesc = const_expr(sm100_desc.mma_op_to_idesc(op))
998
+ llvm.inline_asm(
999
+ None,
1000
+ [],
1001
+ f".reg .b32 {var_name};\n\t" # noqa
1002
+ f"mov.b32 {var_name}, {hex(idesc)};\n\t",
1003
+ constraints="",
1004
+ has_side_effects=True,
1005
+ is_align_stack=False,
1006
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1007
+ )
1008
+
1009
+
1010
+ @cute.jit
1011
+ def gemm_ptx_precomputed_varname(
1012
+ acc_tmem_addr: Int32,
1013
+ smem_desc_start_b: Int32,
1014
+ # idesc: int,
1015
+ smem_desc_base_b: int,
1016
+ tCrB_layout: cute.Layout,
1017
+ smem_var_name_prefix: str,
1018
+ idesc_var_name: str,
1019
+ smem_offset: int,
1020
+ zero_init: bool | Boolean = False,
1021
+ cta_group: int = 1,
1022
+ ) -> None:
1023
+ is_ts = False
1024
+ num_k_tile = cute.size(tCrB_layout.shape[2])
1025
+ smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)
1026
+ offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]
1027
+
1028
+ smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)
1029
+ pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1"
1030
+ if const_expr(not is_ts):
1031
+ llvm.inline_asm(
1032
+ None,
1033
+ [
1034
+ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),
1035
+ Int32(not zero_init).ir_value(),
1036
+ Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),
1037
+ ],
1038
+ "{\n\t"
1039
+ ".reg .pred leader_thread;\n\t"
1040
+ ".reg .pred p;\n\t"
1041
+ # ".reg .b32 idesc;\n\t"
1042
+ ".reg .b32 tmem_acc;\n\t"
1043
+ ".reg .b32 smem_desc_b_lo_start;\n\t"
1044
+ ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t"
1045
+ ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t"
1046
+ # ".reg .b64 smem_desc_b;\n\t"
1047
+ f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t"
1048
+ "elect.sync _|leader_thread, -1;\n\t"
1049
+ # f"mov.b32 idesc, {hex(idesc)};\n\t"
1050
+ # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t"
1051
+ f"mov.b32 tmem_acc, $2;\n\t"
1052
+ "mov.b32 smem_desc_b_lo_start, $0;\n\t"
1053
+ f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
1054
+ f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t"
1055
+ f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
1056
+ f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
1057
+ f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
1058
+ + "".join(
1059
+ (
1060
+ f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
1061
+ f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
1062
+ f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
1063
+ f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
1064
+ f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
1065
+ )
1066
+ for k in range(1, num_k_tile)
1067
+ )
1068
+ + "setp.ne.b32 p, $1, 0;\n\t"
1069
+ # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t"
1070
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t"
1071
+ + "".join(
1072
+ (
1073
+ # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t"
1074
+ # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t"
1075
+ # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
1076
+ # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
1077
+ # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
1078
+ # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t"
1079
+ # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t"
1080
+ f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t"
1081
+ )
1082
+ for k in range(1, num_k_tile)
1083
+ )
1084
+ + "}\n",
1085
+ "r,r,r",
1086
+ has_side_effects=True,
1087
+ is_align_stack=False,
1088
+ asm_dialect=llvm.AsmDialect.AD_ATT,
1089
+ )
build/torch-cuda/block_info.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ from typing import Tuple, Optional
3
+ from dataclasses import dataclass
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass import Int32, const_expr
8
+
9
+ from .seqlen_info import SeqlenInfoQK
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class BlockInfo:
14
+ tile_m: cutlass.Constexpr[int]
15
+ tile_n: cutlass.Constexpr[int]
16
+ is_causal: cutlass.Constexpr[bool]
17
+ is_local: cutlass.Constexpr[bool] = False
18
+ is_split_kv: cutlass.Constexpr[bool] = False
19
+ window_size_left: Optional[Int32] = None
20
+ window_size_right: Optional[Int32] = None
21
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
22
+
23
+ @cute.jit
24
+ def get_n_block_min_max(
25
+ self,
26
+ seqlen_info: SeqlenInfoQK,
27
+ m_block: Int32,
28
+ split_idx: cutlass.Int32 = 0,
29
+ num_splits: cutlass.Int32 = 1,
30
+ ) -> Tuple[Int32, Int32]:
31
+ n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
32
+ if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
33
+ m_idx_max = (m_block + 1) * self.tile_m
34
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
35
+ m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
36
+ n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
37
+ n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right
38
+ n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))
39
+ n_block_min = 0
40
+ if const_expr(self.is_local and self.window_size_left is not None):
41
+ m_idx_min = m_block * self.tile_m
42
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
43
+ m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
44
+ n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
45
+ n_idx_left = n_idx - self.window_size_left
46
+ n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
47
+ if cutlass.const_expr(self.is_split_kv):
48
+ num_n_blocks_per_split = (
49
+ cutlass.Int32(0)
50
+ if n_block_max <= n_block_min
51
+ else (n_block_max - n_block_min + num_splits - 1) // num_splits
52
+ )
53
+ n_block_min = n_block_min + split_idx * num_n_blocks_per_split
54
+ n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
55
+ return n_block_min, n_block_max
56
+
57
+ @cute.jit
58
+ def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
59
+ m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
60
+ m_block_min = 0
61
+ if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
62
+ n_idx_min = n_block * self.tile_n
63
+ m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
64
+ m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right
65
+ m_block_min = max(m_block_min, m_idx_right // self.tile_m)
66
+ if const_expr(self.is_local and self.window_size_left is not None):
67
+ n_idx_max = (n_block + 1) * self.tile_n
68
+ m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k
69
+ m_idx_left = m_idx + self.window_size_left
70
+ m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
71
+ return m_block_min, m_block_max
72
+
73
+ @cute.jit
74
+ def get_n_block_min_causal_local_mask(
75
+ self,
76
+ seqlen_info: SeqlenInfoQK,
77
+ m_block: Int32,
78
+ n_block_min: Int32,
79
+ ) -> Int32:
80
+ """If we have separate iterations with causal or local masking at the start, where do we stop"""
81
+ m_idx_min = m_block * self.tile_m
82
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
83
+ m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
84
+ n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
85
+ n_idx_right = (
86
+ n_idx
87
+ if const_expr(not self.is_local or self.window_size_right is None)
88
+ else n_idx + self.window_size_right
89
+ )
90
+ return cutlass.max(n_block_min, n_idx_right // self.tile_n)
91
+
92
+ @cute.jit
93
+ def get_n_block_min_before_local_mask(
94
+ self,
95
+ seqlen_info: SeqlenInfoQK,
96
+ m_block: Int32,
97
+ n_block_min: Int32,
98
+ ) -> Int32:
99
+ """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations"""
100
+ if const_expr(not self.is_local or self.window_size_left is None):
101
+ return n_block_min
102
+ else:
103
+ m_idx_max = (m_block + 1) * self.tile_m
104
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
105
+ m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
106
+ n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
107
+ n_idx_left = n_idx - self.window_size_left
108
+ return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n))
build/torch-cuda/block_sparse_utils.py ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Block-sparse runtime utilities for CUTE DSL kernels.
3
+
4
+ This module contains runtime execution functions for block-sparse attention kernels.
5
+ These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
6
+ """
7
+
8
+ from typing import Callable, Optional
9
+ from functools import partial
10
+ import math
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+ from cutlass import Float32, Int32, const_expr
14
+
15
+ from .quack import copy_utils
16
+
17
+ # Import data structures from block_sparsity
18
+ from .block_sparsity import BlockSparseTensors
19
+ from .named_barrier import NamedBarrierBwd
20
+
21
+
22
+ # NOTE [SM100 block-sparse empty tiles: mbarrier contract]
23
+ #
24
+ # For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active
25
+ # KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so
26
+ # the softmax warp-group has no row stats to publish.
27
+ #
28
+ # The correction warp-group seeds fully-masked-row stats and runs the usual correction
29
+ # epilogue so output/LSE have well-defined values. Both warp-groups must still perform
30
+ # the softmax<->correction mbarrier handshake so phases advance correctly across
31
+ # empty->empty and empty->non-empty tile sequences.
32
+ #
33
+ # In the no-sink case, this corresponds to the usual fully-masked-row convention:
34
+ # output is zero and LSE is -inf.
35
+ #
36
+ # Barrier contract (each is `mbar_ptr + <offset> + stage`):
37
+ #
38
+ # Producer/consumer pairs:
39
+ # - `mbar_softmax_corr_full` : softmax arrive -> correction wait
40
+ # - `mbar_softmax_corr_empty` : correction arrive -> softmax wait
41
+ # - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait
42
+ # - `mbar_P_full_2` : softmax arrive -> MMA wait
43
+ # - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate)
44
+ #
45
+ # Empty tile (`total_block_cnt == 0`):
46
+ # - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`).
47
+ # It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal.
48
+ # At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty`
49
+ # before each tile (when block-sparse) to drain a prior correction arrival and keep
50
+ # phases aligned across non-empty -> empty transitions.
51
+ # - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`,
52
+ # and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable).
53
+ # - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction
54
+ # (and correction<->epilogue) handshakes advance phases.
55
+ #
56
+ # Non-empty tile:
57
+ # - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to
58
+ # publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases;
59
+ # arrives `mbar_P_full_*` when P is stored.
60
+ # - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty`
61
+ # to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed.
62
+ #
63
+ # Backward (SM100):
64
+ # - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute.
65
+ # - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles
66
+ # skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward).
67
+ # - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros
68
+ # even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`).
69
+
70
+
71
+ @cute.jit
72
+ def load_block_list(
73
+ block_indices: cute.Tensor,
74
+ block_count,
75
+ load_q_with_first: cutlass.Constexpr,
76
+ first_block_preloaded: cutlass.Constexpr,
77
+ kv_producer_state,
78
+ load_Q,
79
+ load_K,
80
+ load_V,
81
+ pipeline_k,
82
+ pipeline_v,
83
+ use_tma_q: cutlass.Constexpr,
84
+ tma_q_bytes: cutlass.Constexpr,
85
+ intra_wg_overlap: cutlass.Constexpr,
86
+ ):
87
+ """Iterate over the sparse blocks and load K, V (and Q) into the pipeline.
88
+ for the intra_wg_overlap case, we overlap the loads of K and V. And this
89
+ means we need to pipeline the last V load from the partial block case,
90
+ with the loads for the full blocks. Set first_block_preloaded when the
91
+ caller has already issued the first K load for the list.
92
+
93
+ Note:
94
+ we iterate along the block_n indices in reverse.
95
+
96
+ Returns:
97
+ Updated kv_producer_state after processing the block list.
98
+
99
+ """
100
+ if block_count > 0:
101
+ if const_expr(not intra_wg_overlap):
102
+ # Peel first iteration: the first block may need to load Q alongside K,
103
+ # Parameters are already Constexpr, so no need to wrap in const_expr()
104
+ n_block_first = block_indices[block_count - 1]
105
+ extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
106
+ pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
107
+
108
+ if const_expr(load_q_with_first and use_tma_q):
109
+ load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
110
+
111
+ load_K(src_idx=n_block_first, producer_state=kv_producer_state)
112
+ pipeline_v.producer_acquire(kv_producer_state)
113
+ load_V(src_idx=n_block_first, producer_state=kv_producer_state)
114
+ kv_producer_state.advance()
115
+
116
+ for offset in cutlass.range(1, block_count):
117
+ n_block = block_indices[block_count - 1 - offset]
118
+ pipeline_k.producer_acquire(kv_producer_state)
119
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
120
+ pipeline_v.producer_acquire(kv_producer_state)
121
+ load_V(src_idx=n_block, producer_state=kv_producer_state)
122
+ kv_producer_state.advance()
123
+ else:
124
+ n_block_first = block_indices[block_count - 1]
125
+ if const_expr(not first_block_preloaded):
126
+ extra_tx = (
127
+ tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
128
+ )
129
+ pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
130
+
131
+ if const_expr(load_q_with_first and use_tma_q):
132
+ load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
133
+
134
+ load_K(src_idx=n_block_first, producer_state=kv_producer_state)
135
+
136
+ for idx in cutlass.range(block_count - 1, unroll=1):
137
+ n_block_prev = block_indices[block_count - 1 - idx]
138
+ n_block = block_indices[block_count - 2 - idx]
139
+ kv_producer_state_prev = kv_producer_state.clone()
140
+ kv_producer_state.advance()
141
+ pipeline_k.producer_acquire(kv_producer_state)
142
+ load_K(src_idx=n_block, producer_state=kv_producer_state)
143
+ pipeline_v.producer_acquire(kv_producer_state_prev)
144
+ load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
145
+
146
+ return kv_producer_state
147
+
148
+
149
+ @cute.jit
150
+ def finish_overlap_v_load(
151
+ block_indices: cute.Tensor,
152
+ block_count,
153
+ load_V,
154
+ pipeline_v,
155
+ kv_producer_state,
156
+ ):
157
+ """Load the final V block after overlapped K/V loads."""
158
+ if block_count > 0:
159
+ n_block_last = block_indices[0]
160
+ pipeline_v.producer_acquire(kv_producer_state)
161
+ load_V(src_idx=n_block_last, producer_state=kv_producer_state)
162
+ kv_producer_state.advance()
163
+
164
+ return kv_producer_state
165
+
166
+
167
+ @cute.jit
168
+ def sparse_tensor_m_block(
169
+ m_block,
170
+ qhead_per_kvhead: cutlass.Constexpr[int],
171
+ q_subtile_factor: cutlass.Constexpr[int],
172
+ ):
173
+ """Map packed m_block indices to block-sparse tensor indices."""
174
+ block = m_block
175
+ if const_expr(qhead_per_kvhead != 1):
176
+ block = block // qhead_per_kvhead
177
+ if const_expr(q_subtile_factor != 1):
178
+ block = block // q_subtile_factor
179
+ return block
180
+
181
+
182
+ @cute.jit
183
+ def produce_block_sparse_loads(
184
+ blocksparse_tensors: BlockSparseTensors,
185
+ batch_idx,
186
+ head_idx,
187
+ m_block,
188
+ kv_producer_state,
189
+ load_Q,
190
+ load_K,
191
+ load_V,
192
+ pipeline_k,
193
+ pipeline_v,
194
+ use_tma_q: cutlass.Constexpr,
195
+ tma_q_bytes: cutlass.Constexpr,
196
+ intra_wg_overlap: cutlass.Constexpr,
197
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
198
+ q_subtile_factor: cutlass.Constexpr[int] = 1,
199
+ ):
200
+ """Iterate over the mask and full block lists for a single tile.
201
+
202
+ The masked (partial) list may leave the last V load pending when intra-warp-group
203
+ overlap is enabled. The first full block must consume that pending V while
204
+ issuing its own K load on the next pipeline stage.
205
+
206
+ In the intra-wg-overlap path, the last masked block leaves its V copy in flight
207
+ while we advance the producer state to start the next full K. Either the full list
208
+ overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.
209
+
210
+ Args:
211
+ qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
212
+ must be converted to unpacked for sparse tensor indexing.
213
+ """
214
+
215
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
216
+
217
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
218
+
219
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
220
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
221
+
222
+ if const_expr(full_block_cnt is not None):
223
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
224
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
225
+ else:
226
+ curr_full_block_cnt = Int32(0)
227
+ curr_full_block_idx = None
228
+
229
+ mask_empty = curr_mask_block_cnt == 0
230
+ full_empty = curr_full_block_cnt == 0
231
+
232
+ if mask_empty:
233
+ # No masked blocks: the full list owns the initial Q+K load.
234
+ kv_producer_state = load_block_list(
235
+ curr_full_block_idx,
236
+ curr_full_block_cnt,
237
+ load_q_with_first=True,
238
+ first_block_preloaded=False,
239
+ kv_producer_state=kv_producer_state,
240
+ load_Q=load_Q,
241
+ load_K=load_K,
242
+ load_V=load_V,
243
+ pipeline_k=pipeline_k,
244
+ pipeline_v=pipeline_v,
245
+ use_tma_q=use_tma_q,
246
+ tma_q_bytes=tma_q_bytes,
247
+ intra_wg_overlap=intra_wg_overlap,
248
+ )
249
+
250
+ if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0:
251
+ kv_producer_state = finish_overlap_v_load(
252
+ curr_full_block_idx,
253
+ curr_full_block_cnt,
254
+ load_V,
255
+ pipeline_v,
256
+ kv_producer_state,
257
+ )
258
+ else:
259
+ # Masked blocks present: load Q together with the first masked K so consumers can
260
+ # start immediately. When overlap is disabled this fully drains the list.
261
+ kv_producer_state = load_block_list(
262
+ curr_mask_block_idx,
263
+ curr_mask_block_cnt,
264
+ load_q_with_first=True,
265
+ first_block_preloaded=False,
266
+ kv_producer_state=kv_producer_state,
267
+ load_Q=load_Q,
268
+ load_K=load_K,
269
+ load_V=load_V,
270
+ pipeline_k=pipeline_k,
271
+ pipeline_v=pipeline_v,
272
+ use_tma_q=use_tma_q,
273
+ tma_q_bytes=tma_q_bytes,
274
+ intra_wg_overlap=intra_wg_overlap,
275
+ )
276
+
277
+ if full_empty:
278
+ if const_expr(intra_wg_overlap):
279
+ kv_producer_state = finish_overlap_v_load(
280
+ curr_mask_block_idx,
281
+ curr_mask_block_cnt,
282
+ load_V,
283
+ pipeline_v,
284
+ kv_producer_state,
285
+ )
286
+ else:
287
+ if const_expr(intra_wg_overlap):
288
+ # Bridge the masked list to the full list by overlapping the pending masked V
289
+ # with the first full K load.
290
+ n_block_mask_last = curr_mask_block_idx[0]
291
+ n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1]
292
+ kv_producer_state_prev = kv_producer_state.clone()
293
+ kv_producer_state.advance()
294
+ pipeline_k.producer_acquire(kv_producer_state)
295
+ load_K(src_idx=n_block_full_first, producer_state=kv_producer_state)
296
+ pipeline_v.producer_acquire(kv_producer_state_prev)
297
+ load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev)
298
+
299
+ kv_producer_state = load_block_list(
300
+ curr_full_block_idx,
301
+ curr_full_block_cnt,
302
+ load_q_with_first=False,
303
+ first_block_preloaded=True,
304
+ kv_producer_state=kv_producer_state,
305
+ load_Q=load_Q,
306
+ load_K=load_K,
307
+ load_V=load_V,
308
+ pipeline_k=pipeline_k,
309
+ pipeline_v=pipeline_v,
310
+ use_tma_q=use_tma_q,
311
+ tma_q_bytes=tma_q_bytes,
312
+ intra_wg_overlap=intra_wg_overlap,
313
+ )
314
+
315
+ kv_producer_state = finish_overlap_v_load(
316
+ curr_full_block_idx,
317
+ curr_full_block_cnt,
318
+ load_V,
319
+ pipeline_v,
320
+ kv_producer_state,
321
+ )
322
+ else:
323
+ # Non-overlap path with both lists: run the full list normally (skipping the Q
324
+ # reload because the masked list already issued it).
325
+ kv_producer_state = load_block_list(
326
+ curr_full_block_idx,
327
+ curr_full_block_cnt,
328
+ load_q_with_first=False,
329
+ first_block_preloaded=False,
330
+ kv_producer_state=kv_producer_state,
331
+ load_Q=load_Q,
332
+ load_K=load_K,
333
+ load_V=load_V,
334
+ pipeline_k=pipeline_k,
335
+ pipeline_v=pipeline_v,
336
+ use_tma_q=use_tma_q,
337
+ tma_q_bytes=tma_q_bytes,
338
+ intra_wg_overlap=intra_wg_overlap,
339
+ )
340
+
341
+ return kv_producer_state
342
+
343
+
344
+ @cute.jit
345
+ def consume_block_sparse_loads(
346
+ blocksparse_tensors: BlockSparseTensors,
347
+ batch_idx,
348
+ head_idx,
349
+ m_block,
350
+ seqlen,
351
+ kv_consumer_state,
352
+ mma_pv_fn,
353
+ mma_one_n_block,
354
+ process_first_half_block,
355
+ process_last_half_block,
356
+ mask_fn,
357
+ score_mod_fn,
358
+ O_should_accumulate,
359
+ mask_mod,
360
+ fastdiv_mods,
361
+ intra_wg_overlap: cutlass.Constexpr,
362
+ warp_scheduler_barrier_sync: Callable,
363
+ warp_scheduler_barrier_arrive: Callable,
364
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
365
+ q_subtile_factor: cutlass.Constexpr[int] = 1,
366
+ ):
367
+ """Consume the mask and full block lists for a single tile on the consumer side.
368
+
369
+ Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses
370
+ the same sparse tensor indexing.
371
+
372
+ Args:
373
+ qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
374
+ must be converted to unpacked for sparse tensor indexing.
375
+ """
376
+
377
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
378
+
379
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
380
+
381
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
382
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
383
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
384
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
385
+
386
+ processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0
387
+
388
+ if const_expr(not intra_wg_overlap):
389
+ if curr_mask_block_cnt > 0:
390
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
391
+ warp_scheduler_barrier_sync()
392
+ kv_consumer_state = mma_one_n_block(
393
+ kv_consumer_state,
394
+ n_block=mask_n_block,
395
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
396
+ mask_fn=partial(
397
+ mask_fn,
398
+ mask_mod=mask_mod,
399
+ mask_seqlen=True,
400
+ fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
401
+ ),
402
+ is_first_n_block=True,
403
+ )
404
+ O_should_accumulate = True
405
+ for i in cutlass.range(1, curr_mask_block_cnt):
406
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
407
+ kv_consumer_state = mma_one_n_block(
408
+ kv_consumer_state,
409
+ n_block=mask_n_block,
410
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
411
+ mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
412
+ is_first_n_block=False,
413
+ )
414
+ O_should_accumulate = True
415
+ if curr_full_block_cnt == 0:
416
+ warp_scheduler_barrier_arrive()
417
+
418
+ if curr_full_block_cnt > 0:
419
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
420
+ if curr_mask_block_cnt == 0:
421
+ warp_scheduler_barrier_sync()
422
+ kv_consumer_state = mma_one_n_block(
423
+ kv_consumer_state,
424
+ n_block=full_n_block,
425
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
426
+ mask_fn=partial(mask_fn, mask_seqlen=True),
427
+ is_first_n_block=True,
428
+ )
429
+ O_should_accumulate = True
430
+ for i in cutlass.range(1, curr_full_block_cnt):
431
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
432
+ kv_consumer_state = mma_one_n_block(
433
+ kv_consumer_state,
434
+ n_block=full_n_block,
435
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
436
+ mask_fn=partial(mask_fn, mask_seqlen=False),
437
+ is_first_n_block=False,
438
+ )
439
+ O_should_accumulate = True
440
+ else:
441
+ kv_consumer_state = mma_one_n_block(
442
+ kv_consumer_state,
443
+ n_block=full_n_block,
444
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
445
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
446
+ is_first_n_block=False,
447
+ )
448
+ O_should_accumulate = True
449
+ for i in cutlass.range(1, curr_full_block_cnt):
450
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
451
+ kv_consumer_state = mma_one_n_block(
452
+ kv_consumer_state,
453
+ n_block=full_n_block,
454
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
455
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
456
+ is_first_n_block=False,
457
+ )
458
+ O_should_accumulate = True
459
+ warp_scheduler_barrier_arrive()
460
+ else:
461
+ if curr_mask_block_cnt > 0:
462
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
463
+ kv_consumer_state = process_first_half_block(
464
+ n_block=mask_n_block,
465
+ seqlen=seqlen,
466
+ kv_consumer_state=kv_consumer_state,
467
+ mask_fn=partial(
468
+ mask_fn,
469
+ mask_mod=mask_mod,
470
+ mask_seqlen=True,
471
+ fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,
472
+ ),
473
+ score_mod_fn=score_mod_fn,
474
+ is_first_block=True,
475
+ )
476
+ for i in cutlass.range(1, curr_mask_block_cnt):
477
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
478
+ kv_consumer_state = mma_one_n_block(
479
+ kv_consumer_state,
480
+ n_block=mask_n_block,
481
+ seqlen=seqlen,
482
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
483
+ mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),
484
+ )
485
+ O_should_accumulate = True
486
+
487
+ if curr_full_block_cnt > 0:
488
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
489
+ if curr_mask_block_cnt == 0:
490
+ kv_consumer_state = process_first_half_block(
491
+ n_block=full_n_block,
492
+ seqlen=seqlen,
493
+ kv_consumer_state=kv_consumer_state,
494
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
495
+ score_mod_fn=score_mod_fn,
496
+ is_first_block=True,
497
+ )
498
+ else:
499
+ kv_consumer_state = mma_one_n_block(
500
+ kv_consumer_state,
501
+ n_block=full_n_block,
502
+ seqlen=seqlen,
503
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
504
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),
505
+ )
506
+ O_should_accumulate = True
507
+ for i in cutlass.range(1, curr_full_block_cnt):
508
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
509
+ kv_consumer_state = mma_one_n_block(
510
+ kv_consumer_state,
511
+ n_block=full_n_block,
512
+ seqlen=seqlen,
513
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
514
+ mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),
515
+ )
516
+ O_should_accumulate = True
517
+
518
+ if curr_mask_block_cnt + curr_full_block_cnt > 0:
519
+ kv_consumer_state = process_last_half_block(
520
+ kv_consumer_state=kv_consumer_state,
521
+ zero_init=not O_should_accumulate,
522
+ )
523
+ O_should_accumulate = True
524
+
525
+ return kv_consumer_state, O_should_accumulate, processed_any
526
+
527
+
528
+ @cute.jit
529
+ def load_block_list_sm100(
530
+ block_indices: cute.Tensor,
531
+ block_count,
532
+ load_q_with_first: cutlass.Constexpr,
533
+ q_stage: cutlass.Constexpr,
534
+ kv_producer_state,
535
+ load_Q,
536
+ load_K,
537
+ load_V,
538
+ pipeline_kv,
539
+ ):
540
+ """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count)."""
541
+ if block_count > 0:
542
+ # First iteration: load Q alongside K if requested
543
+ n_block_first = block_indices[block_count - 1]
544
+
545
+ if const_expr(load_q_with_first):
546
+ # SM100 loads Q0 and optionally Q1
547
+ load_Q(block=0, stage=0)
548
+ if const_expr(q_stage == 2):
549
+ load_Q(block=1, stage=1)
550
+
551
+ # SM100 doesn't use producer_acquire for pipeline_kv in load path
552
+ # The pipeline barriers are handled inside load_KV
553
+ load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
554
+ kv_producer_state.advance()
555
+ load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None)
556
+ kv_producer_state.advance()
557
+
558
+ # Remaining blocks
559
+ for offset in cutlass.range(1, block_count):
560
+ n_block = block_indices[block_count - 1 - offset]
561
+ load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)
562
+ kv_producer_state.advance()
563
+ load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)
564
+ kv_producer_state.advance()
565
+
566
+ return kv_producer_state
567
+
568
+
569
+ # SM100-specific tile processor using SM100 helpers
570
+ @cute.jit
571
+ def produce_block_sparse_loads_sm100(
572
+ blocksparse_tensors: BlockSparseTensors,
573
+ batch_idx,
574
+ head_idx,
575
+ m_block,
576
+ kv_producer_state,
577
+ load_Q,
578
+ load_K,
579
+ load_V,
580
+ pipeline_kv,
581
+ q_stage: cutlass.Constexpr,
582
+ q_producer_phase: Int32,
583
+ qhead_per_kvhead: cutlass.Constexpr,
584
+ q_subtile_factor: cutlass.Constexpr,
585
+ ):
586
+ """SM100 entry point for sparse block iteration.
587
+
588
+ SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use
589
+ simplified block processing that just calls producer_acquire without extras.
590
+
591
+ Args:
592
+ m_block: which tile of m we are processing
593
+ qhead_per_kvhead: Constexpr pack factor
594
+ """
595
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
596
+
597
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
598
+
599
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
600
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
601
+
602
+ if const_expr(full_block_cnt is not None):
603
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
604
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
605
+ else:
606
+ curr_full_block_cnt = Int32(0)
607
+ curr_full_block_idx = None
608
+
609
+ mask_empty = curr_mask_block_cnt == 0
610
+ full_empty = curr_full_block_cnt == 0
611
+
612
+ q_phase_flipped = False
613
+
614
+ if mask_empty:
615
+ # No masked blocks: process full list with Q loading
616
+ kv_producer_state = load_block_list_sm100(
617
+ curr_full_block_idx,
618
+ curr_full_block_cnt,
619
+ load_q_with_first=True,
620
+ q_stage=q_stage,
621
+ kv_producer_state=kv_producer_state,
622
+ load_Q=load_Q,
623
+ load_K=load_K,
624
+ load_V=load_V,
625
+ pipeline_kv=pipeline_kv,
626
+ )
627
+ q_phase_flipped = not full_empty
628
+ else:
629
+ # Process masked blocks with Q loading
630
+ kv_producer_state = load_block_list_sm100(
631
+ curr_mask_block_idx,
632
+ curr_mask_block_cnt,
633
+ load_q_with_first=True,
634
+ q_stage=q_stage,
635
+ kv_producer_state=kv_producer_state,
636
+ load_Q=load_Q,
637
+ load_K=load_K,
638
+ load_V=load_V,
639
+ pipeline_kv=pipeline_kv,
640
+ )
641
+ q_phase_flipped = True
642
+
643
+ if not full_empty:
644
+ # Process full blocks without Q loading
645
+ kv_producer_state = load_block_list_sm100(
646
+ curr_full_block_idx,
647
+ curr_full_block_cnt,
648
+ load_q_with_first=False,
649
+ q_stage=q_stage,
650
+ kv_producer_state=kv_producer_state,
651
+ load_Q=load_Q,
652
+ load_K=load_K,
653
+ load_V=load_V,
654
+ pipeline_kv=pipeline_kv,
655
+ )
656
+
657
+ if q_phase_flipped:
658
+ q_producer_phase ^= 1
659
+
660
+ return kv_producer_state, q_producer_phase
661
+
662
+
663
+ @cute.jit
664
+ def get_total_block_count(
665
+ blocksparse_tensors: BlockSparseTensors,
666
+ batch_idx,
667
+ head_idx,
668
+ m_block,
669
+ qhead_per_kvhead: cutlass.Constexpr,
670
+ q_subtile_factor: cutlass.Constexpr,
671
+ ):
672
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
673
+
674
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
675
+ if const_expr(full_block_cnt is not None):
676
+ return (
677
+ mask_block_cnt[batch_idx, head_idx, m_block_sparse]
678
+ + full_block_cnt[batch_idx, head_idx, m_block_sparse]
679
+ )
680
+ else:
681
+ return mask_block_cnt[batch_idx, head_idx, m_block_sparse]
682
+
683
+
684
+ @cute.jit
685
+ def handle_block_sparse_empty_tile_correction_sm100(
686
+ tidx: Int32,
687
+ q_stage: cutlass.Constexpr,
688
+ m_block_size: cutlass.Constexpr,
689
+ qhead_per_kvhead,
690
+ pack_gqa: cutlass.Constexpr,
691
+ is_split_kv: cutlass.Constexpr,
692
+ learnable_sink,
693
+ mLSE,
694
+ seqlen,
695
+ m_block: Int32,
696
+ head_idx: Int32,
697
+ batch_idx: Int32,
698
+ split_idx: Int32,
699
+ sScale: cute.Tensor,
700
+ stats: list,
701
+ correction_epilogue: Callable,
702
+ thr_mma_pv: cute.core.ThrMma,
703
+ tOtO: cute.Tensor,
704
+ sO: cute.Tensor,
705
+ pipeline_sm_stats: cutlass.pipeline.PipelineAsync,
706
+ sm_stats_barrier: cutlass.pipeline.NamedBarrier,
707
+ pipeline_o_epi: cutlass.pipeline.PipelineAsync,
708
+ sm_stats_consumer_phase: Int32,
709
+ o_corr_consumer_phase: Int32,
710
+ corr_epi_producer_phase: Int32,
711
+ softmax_scale_log2: Float32,
712
+ mO_cur: Optional[cute.Tensor] = None,
713
+ gO: Optional[cute.Tensor] = None,
714
+ gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
715
+ ):
716
+ """Handle SM100 forward block-sparse tiles with no active KV blocks.
717
+
718
+ This path is taken when `total_block_cnt == 0`. The softmax warp-group still
719
+ arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction
720
+ warp-group can:
721
+
722
+ - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE
723
+ - run `correction_epilogue` with `scale=0` so the output tile is written as zeros
724
+ (independent of any prior tmem contents)
725
+ - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty`
726
+ (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles
727
+
728
+ This helper intentionally does not touch `mbar_P_full_*` since no P is produced.
729
+ See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
730
+ """
731
+ LOG2_E = Float32(math.log2(math.e))
732
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
733
+
734
+ for stage in cutlass.range_constexpr(q_stage):
735
+ row_sum_value = Float32(1.0)
736
+ row_max_value = (
737
+ -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None
738
+ )
739
+ if const_expr(learnable_sink is not None):
740
+ sink_val = -Float32.inf
741
+ if const_expr(not pack_gqa):
742
+ sink_val = Float32(learnable_sink[head_idx])
743
+ elif tidx < m_block_size:
744
+ q_head_idx = (
745
+ (q_stage * m_block + stage) * m_block_size + tidx
746
+ ) % qhead_per_kvhead + head_idx * qhead_per_kvhead
747
+ sink_val = Float32(learnable_sink[q_head_idx])
748
+ if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):
749
+ if row_max_value == -Float32.inf:
750
+ row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
751
+ row_sum_value = Float32(1.0)
752
+ else:
753
+ row_sum_value = row_sum_value + cute.math.exp2(
754
+ sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True
755
+ )
756
+ if tidx < m_block_size:
757
+ scale_row_idx = tidx + stage * m_block_size
758
+ sScale[scale_row_idx] = row_sum_value
759
+ if const_expr(mLSE is not None or learnable_sink is not None):
760
+ sScale[scale_row_idx + q_stage * m_block_size] = row_max_value
761
+ acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
762
+ stats[stage] = (row_sum_value, row_max_value, acc_flag)
763
+
764
+ # See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
765
+ # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase)
766
+ sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx)
767
+ pipeline_sm_stats.consumer_release_w_index(stage)
768
+
769
+ if const_expr(gmem_tiled_copy_O is None):
770
+ pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)
771
+ correction_epilogue(
772
+ thr_mma_pv,
773
+ tOtO[None, None, None, stage],
774
+ tidx,
775
+ stage,
776
+ m_block,
777
+ seqlen.seqlen_q,
778
+ Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
779
+ sO[None, None, stage],
780
+ mO_cur,
781
+ gO[None, None, stage],
782
+ gmem_tiled_copy_O,
783
+ )
784
+ if const_expr(gmem_tiled_copy_O is None):
785
+ pipeline_o_epi.producer_commit_w_index(stage)
786
+
787
+ sm_stats_consumer_phase ^= 1
788
+ corr_epi_producer_phase ^= 1
789
+
790
+ return (
791
+ sm_stats_consumer_phase,
792
+ o_corr_consumer_phase,
793
+ corr_epi_producer_phase,
794
+ )
795
+
796
+
797
+ @cute.jit
798
+ def softmax_block_sparse_sm100(
799
+ blocksparse_tensors: BlockSparseTensors,
800
+ batch_idx,
801
+ head_idx,
802
+ m_block,
803
+ softmax_step: Callable,
804
+ mask_fn: Callable,
805
+ mask_fn_none: Callable,
806
+ mma_si_consumer_phase: Int32,
807
+ si_corr_producer_phase: Int32,
808
+ s0_s1_sequence_phase: Int32,
809
+ pipeline_sm_stats: cutlass.pipeline.PipelineAsync,
810
+ sm_stats_barrier: cutlass.pipeline.NamedBarrier,
811
+ q_stage: cutlass.Constexpr,
812
+ stage_idx: Int32,
813
+ check_m_boundary: bool,
814
+ qhead_per_kvhead: cutlass.Constexpr,
815
+ q_subtile_factor: cutlass.Constexpr[int] = 1,
816
+ ):
817
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
818
+ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)
819
+
820
+ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
821
+
822
+ curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
823
+ curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
824
+
825
+ if const_expr(full_block_cnt is not None):
826
+ curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
827
+ curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
828
+ else:
829
+ curr_full_block_cnt = Int32(0)
830
+ curr_full_block_idx = None
831
+
832
+ total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt
833
+
834
+ if total_block_cnt == 0:
835
+ # See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
836
+ # pipeline_sm_stats.producer_commit_w_index(stage_idx)
837
+ sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx)
838
+ else:
839
+ if curr_mask_block_cnt > 0:
840
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
841
+ (
842
+ mma_si_consumer_phase,
843
+ si_corr_producer_phase,
844
+ s0_s1_sequence_phase,
845
+ ) = softmax_step(
846
+ mma_si_consumer_phase,
847
+ si_corr_producer_phase,
848
+ s0_s1_sequence_phase,
849
+ mask_n_block,
850
+ is_first=True,
851
+ mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),
852
+ )
853
+ for i in cutlass.range(1, curr_mask_block_cnt):
854
+ mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]
855
+ (
856
+ mma_si_consumer_phase,
857
+ si_corr_producer_phase,
858
+ s0_s1_sequence_phase,
859
+ ) = softmax_step(
860
+ mma_si_consumer_phase,
861
+ si_corr_producer_phase,
862
+ s0_s1_sequence_phase,
863
+ mask_n_block,
864
+ mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),
865
+ )
866
+
867
+ if curr_full_block_cnt > 0:
868
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]
869
+ if curr_mask_block_cnt == 0:
870
+ (
871
+ mma_si_consumer_phase,
872
+ si_corr_producer_phase,
873
+ s0_s1_sequence_phase,
874
+ ) = softmax_step(
875
+ mma_si_consumer_phase,
876
+ si_corr_producer_phase,
877
+ s0_s1_sequence_phase,
878
+ full_n_block,
879
+ is_first=True,
880
+ mask_fn=partial(
881
+ mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary
882
+ ),
883
+ )
884
+ else:
885
+ (
886
+ mma_si_consumer_phase,
887
+ si_corr_producer_phase,
888
+ s0_s1_sequence_phase,
889
+ ) = softmax_step(
890
+ mma_si_consumer_phase,
891
+ si_corr_producer_phase,
892
+ s0_s1_sequence_phase,
893
+ full_n_block,
894
+ is_first=False,
895
+ mask_fn=partial(
896
+ mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
897
+ ),
898
+ )
899
+ for i in cutlass.range(1, curr_full_block_cnt):
900
+ full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]
901
+ (
902
+ mma_si_consumer_phase,
903
+ si_corr_producer_phase,
904
+ s0_s1_sequence_phase,
905
+ ) = softmax_step(
906
+ mma_si_consumer_phase,
907
+ si_corr_producer_phase,
908
+ s0_s1_sequence_phase,
909
+ full_n_block,
910
+ mask_fn=partial(
911
+ mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary
912
+ ),
913
+ )
914
+
915
+ return (
916
+ mma_si_consumer_phase,
917
+ si_corr_producer_phase,
918
+ s0_s1_sequence_phase,
919
+ total_block_cnt == 0,
920
+ )
921
+
922
+
923
+ # =============================================================================
924
+ # Backward-specific block-sparse helpers (SM100)
925
+ # =============================================================================
926
+ #
927
+ # In backward, iteration is transposed compared to forward:
928
+ # - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)
929
+ # - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)
930
+ #
931
+ # The backward block-sparse tensors use "Q direction" indexing:
932
+ # - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile
933
+ # - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process
934
+ #
935
+
936
+
937
+ @cute.jit
938
+ def get_total_q_block_count_bwd(
939
+ blocksparse_tensors: BlockSparseTensors,
940
+ batch_idx,
941
+ head_idx,
942
+ n_block,
943
+ subtile_factor: cutlass.Constexpr = 1,
944
+ m_block_max: int = 0,
945
+ ):
946
+ """Count total tile iterations for given n_block (KV tile) in backward."""
947
+ q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors
948
+ total = q_block_cnt[batch_idx, head_idx, n_block]
949
+ if const_expr(full_block_cnt is not None):
950
+ total = total + full_block_cnt[batch_idx, head_idx, n_block]
951
+ return total * subtile_factor
952
+
953
+
954
+ @cute.jit
955
+ def produce_block_sparse_q_loads_bwd_sm100(
956
+ blocksparse_tensors: BlockSparseTensors,
957
+ batch_idx,
958
+ head_idx,
959
+ n_block,
960
+ # Pipeline states (will be returned after advancing)
961
+ producer_state_Q_LSE,
962
+ producer_state_dO_dPsum,
963
+ # Pipelines
964
+ pipeline_Q,
965
+ pipeline_LSE,
966
+ pipeline_dO,
967
+ pipeline_dPsum,
968
+ # Load functions
969
+ load_K,
970
+ load_V,
971
+ load_Q,
972
+ load_dO,
973
+ copy_stats,
974
+ # Global tensors for LSE/dPsum
975
+ gLSE,
976
+ sLSE,
977
+ gdPsum,
978
+ sdPsum,
979
+ # TMA copy bytes for extra_tx_count
980
+ tma_copy_bytes_K,
981
+ tma_copy_bytes_V,
982
+ # Flags for which loads to perform
983
+ should_load_Q: cutlass.Constexpr,
984
+ should_load_dO: cutlass.Constexpr,
985
+ # Subtiling factor and bounds
986
+ subtile_factor: cutlass.Constexpr = 1,
987
+ m_block_max: int = 0,
988
+ ):
989
+ """SM100 backward block sparse loading with subtiling.
990
+
991
+ Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).
992
+ First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.
993
+ """
994
+ (
995
+ curr_q_cnt,
996
+ curr_q_idx,
997
+ curr_full_cnt,
998
+ curr_full_idx,
999
+ loop_count,
1000
+ ) = get_block_sparse_iteration_info_bwd(
1001
+ blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
1002
+ )
1003
+
1004
+ for iter_idx in cutlass.range(loop_count, unroll=1):
1005
+ m_block, _ = get_m_block_from_iter_bwd(
1006
+ iter_idx,
1007
+ curr_q_cnt,
1008
+ curr_q_idx,
1009
+ curr_full_cnt,
1010
+ curr_full_idx,
1011
+ subtile_factor,
1012
+ m_block_max,
1013
+ )
1014
+ m_block_safe = m_block
1015
+ if m_block_max > 0:
1016
+ m_block_safe = cutlass.min(m_block, m_block_max - 1)
1017
+
1018
+ if iter_idx == 0:
1019
+ # First block: load K/V alongside Q/dO
1020
+ if const_expr(should_load_Q):
1021
+ pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)
1022
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
1023
+ load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
1024
+ pipeline_Q.producer_commit(producer_state_Q_LSE)
1025
+ pipeline_LSE.producer_acquire(producer_state_Q_LSE)
1026
+ with cute.arch.elect_one():
1027
+ copy_stats(
1028
+ gLSE[None, m_block_safe],
1029
+ sLSE[None, producer_state_Q_LSE.index],
1030
+ mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
1031
+ )
1032
+ producer_state_Q_LSE.advance()
1033
+ if const_expr(should_load_dO):
1034
+ pipeline_dO.producer_acquire(
1035
+ producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V
1036
+ )
1037
+ load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))
1038
+ load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
1039
+ pipeline_dO.producer_commit(producer_state_dO_dPsum)
1040
+ pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
1041
+ with cute.arch.elect_one():
1042
+ copy_stats(
1043
+ gdPsum[None, m_block_safe],
1044
+ sdPsum[None, producer_state_dO_dPsum.index],
1045
+ mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
1046
+ )
1047
+ producer_state_dO_dPsum.advance()
1048
+ else:
1049
+ # Subsequent blocks: just load Q/dO (K/V already loaded)
1050
+ if const_expr(should_load_Q):
1051
+ pipeline_Q.producer_acquire(producer_state_Q_LSE)
1052
+ load_Q(m_block_safe, producer_state=producer_state_Q_LSE)
1053
+ pipeline_Q.producer_commit(producer_state_Q_LSE)
1054
+ pipeline_LSE.producer_acquire(producer_state_Q_LSE)
1055
+ with cute.arch.elect_one():
1056
+ copy_stats(
1057
+ gLSE[None, m_block_safe],
1058
+ sLSE[None, producer_state_Q_LSE.index],
1059
+ mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
1060
+ )
1061
+ producer_state_Q_LSE.advance()
1062
+ if const_expr(should_load_dO):
1063
+ pipeline_dO.producer_acquire(producer_state_dO_dPsum)
1064
+ load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)
1065
+ pipeline_dO.producer_commit(producer_state_dO_dPsum)
1066
+ pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
1067
+ with cute.arch.elect_one():
1068
+ copy_stats(
1069
+ gdPsum[None, m_block_safe],
1070
+ sdPsum[None, producer_state_dO_dPsum.index],
1071
+ mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
1072
+ )
1073
+ producer_state_dO_dPsum.advance()
1074
+
1075
+ return producer_state_Q_LSE, producer_state_dO_dPsum
1076
+
1077
+
1078
+ @cute.jit
1079
+ def get_block_sparse_iteration_info_bwd(
1080
+ blocksparse_tensors: BlockSparseTensors,
1081
+ batch_idx,
1082
+ head_idx,
1083
+ n_block,
1084
+ subtile_factor: cutlass.Constexpr = 1,
1085
+ m_block_max: int = 0,
1086
+ ):
1087
+ """Extract block-sparse iteration info for backward pass.
1088
+
1089
+ Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).
1090
+ """
1091
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1092
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1093
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1094
+
1095
+ if const_expr(full_cnt is not None):
1096
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1097
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1098
+ else:
1099
+ curr_full_cnt = Int32(0)
1100
+ curr_full_idx = None
1101
+
1102
+ sparse_block_count = curr_q_cnt
1103
+ if const_expr(full_cnt is not None):
1104
+ sparse_block_count = sparse_block_count + curr_full_cnt
1105
+ total_count = sparse_block_count * subtile_factor
1106
+
1107
+ return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count
1108
+
1109
+
1110
+ @cute.jit
1111
+ def get_m_block_from_iter_bwd(
1112
+ iter_idx,
1113
+ curr_q_cnt,
1114
+ curr_q_idx: cute.Tensor,
1115
+ curr_full_cnt,
1116
+ curr_full_idx: Optional[cute.Tensor],
1117
+ subtile_factor: cutlass.Constexpr = 1,
1118
+ m_block_max: int = 0,
1119
+ ):
1120
+ """Derive m_block index and is_full_block flag from iteration index.
1121
+
1122
+ Returns (m_block, is_full_block):
1123
+ - m_block: The actual Q-tile block index
1124
+ - is_full_block: True if this is a full block (no mask_mod needed)
1125
+ """
1126
+ sparse_iter_idx = iter_idx // subtile_factor
1127
+ subtile_offset = iter_idx % subtile_factor
1128
+
1129
+ sparse_m_block = Int32(0)
1130
+ is_full_block = False
1131
+ if const_expr(curr_full_idx is not None):
1132
+ if sparse_iter_idx < curr_q_cnt:
1133
+ sparse_m_block = curr_q_idx[sparse_iter_idx]
1134
+ else:
1135
+ sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt]
1136
+ is_full_block = True
1137
+ else:
1138
+ sparse_m_block = curr_q_idx[sparse_iter_idx]
1139
+
1140
+ return sparse_m_block * subtile_factor + subtile_offset, is_full_block
1141
+
1142
+
1143
+ @cute.jit
1144
+ def _load_q_do_block_sm90(
1145
+ m_block,
1146
+ producer_state_Q,
1147
+ producer_state_dO,
1148
+ pipeline_Q,
1149
+ pipeline_dO,
1150
+ load_K,
1151
+ load_V,
1152
+ load_Q,
1153
+ load_dO,
1154
+ load_LSE,
1155
+ load_dPsum,
1156
+ tma_copy_bytes_K,
1157
+ tma_copy_bytes_V,
1158
+ Q_stage_eq_dO_stage: cutlass.Constexpr,
1159
+ load_kv: bool,
1160
+ ):
1161
+ """Load one Q/dO block, optionally loading K/V on first iteration."""
1162
+ if load_kv:
1163
+ pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K)
1164
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
1165
+ else:
1166
+ pipeline_Q.producer_acquire(producer_state_Q)
1167
+ load_Q(m_block, producer_state=producer_state_Q)
1168
+ load_LSE(m_block, producer_state=producer_state_Q)
1169
+
1170
+ producer_state_dO_cur = (
1171
+ producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q
1172
+ )
1173
+ if load_kv:
1174
+ pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V)
1175
+ load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
1176
+ else:
1177
+ pipeline_dO.producer_acquire(producer_state_dO_cur)
1178
+ load_dO(m_block, producer_state=producer_state_dO_cur)
1179
+ load_dPsum(m_block, producer_state=producer_state_dO_cur)
1180
+
1181
+ producer_state_Q.advance()
1182
+ producer_state_dO.advance()
1183
+ return producer_state_Q, producer_state_dO
1184
+
1185
+
1186
+ @cute.jit
1187
+ def produce_block_sparse_q_loads_bwd_sm90(
1188
+ blocksparse_tensors: BlockSparseTensors,
1189
+ batch_idx,
1190
+ head_idx,
1191
+ n_block,
1192
+ producer_state_Q,
1193
+ producer_state_dO,
1194
+ pipeline_Q,
1195
+ pipeline_dO,
1196
+ load_K,
1197
+ load_V,
1198
+ load_Q,
1199
+ load_dO,
1200
+ load_LSE,
1201
+ load_dPsum,
1202
+ tma_copy_bytes_K,
1203
+ tma_copy_bytes_V,
1204
+ Q_stage_eq_dO_stage: cutlass.Constexpr,
1205
+ subtile_factor: cutlass.Constexpr,
1206
+ m_block_max: int,
1207
+ ):
1208
+ """SM90 backward block sparse loading with separate partial/full loops.
1209
+
1210
+ K/V are loaded with the first valid block. Iterates partial blocks first,
1211
+ then full blocks, matching consumer order.
1212
+
1213
+ Returns updated (producer_state_Q, producer_state_dO).
1214
+ """
1215
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1216
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1217
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1218
+
1219
+ if const_expr(full_cnt is not None):
1220
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1221
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1222
+ else:
1223
+ curr_full_cnt = Int32(0)
1224
+ curr_full_idx = None
1225
+
1226
+ kv_loaded = False
1227
+
1228
+ for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1229
+ sparse_idx = iter_idx // subtile_factor
1230
+ subtile_offset = iter_idx % subtile_factor
1231
+ m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1232
+
1233
+ if m_block < m_block_max:
1234
+ producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
1235
+ m_block,
1236
+ producer_state_Q,
1237
+ producer_state_dO,
1238
+ pipeline_Q,
1239
+ pipeline_dO,
1240
+ load_K,
1241
+ load_V,
1242
+ load_Q,
1243
+ load_dO,
1244
+ load_LSE,
1245
+ load_dPsum,
1246
+ tma_copy_bytes_K,
1247
+ tma_copy_bytes_V,
1248
+ Q_stage_eq_dO_stage,
1249
+ load_kv=not kv_loaded,
1250
+ )
1251
+ kv_loaded = True
1252
+
1253
+ if const_expr(full_cnt is not None):
1254
+ for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1255
+ sparse_idx = iter_idx // subtile_factor
1256
+ subtile_offset = iter_idx % subtile_factor
1257
+ m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1258
+
1259
+ if m_block < m_block_max:
1260
+ producer_state_Q, producer_state_dO = _load_q_do_block_sm90(
1261
+ m_block,
1262
+ producer_state_Q,
1263
+ producer_state_dO,
1264
+ pipeline_Q,
1265
+ pipeline_dO,
1266
+ load_K,
1267
+ load_V,
1268
+ load_Q,
1269
+ load_dO,
1270
+ load_LSE,
1271
+ load_dPsum,
1272
+ tma_copy_bytes_K,
1273
+ tma_copy_bytes_V,
1274
+ Q_stage_eq_dO_stage,
1275
+ load_kv=not kv_loaded,
1276
+ )
1277
+ kv_loaded = True
1278
+
1279
+ return producer_state_Q, producer_state_dO
1280
+
1281
+
1282
+ @cute.jit
1283
+ def consume_block_sparse_mma_bwd_sm90(
1284
+ blocksparse_tensors: BlockSparseTensors,
1285
+ batch_idx,
1286
+ head_idx,
1287
+ n_block,
1288
+ consumer_state_Q,
1289
+ consumer_state_dO,
1290
+ mma_one_m_block_fn,
1291
+ mask,
1292
+ mask_mod,
1293
+ is_causal: cutlass.Constexpr,
1294
+ is_local: cutlass.Constexpr,
1295
+ thr_mma_SdP,
1296
+ score_mod_fn=None,
1297
+ score_mod_bwd_fn=None,
1298
+ subtile_factor: cutlass.Constexpr = 1,
1299
+ m_block_max: int = 0,
1300
+ aux_tensors=None,
1301
+ fastdiv_mods=(None, None),
1302
+ ):
1303
+ """SM90 backward block sparse MMA consumption with separate partial/full loops.
1304
+
1305
+ Partial blocks are processed first (with mask_mod applied), then full blocks
1306
+ (without mask_mod). This ensures mask_mod is only applied where needed.
1307
+
1308
+ Returns updated (consumer_state_Q, consumer_state_dO).
1309
+ """
1310
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1311
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1312
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1313
+
1314
+ if const_expr(full_cnt is not None):
1315
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1316
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1317
+ else:
1318
+ curr_full_cnt = Int32(0)
1319
+ curr_full_idx = None
1320
+
1321
+ dKV_accumulate = False
1322
+
1323
+ mask_fn_partial = partial(
1324
+ mask.apply_mask,
1325
+ batch_idx=batch_idx,
1326
+ head_idx=head_idx,
1327
+ n_block=n_block,
1328
+ thr_mma=thr_mma_SdP,
1329
+ mask_seqlen=True,
1330
+ mask_causal=is_causal,
1331
+ mask_local=is_local,
1332
+ mask_mod=mask_mod,
1333
+ aux_tensors=aux_tensors,
1334
+ fastdiv_mods=fastdiv_mods,
1335
+ )
1336
+
1337
+ mask_fn_full = partial(
1338
+ mask.apply_mask,
1339
+ batch_idx=batch_idx,
1340
+ head_idx=head_idx,
1341
+ n_block=n_block,
1342
+ thr_mma=thr_mma_SdP,
1343
+ mask_seqlen=True,
1344
+ mask_causal=is_causal,
1345
+ mask_local=is_local,
1346
+ aux_tensors=aux_tensors,
1347
+ fastdiv_mods=fastdiv_mods,
1348
+ )
1349
+
1350
+ for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1351
+ sparse_idx = iter_idx // subtile_factor
1352
+ subtile_offset = iter_idx % subtile_factor
1353
+ m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1354
+
1355
+ if m_block < m_block_max:
1356
+ consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
1357
+ m_block,
1358
+ consumer_state_Q,
1359
+ consumer_state_dO,
1360
+ mask_fn=mask_fn_partial,
1361
+ score_mod_fn=score_mod_fn,
1362
+ score_mod_bwd_fn=score_mod_bwd_fn,
1363
+ dKV_accumulate=dKV_accumulate,
1364
+ )
1365
+ dKV_accumulate = True
1366
+
1367
+ if const_expr(full_cnt is not None):
1368
+ for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1369
+ sparse_idx = iter_idx // subtile_factor
1370
+ subtile_offset = iter_idx % subtile_factor
1371
+ m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1372
+
1373
+ if m_block < m_block_max:
1374
+ consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(
1375
+ m_block,
1376
+ consumer_state_Q,
1377
+ consumer_state_dO,
1378
+ mask_fn=mask_fn_full,
1379
+ score_mod_fn=score_mod_fn,
1380
+ score_mod_bwd_fn=score_mod_bwd_fn,
1381
+ dKV_accumulate=dKV_accumulate,
1382
+ )
1383
+ dKV_accumulate = True
1384
+
1385
+ return consumer_state_Q, consumer_state_dO
1386
+
1387
+
1388
+ @cute.jit
1389
+ def _store_one_dQaccum_sm90(
1390
+ m_block,
1391
+ sdQaccum: cute.Tensor,
1392
+ gdQaccum: cute.Tensor,
1393
+ num_mma_warp_groups: cutlass.Constexpr,
1394
+ num_threads_per_warp_group: cutlass.Constexpr,
1395
+ tma_copy_bytes_dQ,
1396
+ ):
1397
+ """Store dQaccum for a single m_block."""
1398
+ for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1399
+ cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
1400
+ cute.arch.barrier_arrive(
1401
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1402
+ number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1403
+ )
1404
+ for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1405
+ cute.arch.barrier(
1406
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1407
+ number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1408
+ )
1409
+ with cute.arch.elect_one():
1410
+ copy_utils.cpasync_reduce_bulk_add_f32(
1411
+ sdQaccum[None, warp_group_idx].iterator,
1412
+ gdQaccum[None, warp_group_idx, m_block].iterator,
1413
+ tma_copy_bytes_dQ,
1414
+ )
1415
+ cute.arch.cp_async_bulk_commit_group()
1416
+
1417
+
1418
+ @cute.jit
1419
+ def dQaccum_store_block_sparse_bwd_sm90(
1420
+ blocksparse_tensors: BlockSparseTensors,
1421
+ batch_idx,
1422
+ head_idx,
1423
+ n_block,
1424
+ sdQaccum: cute.Tensor,
1425
+ gdQaccum: cute.Tensor,
1426
+ subtile_factor: cutlass.Constexpr,
1427
+ m_block_max: int,
1428
+ num_mma_warp_groups: cutlass.Constexpr,
1429
+ num_threads_per_warp_group: cutlass.Constexpr,
1430
+ tma_copy_bytes_dQ,
1431
+ ):
1432
+ """SM90 backward block sparse dQaccum store with separate partial/full loops.
1433
+
1434
+ Iterates partial blocks first, then full blocks, matching producer/consumer order.
1435
+ """
1436
+ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
1437
+ curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
1438
+ curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]
1439
+
1440
+ if const_expr(full_cnt is not None):
1441
+ curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
1442
+ curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
1443
+ else:
1444
+ curr_full_cnt = Int32(0)
1445
+ curr_full_idx = None
1446
+
1447
+ for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):
1448
+ sparse_idx = iter_idx // subtile_factor
1449
+ subtile_offset = iter_idx % subtile_factor
1450
+ m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset
1451
+
1452
+ if m_block < m_block_max:
1453
+ _store_one_dQaccum_sm90(
1454
+ m_block,
1455
+ sdQaccum,
1456
+ gdQaccum,
1457
+ num_mma_warp_groups,
1458
+ num_threads_per_warp_group,
1459
+ tma_copy_bytes_dQ,
1460
+ )
1461
+
1462
+ if const_expr(full_cnt is not None):
1463
+ for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):
1464
+ sparse_idx = iter_idx // subtile_factor
1465
+ subtile_offset = iter_idx % subtile_factor
1466
+ m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset
1467
+
1468
+ if m_block < m_block_max:
1469
+ _store_one_dQaccum_sm90(
1470
+ m_block,
1471
+ sdQaccum,
1472
+ gdQaccum,
1473
+ num_mma_warp_groups,
1474
+ num_threads_per_warp_group,
1475
+ tma_copy_bytes_dQ,
1476
+ )
build/torch-cuda/block_sparsity.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Block-sparsity utilities for FlexAttention
3
+ """
4
+
5
+ from typing import Callable, NamedTuple, Tuple
6
+
7
+ import cutlass.cute as cute
8
+ import torch
9
+
10
+ from .cute_dsl_utils import get_broadcast_dims, to_cute_tensor
11
+
12
+
13
+ def ceildiv(a: int, b: int) -> int:
14
+ return (a + b - 1) // b
15
+
16
+
17
+ class BlockSparseTensors(NamedTuple):
18
+ mask_block_cnt: cute.Tensor
19
+ mask_block_idx: cute.Tensor
20
+ full_block_cnt: cute.Tensor | None
21
+ full_block_idx: cute.Tensor | None
22
+
23
+ def __new_from_mlir_values__(self, values):
24
+ if len(values) == 2:
25
+ values = (*values, None, None)
26
+ return BlockSparseTensors(*values)
27
+
28
+
29
+ class BlockSparseTensorsTorch(NamedTuple):
30
+ mask_block_cnt: torch.Tensor
31
+ mask_block_idx: torch.Tensor
32
+ full_block_cnt: torch.Tensor | None = None
33
+ full_block_idx: torch.Tensor | None = None
34
+ block_size: tuple[int, int] | None = None
35
+
36
+
37
+ def _expand_sparsity_tensor(
38
+ tensor: torch.Tensor,
39
+ expected_shape: Tuple[int, ...],
40
+ tensor_name: str,
41
+ context: str | None,
42
+ hint: str | Callable[[], str] | None,
43
+ ) -> torch.Tensor:
44
+ """Check if we need to expand the tensor to expected shape, and do so if possible."""
45
+ needs_expand = tensor.shape != expected_shape
46
+ if not needs_expand:
47
+ return tensor
48
+ can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))
49
+ if not can_expand:
50
+ context_clause = f" ({context})" if context else ""
51
+ resolved_hint = hint() if callable(hint) else hint
52
+ hint_clause = f" Hint: {resolved_hint}" if resolved_hint else ""
53
+ raise ValueError(
54
+ f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}."
55
+ f"{hint_clause}"
56
+ )
57
+ return tensor.expand(*expected_shape)
58
+
59
+
60
+ def _check_and_expand_block(
61
+ name: str,
62
+ cnt: torch.Tensor | None,
63
+ idx: torch.Tensor | None,
64
+ expected_count_shape: Tuple[int, int, int],
65
+ expected_index_shape: Tuple[int, int, int, int],
66
+ context: str | None,
67
+ hint: str | Callable[[], str] | None,
68
+ ) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
69
+ if (cnt is None) != (idx is None):
70
+ raise ValueError(
71
+ f"{name}_block_cnt and {name}_block_idx must both be provided or both be None"
72
+ )
73
+ if cnt is None or idx is None:
74
+ return None, None
75
+ if cnt.dtype != torch.int32 or idx.dtype != torch.int32:
76
+ raise ValueError(f"{name}_block tensors must have dtype torch.int32")
77
+ if cnt.device != idx.device:
78
+ raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device")
79
+ if not cnt.is_cuda or not idx.is_cuda:
80
+ raise ValueError(f"{name}_block tensors must live on CUDA")
81
+ expanded_cnt = _expand_sparsity_tensor(
82
+ cnt, expected_count_shape, f"{name}_block_cnt", context, hint
83
+ )
84
+ expanded_idx = _expand_sparsity_tensor(
85
+ idx, expected_index_shape, f"{name}_block_idx", context, hint
86
+ )
87
+ return expanded_cnt, expanded_idx
88
+
89
+
90
+ def get_block_sparse_expected_shapes(
91
+ batch_size: int,
92
+ num_head: int,
93
+ seqlen_q: int,
94
+ seqlen_k: int,
95
+ m_block_size: int,
96
+ n_block_size: int,
97
+ q_stage: int,
98
+ ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
99
+ """Return (expected_count_shape, expected_index_shape) for block sparse normalization."""
100
+ m_block_size_effective = q_stage * m_block_size
101
+ expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)
102
+ expected_n_blocks = ceildiv(seqlen_k, n_block_size)
103
+ expected_count_shape = (batch_size, num_head, expected_m_blocks)
104
+ expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
105
+ return expected_count_shape, expected_index_shape
106
+
107
+
108
+ def infer_block_sparse_expected_shapes(
109
+ tensors: BlockSparseTensorsTorch,
110
+ *,
111
+ batch_size: int,
112
+ num_head: int,
113
+ seqlen_q: int,
114
+ seqlen_k: int,
115
+ m_block_size: int,
116
+ n_block_size: int,
117
+ q_stage: int,
118
+ context: str,
119
+ sparse_block_size_q: int | None = None,
120
+ sparse_block_size_kv: int | None = None,
121
+ ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int], int]:
122
+ """Infer shapes and scaling for block-sparse tensors.
123
+
124
+ Expectations:
125
+ - mask_block_cnt is (B, H, M) and mask_block_idx is (B, H, M, N).
126
+ - Batch/head dims may be 1 for broadcast, or match the requested sizes.
127
+ - sparse_block_size_kv must match tile_n.
128
+ - sparse_block_size_q must be a multiple of q_stage * tile_m.
129
+ - If sparse_block_size_q is omitted and seqlen_q/num_m_blocks is ambiguous,
130
+ the caller must provide block_size to disambiguate. TODO will make this required in a future PR.
131
+ """
132
+ base_m_block = q_stage * m_block_size
133
+ base_n_block = n_block_size
134
+ if sparse_block_size_kv is None:
135
+ sparse_block_size_kv = base_n_block
136
+ if sparse_block_size_kv != base_n_block:
137
+ raise ValueError(f"Block sparse tensors{context} require BLOCK_SIZE_KV={base_n_block}.")
138
+ if tensors.mask_block_idx is None:
139
+ raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
140
+ num_m_blocks = tensors.mask_block_idx.shape[2]
141
+
142
+ if sparse_block_size_q is None:
143
+ min_block_size = ceildiv(seqlen_q, num_m_blocks)
144
+ if num_m_blocks == 1:
145
+ max_block_size = seqlen_q
146
+ else:
147
+ max_block_size = (seqlen_q - 1) // (num_m_blocks - 1)
148
+ if max_block_size != min_block_size and base_m_block != 1:
149
+ raise ValueError(
150
+ f"Block sparse tensors{context} require explicit sparse_block_size[0] "
151
+ f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
152
+ )
153
+ sparse_block_size_q = min_block_size
154
+
155
+ if sparse_block_size_q % base_m_block != 0:
156
+ raise ValueError(
157
+ f"Block sparse tensors{context} have block size {sparse_block_size_q}, "
158
+ f"which must be a multiple of {base_m_block}."
159
+ )
160
+
161
+ expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
162
+ expected_n_blocks = ceildiv(seqlen_k, sparse_block_size_kv)
163
+ q_subtile_factor = sparse_block_size_q // base_m_block
164
+ expected_count_shape = (batch_size, num_head, expected_m_blocks)
165
+ expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
166
+
167
+ mask_block_cnt = tensors.mask_block_cnt
168
+ mask_block_idx = tensors.mask_block_idx
169
+ if mask_block_cnt is None or mask_block_idx is None:
170
+ raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
171
+ if mask_block_cnt.ndim != 3 or mask_block_idx.ndim != 4:
172
+ raise ValueError(
173
+ f"Block sparse tensors{context} must have shapes (B, H, M) and (B, H, M, N)."
174
+ )
175
+ for dim_name, cur, tgt in (
176
+ ("batch", mask_block_cnt.shape[0], expected_count_shape[0]),
177
+ ("head", mask_block_cnt.shape[1], expected_count_shape[1]),
178
+ ):
179
+ if cur != tgt and cur != 1:
180
+ raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
181
+ for dim_name, cur, tgt in (
182
+ ("batch", mask_block_idx.shape[0], expected_index_shape[0]),
183
+ ("head", mask_block_idx.shape[1], expected_index_shape[1]),
184
+ ):
185
+ if cur != tgt and cur != 1:
186
+ raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
187
+ if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
188
+ raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
189
+ if mask_block_idx.shape[3] != expected_n_blocks:
190
+ raise ValueError(
191
+ f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
192
+ )
193
+ if expected_m_blocks != num_m_blocks:
194
+ raise ValueError(
195
+ f"Block sparse tensors{context} m-block dimension {num_m_blocks} does not match "
196
+ f"sparse_block_size_q={sparse_block_size_q}. "
197
+ f"Set BlockSparseTensorsTorch.block_size to match the BlockMask BLOCK_SIZE."
198
+ )
199
+ return expected_count_shape, expected_index_shape, q_subtile_factor
200
+
201
+
202
+ def get_block_sparse_expected_shapes_bwd(
203
+ batch_size: int,
204
+ num_head: int,
205
+ seqlen_q: int,
206
+ seqlen_k: int,
207
+ m_block_size: int,
208
+ n_block_size: int,
209
+ subtile_factor: int,
210
+ ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
211
+ """Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.
212
+
213
+ Backward uses Q-direction indexing (transposed from forward), where shapes are
214
+ indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined
215
+ by subtile_factor * m_block_size.
216
+ """
217
+ sparse_block_size_q = subtile_factor * m_block_size
218
+ expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
219
+ expected_n_blocks = ceildiv(seqlen_k, n_block_size)
220
+ expected_count_shape = (batch_size, num_head, expected_n_blocks)
221
+ expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks)
222
+ return expected_count_shape, expected_index_shape
223
+
224
+
225
+ def normalize_block_sparse_tensors(
226
+ tensors: BlockSparseTensorsTorch,
227
+ *,
228
+ expected_count_shape: Tuple[int, int, int],
229
+ expected_index_shape: Tuple[int, int, int, int],
230
+ context: str | None = None,
231
+ hint: str | Callable[[], str] | None = None,
232
+ ) -> BlockSparseTensorsTorch:
233
+ if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:
234
+ raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
235
+
236
+ mask_cnt, mask_idx = _check_and_expand_block(
237
+ "mask",
238
+ tensors.mask_block_cnt,
239
+ tensors.mask_block_idx,
240
+ expected_count_shape,
241
+ expected_index_shape,
242
+ context,
243
+ hint,
244
+ )
245
+ if mask_cnt is None or mask_idx is None:
246
+ raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.")
247
+
248
+ full_cnt, full_idx = _check_and_expand_block(
249
+ "full",
250
+ tensors.full_block_cnt,
251
+ tensors.full_block_idx,
252
+ expected_count_shape,
253
+ expected_index_shape,
254
+ context,
255
+ hint,
256
+ )
257
+ if full_cnt is not None and mask_cnt.device != full_cnt.device:
258
+ raise ValueError("All block sparse tensors must be on the same device")
259
+
260
+ return BlockSparseTensorsTorch(
261
+ mask_block_cnt=mask_cnt,
262
+ mask_block_idx=mask_idx,
263
+ full_block_cnt=full_cnt,
264
+ full_block_idx=full_idx,
265
+ block_size=tensors.block_size,
266
+ )
267
+
268
+
269
+ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
270
+ return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))
271
+
272
+
273
+ def get_block_sparse_broadcast_pattern(
274
+ tensors: BlockSparseTensorsTorch,
275
+ ) -> Tuple[Tuple[bool, ...], ...] | None:
276
+ """Return broadcast pattern for block sparse tensors by checking actual strides.
277
+
278
+ Returns a tuple of broadcast patterns (one per tensor) where each pattern
279
+ is a tuple of bools indicating which dims have stride=0.
280
+ This is used in compile keys to ensure kernels are recompiled when
281
+ broadcast patterns change, since CuTe's mark_layout_dynamic() keeps
282
+ stride=0 as static.
283
+
284
+ The tensors should already be expanded/normalized before calling this function.
285
+
286
+ Returns None if block sparsity is not enabled.
287
+ """
288
+ if not is_block_sparsity_enabled(tensors):
289
+ return None
290
+
291
+ patterns = []
292
+ for tensor in (
293
+ tensors.mask_block_cnt,
294
+ tensors.mask_block_idx,
295
+ tensors.full_block_cnt,
296
+ tensors.full_block_idx,
297
+ ):
298
+ if tensor is not None:
299
+ patterns.append(get_broadcast_dims(tensor))
300
+ else:
301
+ patterns.append(None)
302
+ return tuple(patterns)
303
+
304
+
305
+ def normalize_block_sparse_config(
306
+ tensors: BlockSparseTensorsTorch,
307
+ *,
308
+ batch_size: int,
309
+ num_head: int,
310
+ seqlen_q: int,
311
+ seqlen_k: int,
312
+ block_size: tuple[int, int],
313
+ q_stage: int,
314
+ ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
315
+ m_block_size, n_block_size = block_size
316
+ if tensors.block_size is None:
317
+ sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size
318
+ else:
319
+ sparse_block_size_q, sparse_block_size_kv = tensors.block_size
320
+ if sparse_block_size_kv != n_block_size:
321
+ raise ValueError(
322
+ f"Block sparsity requires sparse_block_size[1]={n_block_size} to match tile_n."
323
+ )
324
+ expected_count_shape, expected_index_shape, q_subtile_factor = (
325
+ infer_block_sparse_expected_shapes(
326
+ tensors,
327
+ batch_size=batch_size,
328
+ num_head=num_head,
329
+ seqlen_q=seqlen_q,
330
+ seqlen_k=seqlen_k,
331
+ m_block_size=m_block_size,
332
+ n_block_size=n_block_size,
333
+ q_stage=q_stage,
334
+ context="forward",
335
+ sparse_block_size_q=sparse_block_size_q,
336
+ sparse_block_size_kv=sparse_block_size_kv,
337
+ )
338
+ )
339
+ normalized_tensors = normalize_block_sparse_tensors(
340
+ tensors,
341
+ expected_count_shape=expected_count_shape,
342
+ expected_index_shape=expected_index_shape,
343
+ )
344
+ return (
345
+ normalized_tensors,
346
+ get_block_sparse_broadcast_pattern(normalized_tensors),
347
+ q_subtile_factor,
348
+ )
349
+
350
+
351
+ def normalize_block_sparse_config_bwd(
352
+ tensors: BlockSparseTensorsTorch,
353
+ *,
354
+ batch_size: int,
355
+ num_head: int,
356
+ seqlen_q: int,
357
+ seqlen_k: int,
358
+ block_size: tuple[int, int],
359
+ subtile_factor: int,
360
+ ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None]:
361
+ m_block_size, n_block_size = block_size
362
+ if tensors.block_size is None:
363
+ sparse_block_size_q, sparse_block_size_kv = subtile_factor * m_block_size, n_block_size
364
+ else:
365
+ sparse_block_size_q, sparse_block_size_kv = tensors.block_size
366
+ if sparse_block_size_q != subtile_factor * m_block_size:
367
+ raise ValueError(
368
+ f"Block sparsity expects sparse_block_size_q={subtile_factor * m_block_size} "
369
+ f"for subtile_factor={subtile_factor}."
370
+ )
371
+ if sparse_block_size_kv != n_block_size:
372
+ raise ValueError(
373
+ f"Block sparsity expects sparse_block_size[1]={n_block_size} to match tile_n."
374
+ )
375
+ expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
376
+ batch_size,
377
+ num_head,
378
+ seqlen_q,
379
+ seqlen_k,
380
+ m_block_size,
381
+ n_block_size,
382
+ subtile_factor,
383
+ )
384
+ normalized_tensors = normalize_block_sparse_tensors(
385
+ tensors,
386
+ expected_count_shape=expected_count_shape,
387
+ expected_index_shape=expected_index_shape,
388
+ context="_flash_attn_bwd",
389
+ hint=lambda: (
390
+ f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, "
391
+ f"and optionally full_q_cnt/full_q_idx). Regenerate the backward BlockMask with "
392
+ f"BLOCK_SIZE=({subtile_factor * m_block_size}, {n_block_size})."
393
+ ),
394
+ )
395
+ return normalized_tensors, get_block_sparse_broadcast_pattern(normalized_tensors)
396
+
397
+
398
+ def to_cute_block_sparse_tensors(
399
+ tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
400
+ ) -> BlockSparseTensors | None:
401
+ """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
402
+ if not is_block_sparsity_enabled(tensors):
403
+ return None
404
+ (
405
+ mask_block_cnt,
406
+ mask_block_idx,
407
+ full_block_cnt,
408
+ full_block_idx,
409
+ *_,
410
+ ) = tensors
411
+
412
+ (
413
+ mask_block_cnt_tensor,
414
+ mask_block_idx_tensor,
415
+ ) = [
416
+ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
417
+ for t in (mask_block_cnt, mask_block_idx)
418
+ ]
419
+ (
420
+ full_block_cnt_tensor,
421
+ full_block_idx_tensor,
422
+ ) = [
423
+ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)
424
+ if t is not None
425
+ else None
426
+ for t in (full_block_cnt, full_block_idx)
427
+ ]
428
+
429
+ return BlockSparseTensors(
430
+ mask_block_cnt_tensor,
431
+ mask_block_idx_tensor,
432
+ full_block_cnt_tensor,
433
+ full_block_idx_tensor,
434
+ )
435
+
436
+
437
+ def fast_sampling(mask_mod):
438
+ """Convenience decorator to mark mask_mod as safe for 5-point fast sampling"""
439
+ mask_mod.use_fast_sampling = True
440
+ return mask_mod
build/torch-cuda/cache_utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Manage Ahead-of-Time (AOT) compiled kernels
2
+ import fcntl
3
+ import hashlib
4
+ import logging
5
+ import os
6
+ import pickle
7
+ import sys
8
+ import tempfile
9
+ import time
10
+ from distutils.ccompiler import CCompiler, new_compiler
11
+ from functools import lru_cache
12
+ from getpass import getuser
13
+ from pathlib import Path
14
+ from typing import Hashable, TypeAlias
15
+
16
+ import cutlass
17
+ import cutlass.cute as cute
18
+ import tvm_ffi
19
+ from cutlass.cutlass_dsl import JitCompiledFunction
20
+
21
+ CompileKeyType: TypeAlias = tuple[Hashable, ...]
22
+ CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logger.addHandler(logging.StreamHandler())
26
+ logger.setLevel(logging.WARNING)
27
+
28
+
29
+ # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
30
+ CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
31
+
32
+
33
+ # Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is
34
+ # `/tmp/${USER}/flash_attention_cute_dsl_cache``
35
+ CUTE_DSL_CACHE_DIR: str | None = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_DIR", None)
36
+
37
+
38
+ def get_cache_path() -> Path:
39
+ if CUTE_DSL_CACHE_DIR is not None:
40
+ cache_dir = Path(CUTE_DSL_CACHE_DIR)
41
+ else:
42
+ cache_dir = Path(tempfile.gettempdir()) / getuser() / "flash_attention_cute_dsl_cache"
43
+ cache_dir.mkdir(parents=True, exist_ok=True)
44
+ return cache_dir
45
+
46
+
47
+ @lru_cache(maxsize=1)
48
+ def _compute_source_fingerprint() -> str:
49
+ """
50
+ Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint.
51
+
52
+ The fingerprint changes whenever:
53
+ - Any .py file under flash_attn/cute is added, removed, renamed, or modified.
54
+ - The Python minor version changes (e.g. 3.13 -> 3.14).
55
+ - The cutlass or tvm_ffi package version changes.
56
+
57
+ Computed once per process and cached.
58
+ """
59
+ cute_root = Path(__file__).resolve().parent
60
+ h = hashlib.sha256()
61
+
62
+ h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode())
63
+ h.update(f"cutlass={cutlass.__version__}".encode())
64
+ h.update(f"tvm_ffi={tvm_ffi.__version__}".encode())
65
+
66
+ for src in sorted(cute_root.rglob("*.py")):
67
+ h.update(src.relative_to(cute_root).as_posix().encode())
68
+ content = src.read_bytes()
69
+ h.update(len(content).to_bytes(8, "little"))
70
+ h.update(content)
71
+
72
+ return h.hexdigest()
73
+
74
+
75
+ class FileLock:
76
+ """Context manager for advisory file locks using fcntl.flock.
77
+
78
+ Supports exclusive (write) and shared (read) locks.
79
+ Always blocks with polling until the lock is acquired or timeout is reached.
80
+
81
+ Usage:
82
+ with FileLock(lock_path, exclusive=True, timeout=15, label="abc"):
83
+ # do work under lock
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ lock_path: Path,
89
+ exclusive: bool,
90
+ timeout: float = 15,
91
+ label: str = "",
92
+ ):
93
+ """
94
+ Args:
95
+ lock_path: Path to the lock file on disk.
96
+ exclusive: True for exclusive (write) lock, False for shared (read) lock.
97
+ timeout: Max seconds to wait for lock acquisition before raising RuntimeError.
98
+ label: Optional human-readable label for error messages.
99
+ """
100
+ self.lock_path: Path = lock_path
101
+ self.exclusive: bool = exclusive
102
+ self.timeout: float = timeout
103
+ self.label: str = label
104
+ self._fd: int = -1
105
+
106
+ @property
107
+ def _lock_label(self) -> str:
108
+ kind = "exclusive" if self.exclusive else "shared"
109
+ return f"{kind} {self.label}" if self.label else kind
110
+
111
+ def __enter__(self) -> "FileLock":
112
+ open_flags = (
113
+ os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT
114
+ )
115
+ lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH
116
+
117
+ self._fd = os.open(str(self.lock_path), open_flags)
118
+
119
+ deadline = time.monotonic() + self.timeout
120
+ acquired = False
121
+ while time.monotonic() < deadline:
122
+ try:
123
+ fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB)
124
+ acquired = True
125
+ break
126
+ except OSError:
127
+ time.sleep(0.1)
128
+ if not acquired:
129
+ os.close(self._fd)
130
+ self._fd = None
131
+ raise RuntimeError(
132
+ f"Timed out after {self.timeout}s waiting for "
133
+ f"{self._lock_label} lock: {self.lock_path}"
134
+ )
135
+
136
+ return self
137
+
138
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
139
+ if self._fd is not None:
140
+ fcntl.flock(self._fd, fcntl.LOCK_UN)
141
+ os.close(self._fd)
142
+ self._fd = None
143
+
144
+
145
+ class JITCache:
146
+ """
147
+ In-memory cache for compiled functions.
148
+ """
149
+
150
+ def __init__(self):
151
+ self.cache: dict[CompileKeyType, CallableFunction] = {}
152
+
153
+ def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
154
+ self.cache[key] = fn
155
+
156
+ def __getitem__(self, key: CompileKeyType) -> CallableFunction:
157
+ return self.cache[key]
158
+
159
+ def __contains__(self, key: CompileKeyType) -> bool:
160
+ return key in self.cache
161
+
162
+ def clear(self) -> None:
163
+ """
164
+ Clear in-memory cache of compiled functions
165
+ """
166
+ self.cache.clear()
167
+
168
+
169
+ class JITPersistentCache(JITCache):
170
+ """
171
+ In-memory cache for compiled functions, which is also backed by persistent storage.
172
+ Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True
173
+ """
174
+
175
+ EXPORT_FUNCTION_PREFIX = "func"
176
+ LOCK_TIMEOUT_SECONDS = 15
177
+
178
+ _compiler: CCompiler | None = None
179
+
180
+ def __init__(self, cache_path: Path):
181
+ super().__init__()
182
+ cache_path.mkdir(parents=True, exist_ok=True)
183
+ self.cache_path: Path = cache_path
184
+
185
+ def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
186
+ JITCache.__setitem__(self, key, fn)
187
+ self._try_export_to_storage(key, fn)
188
+
189
+ def __getitem__(self, key: CompileKeyType) -> CallableFunction:
190
+ # Use __contains__ to try populating in-memory cache with persistent storage
191
+ self.__contains__(key)
192
+ return JITCache.__getitem__(self, key)
193
+
194
+ def __contains__(self, key: CompileKeyType) -> bool:
195
+ # Checks in-memory cache first, then tries loading from storage.
196
+ # When returning True, guarantees the in-memory cache is populated.
197
+ if JITCache.__contains__(self, key):
198
+ return True
199
+ return self._try_load_from_storage(key)
200
+
201
+ def _try_load_from_storage(self, key: CompileKeyType) -> bool:
202
+ """
203
+ Try to load a function from persistent storage into in-memory cache.
204
+ Returns True if loaded successfully, False if not found on disk.
205
+ Holds a shared lock during loading to prevent concurrent writes.
206
+ """
207
+ sha256_hex = self._key_to_hash(key)
208
+ so_path = self.cache_path / f"{sha256_hex}.so"
209
+ with FileLock(
210
+ self._lock_path(sha256_hex),
211
+ exclusive=False,
212
+ timeout=self.LOCK_TIMEOUT_SECONDS,
213
+ label=sha256_hex,
214
+ ):
215
+ if so_path.exists():
216
+ logger.debug(
217
+ "Loading compiled function from disk: %s", so_path
218
+ )
219
+ m = cute.runtime.load_module(
220
+ str(so_path), enable_tvm_ffi=True
221
+ )
222
+ fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
223
+ JITCache.__setitem__(self, key, fn)
224
+ return True
225
+ else:
226
+ logger.debug(
227
+ "Cache miss on disk for key hash %s", sha256_hex
228
+ )
229
+ return False
230
+
231
+ def _try_export_to_storage(
232
+ self, key: CompileKeyType, fn: JitCompiledFunction
233
+ ) -> None:
234
+ """Export a compiled function to persistent storage under exclusive lock."""
235
+ sha256_hex = self._key_to_hash(key)
236
+ with FileLock(
237
+ self._lock_path(sha256_hex),
238
+ exclusive=True,
239
+ timeout=self.LOCK_TIMEOUT_SECONDS,
240
+ label=sha256_hex,
241
+ ):
242
+ so_path = self.cache_path / f"{sha256_hex}.so"
243
+ if so_path.exists():
244
+ # Another process already exported.
245
+ logger.debug(
246
+ "Skipping export, already on disk: %s", so_path
247
+ )
248
+ return
249
+ obj_path = self.cache_path / f"{sha256_hex}.o"
250
+ logger.debug(
251
+ "Exporting compiled function to disk: %s", so_path
252
+ )
253
+ fn.export_to_c(
254
+ object_file_path=str(obj_path),
255
+ function_name=self.EXPORT_FUNCTION_PREFIX,
256
+ )
257
+ # TODO: as of cutedsl 4.4.0, `export_to_c` only supports exporting
258
+ # "relocatable" .o files. But tvm_ffi expects "shared library" .so
259
+ # files. Link ourselves to workaround.
260
+ if JITPersistentCache._compiler is None:
261
+ JITPersistentCache._compiler = new_compiler()
262
+ JITPersistentCache._compiler.link_shared_object(
263
+ [str(obj_path)], str(so_path)
264
+ )
265
+ obj_path.unlink()
266
+ logger.debug(
267
+ "Successfully exported compiled function to disk: %s", so_path
268
+ )
269
+
270
+ def _key_to_hash(self, key: CompileKeyType) -> str:
271
+ return hashlib.sha256(pickle.dumps(key)).hexdigest()
272
+
273
+ def _lock_path(self, sha256_hex: str) -> Path:
274
+ return self.cache_path / f"{sha256_hex}.lock"
275
+
276
+ def clear(self) -> None:
277
+ """
278
+ Not only clear the in-memory cache. Also purge persistent compilation cache.
279
+ """
280
+ logger.debug(
281
+ "Clearing persistent cache at %s", self.cache_path
282
+ )
283
+ super().clear()
284
+ for child in self.cache_path.iterdir():
285
+ child.unlink()
286
+
287
+
288
+ def get_jit_cache(name: str | None = None) -> JITCache:
289
+ """
290
+ JIT cache factory.
291
+ `name` is an optional identifier to create subdirectories to manage cache.
292
+
293
+ When persistent caching is enabled, artifacts are namespaced under a
294
+ source fingerprint directory so that code or dependency changes
295
+ automatically invalidate stale entries.
296
+ """
297
+ if CUTE_DSL_CACHE_ENABLED:
298
+ path = get_cache_path() / _compute_source_fingerprint()
299
+ if name:
300
+ path = path / name
301
+ logger.debug(
302
+ "Creating persistent JIT cache at %s", path
303
+ )
304
+ return JITPersistentCache(path)
305
+ else:
306
+ logger.debug("Persistent cache disabled, using in-memory JIT cache")
307
+ return JITCache()
build/torch-cuda/compute_block_sparsity.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Callable, Optional, Tuple
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ import torch
7
+ from cutlass import Boolean, Int8, Int32, const_expr
8
+
9
+ from .block_sparsity import (
10
+ BlockSparseTensors,
11
+ BlockSparseTensorsTorch,
12
+ to_cute_block_sparse_tensors,
13
+ )
14
+ from .utils import hash_callable, scalar_to_ssa, ssa_to_scalar
15
+ from .seqlen_info import SeqlenInfoQK
16
+
17
+
18
+ class BlockSparsityKernel:
19
+ """Block sparsity kernel for FlexAttention.
20
+
21
+ This kernel computes `mask_mod` for every token of each block
22
+ to determine if an n block is full, masked, or neither.
23
+
24
+ Writes block counts and indices to a BlockSparseTensors object.
25
+
26
+ When use_fast_sampling=True, uses 5-point sampling (4 corners + center)
27
+ which is much faster but only suitable for masks where this is sufficient.
28
+
29
+ TODO:
30
+ - optimize mask_mod evaluation
31
+ - varlen support
32
+ - transposed tensors for bwd pass
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ mask_mod: Callable,
38
+ tile_mn: Tuple[int, int],
39
+ compute_full_blocks: bool = True,
40
+ use_aux_tensors: bool = False,
41
+ use_fast_sampling: bool = False,
42
+ ):
43
+ self.mask_mod = mask_mod
44
+ self.tile_mn = tile_mn
45
+ self.compute_full_blocks = compute_full_blocks
46
+ self.use_aux_tensors = use_aux_tensors
47
+ self.use_fast_sampling = use_fast_sampling
48
+
49
+ @cute.jit
50
+ def __call__(
51
+ self,
52
+ blocksparse_tensors: BlockSparseTensors,
53
+ seqlen_q: Int32,
54
+ seqlen_k: Int32,
55
+ aux_tensors: Optional[list] = None,
56
+ ):
57
+ self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors
58
+
59
+ if const_expr(self.compute_full_blocks):
60
+ assert self.full_cnt is not None and self.full_idx is not None, (
61
+ "full block tensors must be provided when computing full blocks"
62
+ )
63
+
64
+ batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape
65
+ # launch 1 CTA per m block
66
+ grid = [num_m_blocks, num_heads, batch_size]
67
+
68
+ if const_expr(self.use_fast_sampling):
69
+ num_threads = 5
70
+ self.num_warps = 1
71
+ else:
72
+ num_threads = self.tile_mn[0]
73
+ self.num_warps = (num_threads + 32 - 1) // 32
74
+
75
+ self.kernel(
76
+ self.mask_cnt,
77
+ self.mask_idx,
78
+ self.full_cnt,
79
+ self.full_idx,
80
+ num_n_blocks,
81
+ seqlen_q,
82
+ seqlen_k,
83
+ aux_tensors,
84
+ ).launch(grid=grid, block=[num_threads, 1, 1])
85
+
86
+ @cute.kernel
87
+ def kernel(
88
+ self,
89
+ mask_cnt: cute.Tensor,
90
+ mask_idx: cute.Tensor,
91
+ full_cnt: cute.Tensor,
92
+ full_idx: cute.Tensor,
93
+ num_n_blocks: Int32,
94
+ seqlen_q: Int32,
95
+ seqlen_k: Int32,
96
+ aux_tensors: Optional[list] = None,
97
+ ):
98
+ tidx, _, _ = cute.arch.thread_idx()
99
+ warp_idx = cute.arch.warp_idx()
100
+ lane_id = cute.arch.lane_idx()
101
+ m_block, head_idx, batch_idx = cute.arch.block_idx()
102
+
103
+ ssa = partial(scalar_to_ssa, dtype=Int32)
104
+
105
+ seqlen = SeqlenInfoQK.create(
106
+ batch_idx,
107
+ seqlen_q,
108
+ seqlen_k,
109
+ mCuSeqlensQ=None,
110
+ mCuSeqlensK=None,
111
+ mSeqUsedQ=None,
112
+ mSeqUsedK=None,
113
+ )
114
+
115
+ @cute.struct
116
+ class SharedStorage:
117
+ reduction_buffer_smem: cute.struct.Align[
118
+ cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024
119
+ ]
120
+
121
+ smem = cutlass.utils.SmemAllocator()
122
+ storage = smem.allocate(SharedStorage, 16)
123
+
124
+ reduction_buffer = storage.reduction_buffer_smem.get_tensor(
125
+ cute.make_layout((self.num_warps, 2))
126
+ )
127
+
128
+ num_mask_blocks = Int32(0)
129
+ num_full_blocks = Int32(0)
130
+
131
+ for n_block in cutlass.range(num_n_blocks, unroll_full=True):
132
+ m_base = m_block * self.tile_mn[0]
133
+ n_base = n_block * self.tile_mn[1]
134
+
135
+ if const_expr(self.use_fast_sampling):
136
+ # Fast path: 5-point sampling (4 corners + center)
137
+ # Clamps OOB indices to nearest in bounds.
138
+ thread_result = Boolean(False)
139
+ thread_is_valid = Boolean(False)
140
+ q_idx = Int32(0)
141
+ kv_idx = Int32(0)
142
+
143
+ if tidx == 0:
144
+ # Top-left corner (0, 0); always in bounds
145
+ q_idx = m_base
146
+ kv_idx = n_base
147
+ elif tidx == 1:
148
+ # Top-right corner
149
+ q_idx = m_base
150
+ kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
151
+ elif tidx == 2:
152
+ # Bottom-left corner
153
+ q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
154
+ kv_idx = n_base
155
+ elif tidx == 3:
156
+ # Bottom-right corner
157
+ q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)
158
+ kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)
159
+ elif tidx == 4:
160
+ # Center point
161
+ q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2
162
+ kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2
163
+ else:
164
+ thread_is_valid = Boolean(False)
165
+
166
+ # Check bounds and determine if this thread has a valid index pair
167
+ if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k:
168
+ thread_is_valid = Boolean(True)
169
+ q_idx_ssa = ssa(q_idx)
170
+ kv_idx_ssa = ssa(kv_idx)
171
+ thread_result = ssa_to_scalar(
172
+ self.mask_mod(
173
+ ssa(batch_idx),
174
+ ssa(head_idx),
175
+ q_idx_ssa,
176
+ kv_idx_ssa,
177
+ seqlen,
178
+ aux_tensors,
179
+ )
180
+ )
181
+ else:
182
+ thread_is_valid = Boolean(False)
183
+
184
+ # Use vote_any_sync to see if any valid thread found unmasked or masked
185
+ # Only count results from threads that checked valid indices
186
+ has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid)
187
+ has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid)
188
+
189
+ else:
190
+ # Full path: check all elements in the block
191
+ # Track if this thread's row has any masked or unmasked elements
192
+ thread_has_unmasked = Boolean(False)
193
+ thread_has_masked = Boolean(False)
194
+ thread_is_valid = Boolean(False)
195
+
196
+ # Each thread handles 1 row
197
+ q_idx = m_base + tidx
198
+ kv_idx = Int32(0)
199
+ if tidx < self.tile_mn[0] and q_idx < seqlen_q:
200
+ thread_is_valid = Boolean(True)
201
+ q_idx_ssa = ssa(q_idx)
202
+
203
+ # Loop over all columns in this row
204
+ for c in cutlass.range(self.tile_mn[1], unroll_full=True):
205
+ kv_idx = n_base + c
206
+ kv_idx_ssa = ssa(kv_idx)
207
+
208
+ # Only check elements within valid sequence bounds
209
+ if kv_idx < seqlen_k:
210
+ # Direct scalar call
211
+ mask_val = ssa_to_scalar(
212
+ self.mask_mod(
213
+ ssa(batch_idx),
214
+ ssa(head_idx),
215
+ q_idx_ssa,
216
+ kv_idx_ssa,
217
+ seqlen,
218
+ aux_tensors,
219
+ )
220
+ )
221
+
222
+ # Update tracking flags
223
+ if mask_val:
224
+ thread_has_unmasked = Boolean(True)
225
+ else:
226
+ thread_has_masked = Boolean(True)
227
+
228
+ # Block-level reduction to combine results across all threads
229
+ # Only count votes from threads that checked valid indices
230
+ warp_has_unmasked_mask = cute.arch.vote_any_sync(
231
+ thread_has_unmasked & thread_is_valid
232
+ )
233
+ warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid)
234
+
235
+ # lane 0 writes the ballot mask to shared memory
236
+ lane_id = tidx % 32
237
+ if lane_id == 0:
238
+ # Store as Int8
239
+ reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0)
240
+ reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0)
241
+
242
+ cute.arch.sync_threads()
243
+
244
+ # Thread 0 ORs all warp results together
245
+ has_unmasked = Boolean(False)
246
+ has_masked = Boolean(False)
247
+ if tidx == 0:
248
+ for w in cutlass.range(self.num_warps):
249
+ if reduction_buffer[w, 0]:
250
+ has_unmasked = Boolean(True)
251
+ if reduction_buffer[w, 1]:
252
+ has_masked = Boolean(True)
253
+
254
+ # Only thread 0 updates the output arrays (common to both paths)
255
+ if tidx == 0:
256
+ # Block classification based on what we found:
257
+ # - If has_masked and has_unmasked: partial block (needs masking)
258
+ # - If only has_unmasked: full block (no masking needed)
259
+ # - If only has_masked: skip this block entirely
260
+ is_partial = Boolean(has_masked and has_unmasked)
261
+ is_full = Boolean(has_unmasked and (not has_masked))
262
+
263
+ if is_partial:
264
+ mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block
265
+ num_mask_blocks += 1
266
+ elif is_full and const_expr(self.compute_full_blocks):
267
+ full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block
268
+ num_full_blocks += 1
269
+
270
+ # Only thread 0 writes back the counts
271
+ if tidx == 0:
272
+ mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks
273
+ if const_expr(self.compute_full_blocks):
274
+ full_cnt[batch_idx, head_idx, m_block] = num_full_blocks
275
+
276
+
277
+ def compute_block_sparsity(
278
+ tile_m,
279
+ tile_n,
280
+ batch_size,
281
+ num_heads,
282
+ seqlen_q,
283
+ seqlen_k,
284
+ mask_mod: Callable,
285
+ aux_tensors: Optional[list], # list[cute.Tensor]
286
+ device,
287
+ compute_full_blocks: bool = True,
288
+ use_fast_sampling: bool = False,
289
+ ) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]:
290
+ """
291
+ Computes block sparsity for a given `mask_mod`.
292
+
293
+ Args:
294
+ tile_m: The tile size for the m dimension.
295
+ tile_n: The tile size for the n dimension.
296
+ batch_size: The batch size.
297
+ num_heads: The number of heads.
298
+ seqlen_q: The sequence length for the query.
299
+ seqlen_k: The sequence length for the key.
300
+ mask_mod: The `mask_mod` callable to use.
301
+ aux_tensors: A list of auxiliary tensors.
302
+ device: The device to use.
303
+ compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed.
304
+ use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient.
305
+
306
+ Returns:
307
+ A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`.
308
+ """
309
+ # Check if mask_mod is marked as suitable for 5-point fast sampling
310
+ use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling)
311
+
312
+ num_m_blocks = (seqlen_q + tile_m - 1) // tile_m
313
+ num_n_blocks = (seqlen_k + tile_n - 1) // tile_n
314
+
315
+ mask_block_cnt = torch.zeros(
316
+ (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32
317
+ )
318
+ mask_block_idx = torch.zeros(
319
+ (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
320
+ )
321
+ full_block_cnt = (
322
+ torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32)
323
+ if compute_full_blocks
324
+ else None
325
+ )
326
+ full_block_idx = (
327
+ torch.zeros(
328
+ (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32
329
+ )
330
+ if compute_full_blocks
331
+ else None
332
+ )
333
+
334
+ blocksparse_tensors_torch = BlockSparseTensorsTorch(
335
+ mask_block_cnt=mask_block_cnt,
336
+ mask_block_idx=mask_block_idx,
337
+ full_block_cnt=full_block_cnt,
338
+ full_block_idx=full_block_idx,
339
+ block_size=(tile_m, tile_n),
340
+ )
341
+
342
+ mask_mod_hash = hash_callable(mask_mod)
343
+ blocksparse_tensors = to_cute_block_sparse_tensors(
344
+ blocksparse_tensors_torch, enable_tvm_ffi=True
345
+ )
346
+
347
+ compile_key = (
348
+ tile_m,
349
+ tile_n,
350
+ mask_mod_hash,
351
+ compute_full_blocks,
352
+ aux_tensors is not None,
353
+ use_fast_sampling,
354
+ )
355
+ if compile_key not in compute_block_sparsity.compile_cache:
356
+ kernel = BlockSparsityKernel(
357
+ mask_mod,
358
+ tile_mn=(tile_m, tile_n),
359
+ compute_full_blocks=compute_full_blocks,
360
+ use_aux_tensors=aux_tensors is not None,
361
+ use_fast_sampling=use_fast_sampling,
362
+ )
363
+
364
+ compute_block_sparsity.compile_cache[compile_key] = cute.compile(
365
+ kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi"
366
+ )
367
+
368
+ compute_block_sparsity.compile_cache[compile_key](
369
+ blocksparse_tensors_torch[:4],
370
+ seqlen_q,
371
+ seqlen_k,
372
+ aux_tensors,
373
+ )
374
+
375
+ return blocksparse_tensors, blocksparse_tensors_torch
376
+
377
+
378
+ compute_block_sparsity.compile_cache = {}
build/torch-cuda/copy_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional, Type, Callable
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Float32, Int32, const_expr
9
+ from cutlass.cute.nvgpu import cpasync
10
+ import cutlass.utils.blackwell_helpers as sm100_utils
11
+ from cutlass.cutlass_dsl import T, dsl_user_op
12
+ from cutlass._mlir.dialects import llvm
13
+ import cutlass.pipeline
14
+
15
+
16
+ @dsl_user_op
17
+ def cvt_copy(
18
+ atom: cute.CopyAtom,
19
+ src: cute.Tensor,
20
+ dst: cute.Tensor,
21
+ *,
22
+ pred: Optional[cute.Tensor] = None,
23
+ loc=None,
24
+ ip=None,
25
+ **kwargs,
26
+ ) -> None:
27
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
28
+ if const_expr(src.element_type != dst.element_type):
29
+ src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip)
30
+ src_cvt.store(src.load().to(dst.element_type))
31
+ src = src_cvt
32
+ cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
33
+
34
+
35
+ @dsl_user_op
36
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
37
+ dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
38
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
39
+ return dst
40
+
41
+
42
+ @dsl_user_op
43
+ def get_copy_atom(
44
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
45
+ ) -> cute.CopyAtom:
46
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
47
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
48
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
49
+
50
+
51
+ @dsl_user_op
52
+ def make_tmem_copy(
53
+ tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None
54
+ ) -> cute.CopyAtom:
55
+ num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom)
56
+ assert num_dp == 32
57
+ assert num_bits == 32
58
+ tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),)
59
+ layout_tv = cute.make_layout(
60
+ ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg))
61
+ )
62
+ return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn)
63
+
64
+
65
+ @dsl_user_op
66
+ def copy(
67
+ src: cute.Tensor,
68
+ dst: cute.Tensor,
69
+ *,
70
+ pred: Optional[cute.Tensor] = None,
71
+ num_copy_elems: int = 1,
72
+ is_async: bool = False,
73
+ loc=None,
74
+ ip=None,
75
+ **kwargs,
76
+ ) -> None:
77
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
78
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
79
+
80
+
81
+ def tiled_copy_1d(
82
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
83
+ ) -> cute.TiledCopy:
84
+ num_copy_bits = num_copy_elems * dtype.width
85
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
86
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
87
+ thr_layout = cute.make_layout(num_threads)
88
+ val_layout = cute.make_layout(num_copy_elems)
89
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
90
+
91
+
92
+ def tiled_copy_2d(
93
+ dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
94
+ ) -> cute.TiledCopy:
95
+ num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
96
+ copy_elems = num_copy_bits // dtype.width
97
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
98
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
99
+ gmem_threads_per_row = major_mode_size // copy_elems
100
+ assert num_threads % gmem_threads_per_row == 0
101
+ thr_layout = cute.make_ordered_layout(
102
+ (num_threads // gmem_threads_per_row, gmem_threads_per_row),
103
+ order=(1, 0),
104
+ )
105
+ val_layout = cute.make_layout((1, copy_elems))
106
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
107
+
108
+
109
+ @dsl_user_op
110
+ def atomic_add_fp32x4(
111
+ a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None
112
+ ) -> None:
113
+ gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
114
+ # cache_hint = cutlass.Int64(0x12F0000000000000)
115
+ llvm.inline_asm(
116
+ None,
117
+ [
118
+ gmem_ptr_i64,
119
+ Float32(a).ir_value(loc=loc, ip=ip),
120
+ Float32(b).ir_value(loc=loc, ip=ip),
121
+ Float32(c).ir_value(loc=loc, ip=ip),
122
+ Float32(d).ir_value(loc=loc, ip=ip),
123
+ ],
124
+ # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
125
+ "{\n\t"
126
+ # ".reg .b128 abcd;\n\t"
127
+ # "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
128
+ ".reg .v4 .f32 abcd;\n\t"
129
+ # "mov.b128 abcd, {$1, $2, $3, $4};\n\t"
130
+ "mov.f32 abcd.x, $1;\n\t"
131
+ "mov.f32 abcd.y, $2;\n\t"
132
+ "mov.f32 abcd.z, $3;\n\t"
133
+ "mov.f32 abcd.w, $4;\n\t"
134
+ "red.global.add.v4.f32 [$0], abcd;\n\t"
135
+ # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t"
136
+ "}\n",
137
+ # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
138
+ # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
139
+ "l,f,f,f,f",
140
+ # "l,f,l",
141
+ has_side_effects=True,
142
+ is_align_stack=False,
143
+ asm_dialect=llvm.AsmDialect.AD_ATT,
144
+ )
145
+
146
+
147
+ @dsl_user_op
148
+ def set_block_rank(
149
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
150
+ ) -> Int32:
151
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
152
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
153
+ return Int32(
154
+ llvm.inline_asm(
155
+ T.i32(),
156
+ [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
157
+ "mapa.shared::cluster.u32 $0, $1, $2;",
158
+ "=r,r,r",
159
+ has_side_effects=False,
160
+ is_align_stack=False,
161
+ asm_dialect=llvm.AsmDialect.AD_ATT,
162
+ )
163
+ )
164
+
165
+
166
+ @dsl_user_op
167
+ def store_shared_remote_fp32x4(
168
+ a: Float32,
169
+ b: Float32,
170
+ c: Float32,
171
+ d: Float32,
172
+ smem_ptr: cute.Pointer,
173
+ mbar_ptr: cute.Pointer,
174
+ peer_cta_rank_in_cluster: Int32,
175
+ *,
176
+ loc=None,
177
+ ip=None,
178
+ ) -> None:
179
+ remote_smem_ptr_i32 = set_block_rank(
180
+ smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
181
+ ).ir_value()
182
+ remote_mbar_ptr_i32 = set_block_rank(
183
+ mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
184
+ ).ir_value()
185
+ llvm.inline_asm(
186
+ None,
187
+ [
188
+ remote_smem_ptr_i32,
189
+ remote_mbar_ptr_i32,
190
+ Float32(a).ir_value(loc=loc, ip=ip),
191
+ Float32(b).ir_value(loc=loc, ip=ip),
192
+ Float32(c).ir_value(loc=loc, ip=ip),
193
+ Float32(d).ir_value(loc=loc, ip=ip),
194
+ ],
195
+ "{\n\t"
196
+ ".reg .v4 .f32 abcd;\n\t"
197
+ "mov.f32 abcd.x, $2;\n\t"
198
+ "mov.f32 abcd.y, $3;\n\t"
199
+ "mov.f32 abcd.z, $4;\n\t"
200
+ "mov.f32 abcd.w, $5;\n\t"
201
+ "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t"
202
+ "}\n",
203
+ "r,r,f,f,f,f",
204
+ has_side_effects=True,
205
+ is_align_stack=False,
206
+ asm_dialect=llvm.AsmDialect.AD_ATT,
207
+ )
208
+
209
+
210
+ @dsl_user_op
211
+ def cpasync_bulk_s2cluster(
212
+ smem_src_ptr: cute.Pointer,
213
+ smem_dst_ptr: cute.Pointer,
214
+ mbar_ptr: cute.Pointer,
215
+ size: int | Int32,
216
+ peer_cta_rank_in_cluster: Int32,
217
+ *,
218
+ loc=None,
219
+ ip=None,
220
+ ):
221
+ smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value()
222
+ smem_dst_ptr_i32 = set_block_rank(
223
+ smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
224
+ ).ir_value()
225
+ mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
226
+ llvm.inline_asm(
227
+ None,
228
+ [
229
+ smem_dst_ptr_i32,
230
+ smem_src_ptr_i32,
231
+ mbar_ptr_i32,
232
+ Int32(size).ir_value(loc=loc, ip=ip),
233
+ ],
234
+ "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];",
235
+ "r,r,r,r",
236
+ has_side_effects=True,
237
+ is_align_stack=False,
238
+ asm_dialect=llvm.AsmDialect.AD_ATT,
239
+ )
240
+
241
+
242
+ @dsl_user_op
243
+ def cpasync_bulk_g2s(
244
+ gmem_ptr: cute.Pointer,
245
+ smem_ptr: cute.Pointer,
246
+ tma_bar_ptr: cute.Pointer,
247
+ size: int | Int32,
248
+ *,
249
+ loc=None,
250
+ ip=None,
251
+ ):
252
+ gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
253
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
254
+ mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value()
255
+ llvm.inline_asm(
256
+ None,
257
+ [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()],
258
+ "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];",
259
+ "l,r,r,r",
260
+ has_side_effects=True,
261
+ is_align_stack=False,
262
+ asm_dialect=llvm.AsmDialect.AD_ATT,
263
+ )
264
+
265
+
266
+ @dsl_user_op
267
+ def cpasync_reduce_bulk_add_f32(
268
+ smem_ptr: cute.Pointer,
269
+ gmem_ptr: cute.Pointer,
270
+ store_bytes: int | Int32,
271
+ *,
272
+ loc=None,
273
+ ip=None,
274
+ ):
275
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
276
+ # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
277
+ llvm.inline_asm(
278
+ None,
279
+ [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
280
+ "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
281
+ "l,r,r",
282
+ # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
283
+ # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
284
+ # "l,r,r,l",
285
+ has_side_effects=True,
286
+ is_align_stack=False,
287
+ asm_dialect=llvm.AsmDialect.AD_ATT,
288
+ )
289
+
290
+
291
+ def cpasync_bulk_get_copy_fn(
292
+ src_tensor: cute.Tensor,
293
+ dst_tensor: cute.Tensor,
294
+ single_stage: bool = False,
295
+ **kwargs,
296
+ ) -> Callable:
297
+ # src_is_smem = const_expr(
298
+ # isinstance(src_tensor.iterator, cute.Pointer)
299
+ # and src_tensor.memspace == cute.AddressSpace.smem
300
+ # )
301
+ group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
302
+ group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
303
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
304
+ src = cute.group_modes(src_tensor, 0, group_rank_src)
305
+ dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
306
+
307
+ def copy_bulk(src_idx, dst_idx, **new_kwargs):
308
+ size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8)
309
+ cpasync_bulk_g2s(
310
+ src[None, src_idx].iterator,
311
+ dst[None, dst_idx].iterator,
312
+ size=size,
313
+ **new_kwargs,
314
+ **kwargs,
315
+ )
316
+
317
+ def copy_bulk_single_stage(**new_kwargs):
318
+ size = const_expr(cute.size(src.shape) * src.element_type.width // 8)
319
+ cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs)
320
+
321
+ return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
322
+
323
+
324
+ def tma_get_copy_fn(
325
+ atom: cute.CopyAtom,
326
+ cta_coord: cute.Coord,
327
+ cta_layout: cute.Layout,
328
+ src_tensor: cute.Tensor,
329
+ dst_tensor: cute.Tensor,
330
+ filter_zeros: bool = False,
331
+ single_stage: bool = False,
332
+ **kwargs,
333
+ ) -> Callable:
334
+ src_is_smem = const_expr(
335
+ isinstance(src_tensor.iterator, cute.Pointer)
336
+ and src_tensor.memspace == cute.AddressSpace.smem
337
+ )
338
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
339
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
340
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
341
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
342
+ s, g = cpasync.tma_partition(
343
+ atom,
344
+ cta_coord,
345
+ cta_layout,
346
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
347
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
348
+ )
349
+ if const_expr(filter_zeros):
350
+ s = cute.filter_zeros(s)
351
+ g = cute.filter_zeros(g)
352
+ src, dst = (s, g) if src_is_smem else (g, s)
353
+
354
+ def copy_tma(src_idx, dst_idx, **new_kwargs):
355
+ cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
356
+
357
+ def copy_tma_single_stage(**new_kwargs):
358
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs)
359
+
360
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
361
+
362
+
363
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
364
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
365
+ copy(
366
+ src_idx=src_idx,
367
+ dst_idx=producer_state.index,
368
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
369
+ **new_kwargs,
370
+ )
371
+
372
+ return copy_fn
build/torch-cuda/cute_dsl_ptxas.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System ptxas replacement for CUTLASS DSL.
3
+ Environment variables:
4
+ CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
5
+ CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import re
11
+ import ctypes
12
+ import subprocess
13
+ from pathlib import Path
14
+
15
+ import cutlass
16
+
17
+
18
+ CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
19
+ VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
20
+
21
+ _original_load_cuda_library = None
22
+ _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
23
+
24
+
25
+ def _log(msg):
26
+ if VERBOSE:
27
+ print(f"[ptxas] {msg}", file=sys.stderr)
28
+
29
+
30
+ def _get_ptx(compiled_func) -> tuple[str, Path] | None:
31
+ """Find and read PTX file, stripping null bytes."""
32
+ func_name = getattr(compiled_func, "function_name", None)
33
+ if not func_name:
34
+ return None
35
+
36
+ dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
37
+ for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
38
+ content = ptx_path.read_text().rstrip("\x00")
39
+ if ".entry " in content and content.rstrip().endswith("}"):
40
+ _log(f"Found PTX: {ptx_path}")
41
+ return content, ptx_path
42
+ return None
43
+
44
+
45
+ def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
46
+ """Compile PTX to cubin using system ptxas."""
47
+ # Extract arch from PTX
48
+ match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
49
+ arch = match.group(1) if match else "sm_90a"
50
+
51
+ # Write stripped content back if needed
52
+ if ptx_path.read_text() != ptx_content:
53
+ ptx_path.write_text(ptx_content)
54
+
55
+ # Compile
56
+ cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
57
+ try:
58
+ assert CUTE_DSL_PTXAS_PATH is not None
59
+ result = subprocess.run(
60
+ [CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
61
+ capture_output=True,
62
+ text=True,
63
+ )
64
+ if result.returncode != 0:
65
+ raise RuntimeError(f"ptxas failed: {result.stderr}")
66
+
67
+ cubin_data = cubin_tmp.read_bytes()
68
+ _log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")
69
+
70
+ # Save cubin if CUTE_DSL_KEEP_CUBIN is set
71
+ if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
72
+ cubin_out = ptx_path.with_suffix(".cubin")
73
+ cubin_out.write_bytes(cubin_data)
74
+ _log(f"Saved: {cubin_out}")
75
+
76
+ return cubin_data
77
+ finally:
78
+ cubin_tmp.unlink(missing_ok=True)
79
+
80
+
81
+ def _patched_load_cuda_library(self):
82
+ """Replacement for _load_cuda_library that uses system ptxas."""
83
+
84
+ result = _get_ptx(self)
85
+ if not result:
86
+ _log("PTX not found, falling back to embedded ptxas")
87
+ return _original_load_cuda_library(self)
88
+
89
+ ptx_content, ptx_path = result
90
+
91
+ try:
92
+ cubin = _compile_ptx(ptx_path, ptx_content)
93
+ except Exception as e:
94
+ _log(f"Compilation failed ({e}), falling back to embedded ptxas")
95
+ return _original_load_cuda_library(self)
96
+
97
+ # Load cubin
98
+ import cuda.bindings.runtime as cuda_runtime
99
+
100
+ err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
101
+ if err != cuda_runtime.cudaError_t.cudaSuccess:
102
+ _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
103
+ return _original_load_cuda_library(self)
104
+
105
+ # Register kernels on all devices
106
+ _, cuda_load_to_device = self._get_cuda_init_and_load()
107
+ lib_ptr = ctypes.c_void_p(int(library))
108
+ dev_id = ctypes.c_int32(0)
109
+ err_val = ctypes.c_int32(0)
110
+ args = (ctypes.c_void_p * 3)(
111
+ ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
112
+ ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
113
+ ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
114
+ )
115
+
116
+ for dev in range(self.num_devices):
117
+ dev_id.value = dev
118
+ cuda_load_to_device(args)
119
+ if err_val.value != 0:
120
+ _log("cuda_load_to_device failed, falling back to embedded ptxas")
121
+ return _original_load_cuda_library(self)
122
+
123
+ _log(f"Loaded kernel from {ptx_path.name}")
124
+
125
+ # Delete PTX if user didn't originally want it kept
126
+ if not _user_wanted_ptx:
127
+ ptx_path.unlink(missing_ok=True)
128
+
129
+ return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
130
+
131
+
132
+ def patch():
133
+ """Install system ptxas hook. Call before importing cutlass."""
134
+ global _original_load_cuda_library, _user_wanted_ptx
135
+
136
+ assert CUTE_DSL_PTXAS_PATH is not None
137
+ if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
138
+ raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
139
+
140
+ # Track if user originally wanted PTX kept
141
+ _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
142
+ # os.environ['CUTE_DSL_KEEP_PTX'] = '1'
143
+ assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
144
+ "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
145
+ )
146
+
147
+ cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
148
+ _original_load_cuda_library = cls._load_cuda_library
149
+ cls._load_cuda_library = _patched_load_cuda_library
150
+ _log("Patch applied")
151
+ return
build/torch-cuda/cute_dsl_utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import os
4
+ import pathlib
5
+ from typing import Tuple
6
+ from functools import partial, lru_cache
7
+ from dataclasses import dataclass, fields
8
+
9
+ import torch
10
+
11
+ try:
12
+ from triton.tools.disasm import extract
13
+ except ImportError:
14
+ extract = None
15
+
16
+ import cutlass
17
+ import cutlass.cute as cute
18
+ from cutlass.base_dsl.typing import JitArgument
19
+ from cutlass.cutlass_dsl import NumericMeta
20
+ from cutlass.cute.runtime import from_dlpack
21
+
22
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
23
+
24
+
25
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
26
+ cute_compile_og = cute.compile
27
+
28
+
29
+ torch2cute_dtype_map = {
30
+ torch.float16: cutlass.Float16,
31
+ torch.bfloat16: cutlass.BFloat16,
32
+ torch.float32: cutlass.Float32,
33
+ }
34
+
35
+
36
+ @lru_cache
37
+ def get_max_active_clusters(cluster_size):
38
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
39
+
40
+
41
+ @lru_cache
42
+ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
43
+ return torch.cuda.get_device_capability(device)
44
+
45
+
46
+ @dataclass
47
+ class ArgumentsBase(JitArgument):
48
+ def __c_pointers__(self):
49
+ all_fields = [getattr(self, field.name) for field in fields(self)]
50
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
51
+ c_ptrs = []
52
+ for obj in non_constexpr_fields:
53
+ if hasattr(obj, "__c_pointers__"):
54
+ c_ptrs.extend(obj.__c_pointers__())
55
+ return c_ptrs
56
+
57
+ def __get_mlir_types__(self):
58
+ all_fields = [getattr(self, field.name) for field in fields(self)]
59
+ non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
60
+ types, self._values_pos = [], []
61
+ for obj in non_constexpr_fields:
62
+ if hasattr(obj, "__get_mlir_types__"):
63
+ obj_types = obj.__get_mlir_types__()
64
+ types.extend(obj_types)
65
+ self._values_pos.append(len(obj_types))
66
+ else:
67
+ self._values_pos.append(0)
68
+ return types
69
+
70
+ def __new_from_mlir_values__(self, values):
71
+ all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
72
+ constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
73
+ non_constexpr_fields = {
74
+ n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
75
+ }
76
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
77
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
78
+ values = values[n_items:]
79
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
80
+
81
+
82
+ def load_cubin_module_data_patched(cubin_data, filepath):
83
+ pathlib.Path(filepath).write_bytes(cubin_data)
84
+ return load_cubin_module_data_og(cubin_data)
85
+
86
+
87
+ def cute_compile_patched(*args, **kwargs):
88
+ """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
89
+ cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
90
+ if cubin_path is not None:
91
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
92
+ load_cubin_module_data_patched, filepath=cubin_path
93
+ )
94
+ output = cute_compile_og(*args, **kwargs)
95
+ if cubin_path is not None:
96
+ cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
97
+ if extract is not None:
98
+ sass = extract(cubin_path, None)
99
+ pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
100
+ return output
101
+
102
+
103
+ def assume_strides_aligned(t):
104
+ """Assume all strides except the last are divisible by 128 bits.
105
+
106
+ Python int strides (e.g., stride=0 from GQA expand) are kept as-is
107
+ since they're static and don't need alignment assumptions.
108
+ """
109
+ divby = 128 // t.element_type.width
110
+ strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
111
+ return (*strides, t.stride[-1])
112
+
113
+
114
+ def assume_tensor_aligned(t):
115
+ """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
116
+ if t is None:
117
+ return None
118
+ return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))
119
+
120
+
121
+ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
122
+ """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
123
+ tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
124
+ if fully_dynamic:
125
+ return tensor.mark_layout_dynamic()
126
+ if leading_dim == -1:
127
+ leading_dim = t.ndim - 1
128
+ return tensor.mark_layout_dynamic(leading_dim=leading_dim)
129
+
130
+
131
+ def to_cute_aux_tensor(t, enable_tvm_ffi=True):
132
+ """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.
133
+ This allows the user to specify alignment and leading dimension for aux tensors used in
134
+ custom score_mod callables.
135
+ """
136
+ assumed_align: int = getattr(t, "__assumed_align__", None)
137
+ leading_dim: int = getattr(t, "__leading_dim__", None)
138
+ fully_dynamic: bool = leading_dim is None
139
+
140
+ return to_cute_tensor(
141
+ t,
142
+ assumed_align=assumed_align,
143
+ leading_dim=leading_dim,
144
+ fully_dynamic=fully_dynamic,
145
+ enable_tvm_ffi=enable_tvm_ffi,
146
+ )
147
+
148
+
149
+ def get_aux_tensor_metadata(aux_tensors):
150
+ return tuple(
151
+ (
152
+ getattr(t, "__assumed_align__", 0),
153
+ getattr(t, "__leading_dim__", -1),
154
+ hasattr(t, "__leading_dim__"),
155
+ )
156
+ for t in aux_tensors
157
+ )
158
+
159
+
160
+ def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
161
+ """Return tuple of bools indicating which dims have stride=0 (broadcast).
162
+
163
+ This is useful for compile keys since CuTe's mark_layout_dynamic() keeps
164
+ stride=0 as static, meaning kernels compiled with different broadcast
165
+ patterns are not interchangeable.
166
+ """
167
+ return tuple(s == 0 for s in tensor.stride())
build/torch-cuda/fast_math.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import cutlass
4
+ import cutlass.cute as cute
5
+ from cutlass import Int32
6
+
7
+
8
+ @cute.jit
9
+ def clz(x: Int32) -> Int32:
10
+ # for i in cutlass.range_constexpr(32):
11
+ # if (1 << (31 - i)) & x:
12
+ # return Int32(i)
13
+ # return Int32(32)
14
+ # Early exit is not supported yet
15
+ res = Int32(32)
16
+ done = False
17
+ for i in cutlass.range(32):
18
+ if ((1 << (31 - i)) & x) and not done:
19
+ res = Int32(i)
20
+ done = True
21
+ return res
build/torch-cuda/flash_attn4/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-cuda/flash_bwd.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp
3
+ # from Cutlass C++ to Cute-DSL.
4
+ import math
5
+ from types import SimpleNamespace
6
+ from typing import Type, Callable, Optional
7
+ from functools import partial
8
+
9
+ import cuda.bindings.driver as cuda
10
+
11
+ import cutlass
12
+ import cutlass.cute as cute
13
+ from cutlass.cute.nvgpu import cpasync, warp
14
+ from cutlass import Float32, Int32
15
+ import cutlass.utils as utils_basic
16
+
17
+ from .quack import layout_utils
18
+ from . import ampere_helpers as sm80_utils
19
+ from .cute_dsl_utils import assume_tensor_aligned
20
+ from . import utils
21
+ from .mask import AttentionMask
22
+ from .seqlen_info import SeqlenInfoQK
23
+ from .quack.cute_dsl_utils import ParamsBase
24
+ from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
25
+
26
+
27
+ class FlashAttentionBackwardSm80:
28
+ def __init__(
29
+ self,
30
+ dtype: Type[cutlass.Numeric],
31
+ head_dim: int,
32
+ head_dim_v: Optional[int] = None,
33
+ qhead_per_kvhead: int = 1,
34
+ m_block_size: int = 64,
35
+ n_block_size: int = 128,
36
+ num_stages_Q: int = 2,
37
+ num_stages_dO: int = 2,
38
+ num_threads: int = 256,
39
+ pack_gqa: bool = False,
40
+ is_causal: bool = False,
41
+ SdP_swapAB: bool = False,
42
+ dKV_swapAB: bool = False,
43
+ dQ_swapAB: bool = False,
44
+ AtomLayoutMSdP: int = 1,
45
+ AtomLayoutNdKV: int = 8,
46
+ AtomLayoutMdQ: int = 1,
47
+ V_in_regs: bool = False,
48
+ ):
49
+ """Initializes the configuration for a flash attention v2 kernel.
50
+
51
+ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
52
+ should be a multiple of 8.
53
+
54
+ :param head_dim: head dimension
55
+ :type head_dim: int
56
+ :param m_block_size: m block size
57
+ :type m_block_size: int
58
+ :param n_block_size: n block size
59
+ :type n_block_size: int
60
+ :param num_threads: number of threads
61
+ :type num_threads: int
62
+ :param is_causal: is causal
63
+ """
64
+ self.dtype = dtype
65
+ # padding head_dim to a multiple of 16 as k_block_size
66
+ hdim_multiple_of = 32
67
+ self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
68
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
69
+ self.same_hdim_kv = head_dim == head_dim_v
70
+ self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
71
+ # Can save registers (and hence be faster) if we don't have to check hdim predication
72
+ self.check_hdim_oob = head_dim != self.head_dim_padded
73
+ self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
74
+ self.qhead_per_kvhead = qhead_per_kvhead
75
+ self.m_block_size = m_block_size
76
+ self.n_block_size = n_block_size
77
+ self.num_threads = num_threads
78
+ self.pack_gqa = pack_gqa
79
+ self.is_causal = is_causal
80
+ self.num_stages_Q = num_stages_Q
81
+ self.num_stages_dO = num_stages_dO
82
+ self.SdP_swapAB = SdP_swapAB
83
+ self.dKV_swapAB = dKV_swapAB
84
+ self.dQ_swapAB = dQ_swapAB
85
+ self.AtomLayoutMSdP = AtomLayoutMSdP
86
+ self.AtomLayoutNdKV = AtomLayoutNdKV
87
+ self.AtomLayoutMdQ = AtomLayoutMdQ
88
+ num_mma_warps = self.num_threads // cute.arch.WARP_SIZE
89
+ self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB
90
+ self.V_in_regs = V_in_regs
91
+ self.share_QV_smem = V_in_regs
92
+
93
+ @staticmethod
94
+ def can_implement(
95
+ dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO,
96
+ num_threads, is_causal,
97
+ V_in_regs=False
98
+ ) -> bool:
99
+ """Check if the kernel can be implemented with the given parameters.
100
+
101
+ :param dtype: data type
102
+ :type dtype: cutlass.Numeric
103
+ :param head_dim: head dimension
104
+ :type head_dim: int
105
+ :param m_block_size: m block size
106
+ :type m_block_size: int
107
+ :param n_block_size: n block size
108
+ :type n_block_size: int
109
+ :param num_threads: number of threads
110
+ :type num_threads: int
111
+ :param is_causal: is causal
112
+ :type is_causal: bool
113
+
114
+ :return: True if the kernel can be implemented, False otherwise
115
+ :rtype: bool
116
+ """
117
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
118
+ return False
119
+ if head_dim % 8 != 0:
120
+ return False
121
+ if head_dim_v % 8 != 0:
122
+ return False
123
+ if n_block_size % 16 != 0:
124
+ return False
125
+ if num_threads % 32 != 0:
126
+ return False
127
+ # Check if block size setting is out of shared memory capacity
128
+ # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size
129
+ smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2
130
+ smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2
131
+ smem_usage_K = n_block_size * head_dim * 2
132
+ smem_usage_V = n_block_size * head_dim_v * 2
133
+ smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
134
+ smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
135
+ smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80")
136
+ if smem_usage > smem_capacity:
137
+ return False
138
+ return True
139
+
140
+ def _check_type(
141
+ self,
142
+ mQ_type: Type[cutlass.Numeric],
143
+ mK_type: Type[cutlass.Numeric],
144
+ mV_type: Type[cutlass.Numeric],
145
+ mdO_type: Type[cutlass.Numeric],
146
+ mLSE_type: Type[cutlass.Numeric],
147
+ mdPsum_type: Type[cutlass.Numeric],
148
+ mdQaccum_type: Type[cutlass.Numeric],
149
+ mdK_type: Type[cutlass.Numeric],
150
+ mdV_type: Type[cutlass.Numeric],
151
+ mCuSeqlensQ_type: Type[cutlass.Numeric] | None,
152
+ mCuSeqlensK_type: Type[cutlass.Numeric] | None,
153
+ mSeqUsedQ_type: Type[cutlass.Numeric] | None,
154
+ mSeqUsedK_type: Type[cutlass.Numeric] | None,
155
+ ):
156
+ if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
157
+ raise TypeError("All tensors must have the same data type")
158
+ if cutlass.const_expr(self.qhead_per_kvhead == 1):
159
+ if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)):
160
+ raise TypeError("mdK and mdV tensors must have the same data type as mQ")
161
+ else:
162
+ if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)):
163
+ raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
164
+ if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]):
165
+ raise TypeError("Only Float16 or BFloat16 is supported")
166
+ if cutlass.const_expr(not mLSE_type in [cutlass.Float32]):
167
+ raise TypeError("LSE tensor must be Float32")
168
+ if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]):
169
+ raise TypeError("dPsum tensor must be Float32")
170
+ if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]):
171
+ raise TypeError("dQaccum tensor must be Float32")
172
+ if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]):
173
+ raise TypeError("cuSeqlensQ tensor must be Int32")
174
+ if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]):
175
+ raise TypeError("cuSeqlensK tensor must be Int32")
176
+ if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]):
177
+ raise TypeError("SeqUsedQ tensor must be Int32")
178
+ if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]):
179
+ raise TypeError("SeqUsedK tensor must be Int32")
180
+ assert mQ_type == self.dtype
181
+
182
+ def _setup_attributes(self):
183
+ # ///////////////////////////////////////////////////////////////////////////////
184
+ # Shared memory layout: Q/K/V
185
+ # ///////////////////////////////////////////////////////////////////////////////
186
+ sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded)
187
+ self.sQ_layout = cute.tile_to_shape(
188
+ sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2),
189
+ )
190
+ sK_layout_atom = sQ_layout_atom
191
+ self.sK_layout = cute.tile_to_shape(
192
+ sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),
193
+ )
194
+ sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded)
195
+ self.sV_layout = cute.tile_to_shape(
196
+ sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),
197
+ )
198
+ sdO_layout_atom = sV_layout_atom
199
+ self.sdO_layout = cute.tile_to_shape(
200
+ sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2),
201
+ )
202
+ # TODO: do we set swizzle to be 3 here explicitly?
203
+ sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size)
204
+ self.sPdS_layout = cute.tile_to_shape(
205
+ sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),
206
+ )
207
+ # We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
208
+ # it's still a valid smem address.
209
+ self.sLSE_layout = cute.make_layout(
210
+ (self.m_block_size, self.num_stages_Q),
211
+ stride=(1, cute.round_up(self.m_block_size, 64)),
212
+ )
213
+ sLSEMma_layout = cute.make_layout(
214
+ (self.m_block_size, self.n_block_size, self.num_stages_Q),
215
+ stride=(1, 0, cute.round_up(self.m_block_size, 64)),
216
+ )
217
+ sLSEMma_layout_transposed = cute.make_layout(
218
+ (self.n_block_size, self.m_block_size, self.num_stages_Q),
219
+ stride=(0, 1, cute.round_up(self.m_block_size, 64)),
220
+ )
221
+ self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed
222
+
223
+ # ///////////////////////////////////////////////////////////////////////////////
224
+ # GMEM Tiled copy:
225
+ # ///////////////////////////////////////////////////////////////////////////////
226
+ # Thread layouts for copies
227
+ universal_copy_bits = 128
228
+ async_copy_elems = universal_copy_bits // self.dtype.width
229
+ # atom_async_copy: async copy atom for QKV load
230
+ atom_async_copy = cute.make_copy_atom(
231
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
232
+ self.dtype,
233
+ num_bits_per_copy=universal_copy_bits,
234
+ )
235
+ # atom_universal_copy: universal copy atom for O store
236
+ atom_universal_copy = cute.make_copy_atom(
237
+ cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits,
238
+ )
239
+ # tQK_layout: thread layout for QK load
240
+ tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems
241
+ assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1"
242
+ tQK_layout = cute.make_ordered_layout(
243
+ (self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0),
244
+ )
245
+ # Do we need to check if we overshot kBlockM when we load Q?
246
+ self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0
247
+ # Do we need to check if we overshot kBlockN when we load K?
248
+ self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0
249
+ tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems
250
+ assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1"
251
+ tVdO_layout = cute.make_ordered_layout(
252
+ (self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0),
253
+ )
254
+ # Do we need to check if we overshot kBlockN when we load V?
255
+ self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0
256
+ self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0
257
+
258
+ # Value layouts for copies
259
+ vQKVdO_layout = cute.make_layout((1, async_copy_elems))
260
+
261
+ # gmem_tiled_copy_QK: tiled copy for QK load
262
+ self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout)
263
+ self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout)
264
+ self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout)
265
+ self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout)
266
+ async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
267
+
268
+ # I think we wouldn't require this with smarter padding
269
+ if cutlass.const_expr(not self.varlen_q):
270
+ async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
271
+ atom_async_copy_accum = cute.make_copy_atom(
272
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
273
+ cutlass.Float32,
274
+ num_bits_per_copy=universal_copy_bits,
275
+ )
276
+ else:
277
+ async_copy_elems_accum = 1
278
+ atom_async_copy_accum = cute.make_copy_atom(
279
+ cute.nvgpu.CopyUniversalOp(),
280
+ cutlass.Float32,
281
+ num_bits_per_copy=cutlass.Float32.width,
282
+ )
283
+ self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
284
+ atom_async_copy_accum,
285
+ cute.make_layout(self.num_threads),
286
+ cute.make_layout(async_copy_elems_accum),
287
+ )
288
+ self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
289
+ cute.make_copy_atom(
290
+ cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width
291
+ ),
292
+ cute.make_layout(self.num_threads),
293
+ cute.make_layout(1)
294
+ )
295
+ if cutlass.const_expr(self.qhead_per_kvhead > 1):
296
+ self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum
297
+ self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum
298
+
299
+ def _get_tiled_mma(self):
300
+ num_mma_warps = self.num_threads // 32
301
+ AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1)
302
+ tiled_mma_sdp = cute.make_tiled_mma(
303
+ warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
304
+ AtomLayoutSdP,
305
+ permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16),
306
+ )
307
+ AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1)
308
+ tiled_mma_dkv = cute.make_tiled_mma(
309
+ warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
310
+ AtomLayoutdKV,
311
+ permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16),
312
+ )
313
+ AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
314
+ tiled_mma_dq = cute.make_tiled_mma(
315
+ warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),
316
+ AtomLayoutdQ,
317
+ permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16),
318
+ )
319
+ return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq
320
+
321
+ def _get_shared_storage_cls(self):
322
+ sQ_struct, sK_struct, sV_struct, sdO_struct = [
323
+ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]
324
+ for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout)
325
+ ]
326
+ cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
327
+ sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
328
+ sLSE_struct, sdPsum_struct = [
329
+ cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128]
330
+ for layout in (self.sLSE_layout, self.sLSE_layout)
331
+ ]
332
+ sP_struct, sdS_struct = [
333
+ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128]
334
+ for layout in (self.sPdS_layout, self.sPdS_layout)
335
+ ]
336
+
337
+ @cute.struct
338
+ class SharedStorageSeparateQV:
339
+ sK: sK_struct
340
+ sV: sV_struct
341
+ sQ: sQ_struct
342
+ sdO: sdO_struct
343
+ sLSE: sLSE_struct
344
+ sdPsum: sdPsum_struct
345
+ sP: sP_struct
346
+ sdS: sdS_struct
347
+ # TODO: the case where there's no sP
348
+
349
+ @cute.struct
350
+ class SharedStorageSharedQV:
351
+ sK: sK_struct
352
+ sV: sV_struct
353
+ sQ: sQV_struct
354
+ sdO: sdO_struct
355
+ sLSE: sLSE_struct
356
+ sdPsum: sdPsum_struct
357
+ sP: sP_struct
358
+ sdS: sdS_struct
359
+
360
+ return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV
361
+
362
+ @cute.jit
363
+ def __call__(
364
+ self,
365
+ mQ: cute.Tensor,
366
+ mK: cute.Tensor,
367
+ mV: cute.Tensor,
368
+ mdO: cute.Tensor,
369
+ mLSE: cute.Tensor,
370
+ mdPsum: cute.Tensor,
371
+ mdQaccum: cute.Tensor,
372
+ mdK: cute.Tensor,
373
+ mdV: cute.Tensor,
374
+ softmax_scale: cutlass.Float32,
375
+ stream: cuda.CUstream,
376
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
377
+ mCuSeqlensK: Optional[cute.Tensor] = None,
378
+ mSeqUsedQ: Optional[cute.Tensor] = None,
379
+ mSeqUsedK: Optional[cute.Tensor] = None,
380
+ softcap: Float32 | float | None = None,
381
+ window_size_left: Int32 | int | None = None,
382
+ window_size_right: Int32 | int | None = None,
383
+ mdQ_semaphore: Optional[cute.Tensor] = None,
384
+ ):
385
+ assert mdQ_semaphore is None, "semaphore not supported yet"
386
+ # Get the data type and check if it is fp16 or bf16
387
+ self._check_type(*(t.element_type if t is not None else None
388
+ for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
389
+ mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
390
+ assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
391
+ ]
392
+ self.varlen_q = (mCuSeqlensQ is not None)
393
+ self._setup_attributes()
394
+ SharedStorage = self._get_shared_storage_cls()
395
+ tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma()
396
+
397
+ num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2]
398
+
399
+ if cutlass.const_expr(mCuSeqlensK is not None):
400
+ TileScheduler = SingleTileVarlenScheduler
401
+ num_batch = mCuSeqlensK.shape[0] - 1
402
+ else:
403
+ TileScheduler = SingleTileScheduler
404
+ num_batch = mK.shape[0]
405
+
406
+ # Uses seqlen k, etc. since main bwd kernel's blocks are over n
407
+ tile_sched_args = TileSchedulerArguments(
408
+ num_block=cute.ceil_div(mK.shape[1], self.n_block_size),
409
+ num_head=num_head,
410
+ num_batch=num_batch,
411
+ num_splits=1,
412
+ seqlen_k=0,
413
+ headdim=mK.shape[2],
414
+ headdim_v=mV.shape[2],
415
+ total_q=mK.shape[0],
416
+ tile_shape_mn=(self.n_block_size, self.m_block_size),
417
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
418
+ mCuSeqlensQ=mCuSeqlensK,
419
+ mSeqUsedQ=mSeqUsedK,
420
+ )
421
+
422
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
423
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
424
+
425
+ softmax_scale_log2 = softmax_scale * math.log2(math.e)
426
+ self.kernel(
427
+ mQ,
428
+ mK,
429
+ mV,
430
+ mdO,
431
+ mLSE,
432
+ mdPsum,
433
+ mdQaccum,
434
+ mdK,
435
+ mdV,
436
+ mCuSeqlensQ,
437
+ mCuSeqlensK,
438
+ mSeqUsedQ,
439
+ mSeqUsedK,
440
+ softmax_scale,
441
+ softmax_scale_log2,
442
+ self.sQ_layout,
443
+ self.sK_layout,
444
+ self.sV_layout,
445
+ self.sdO_layout,
446
+ self.sPdS_layout,
447
+ self.sLSE_layout,
448
+ self.sLSEMma_layout,
449
+ self.gmem_tiled_copy_QK,
450
+ self.gmem_tiled_copy_VdO,
451
+ self.gmem_tiled_copy_dK,
452
+ self.gmem_tiled_copy_dV,
453
+ self.gmem_tiled_copy_LSE,
454
+ self.gmem_tiled_copy_dQaccum,
455
+ tiled_mma_sdp,
456
+ tiled_mma_dkv,
457
+ tiled_mma_dq,
458
+ SharedStorage,
459
+ tile_sched_params,
460
+ TileScheduler,
461
+ ).launch(
462
+ grid=grid_dim,
463
+ block=[self.num_threads, 1, 1],
464
+ smem=SharedStorage.size_in_bytes(),
465
+ stream=stream,
466
+ )
467
+
468
+ @cute.kernel
469
+ def kernel(
470
+ self,
471
+ mQ: cute.Tensor,
472
+ mK: cute.Tensor,
473
+ mV: cute.Tensor,
474
+ mdO: cute.Tensor,
475
+ mLSE: cute.Tensor,
476
+ mdPsum: cute.Tensor,
477
+ mdQaccum: cute.Tensor,
478
+ mdK: cute.Tensor,
479
+ mdV: cute.Tensor,
480
+ mCuSeqlensQ: Optional[cute.Tensor],
481
+ mCuSeqlensK: Optional[cute.Tensor],
482
+ mSeqUsedQ: Optional[cute.Tensor],
483
+ mSeqUsedK: Optional[cute.Tensor],
484
+ softmax_scale: cutlass.Float32,
485
+ softmax_scale_log2: cutlass.Float32,
486
+ sQ_layout: cute.ComposedLayout,
487
+ sK_layout: cute.ComposedLayout,
488
+ sV_layout: cute.ComposedLayout,
489
+ sdO_layout: cute.ComposedLayout,
490
+ sPdS_layout: cute.ComposedLayout,
491
+ sLSE_layout: cute.Layout,
492
+ sLSEMma_layout: cute.Layout,
493
+ gmem_tiled_copy_QK: cute.TiledCopy,
494
+ gmem_tiled_copy_VdO: cute.TiledCopy,
495
+ gmem_tiled_copy_dK: cute.TiledCopy,
496
+ gmem_tiled_copy_dV: cute.TiledCopy,
497
+ gmem_tiled_copy_LSE: cute.TiledCopy,
498
+ gmem_tiled_copy_dQaccum: cute.TiledCopy,
499
+ tiled_mma_sdp: cute.TiledMma,
500
+ tiled_mma_dkv: cute.TiledMma,
501
+ tiled_mma_dq: cute.TiledMma,
502
+ SharedStorage: cutlass.Constexpr,
503
+ tile_sched_params: ParamsBase,
504
+ TileScheduler: cutlass.Constexpr[Callable],
505
+ ):
506
+ # Thread index, block index
507
+ tidx, _, _ = cute.arch.thread_idx()
508
+
509
+ tile_scheduler = TileScheduler.create(tile_sched_params)
510
+ work_tile = tile_scheduler.initial_work_tile_info()
511
+
512
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
513
+
514
+ if work_tile.is_valid_tile:
515
+ seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
516
+
517
+ m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
518
+ m_block_min = 0
519
+ if cutlass.const_expr(self.is_causal):
520
+ m_block_min = max(
521
+ (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size,
522
+ m_block_min,
523
+ )
524
+ # TODO: return early if m_block_max == 0
525
+
526
+ # ///////////////////////////////////////////////////////////////////////////////
527
+ # Get the appropriate tiles for this thread block.
528
+ # ///////////////////////////////////////////////////////////////////////////////
529
+ blkQ_shape = (self.m_block_size, self.head_dim_padded)
530
+ blkK_shape = (self.n_block_size, self.head_dim_padded)
531
+ blkV_shape = (self.n_block_size, self.head_dim_v_padded)
532
+ blkdO_shape = (self.m_block_size, self.head_dim_v_padded)
533
+
534
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
535
+ mQ_cur = mQ[batch_idx, None, head_idx, None]
536
+ mLSE_cur = mLSE[batch_idx, head_idx, None]
537
+ mdO_cur = mdO[batch_idx, None, head_idx, None]
538
+ mdPsum_cur = mdPsum[batch_idx, head_idx, None]
539
+ mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
540
+ else:
541
+ padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
542
+ mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
543
+ mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
544
+ mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
545
+ mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
546
+ mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None])
547
+ head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx
548
+
549
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
550
+ mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)]
551
+ else:
552
+ mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)]
553
+
554
+ # (m_block_size, head_dim, m_block)
555
+ gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0))
556
+ # (n_block_size, head_dim)
557
+ gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0))
558
+ # (n_block_size, head_dim_v)
559
+ gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0))
560
+ # (m_block_size, head_dim_v, m_block)
561
+ gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0))
562
+ gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,))
563
+ gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,))
564
+ gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,))
565
+
566
+ # ///////////////////////////////////////////////////////////////////////////////
567
+ # Get shared memory buffer
568
+ # ///////////////////////////////////////////////////////////////////////////////
569
+ smem = cutlass.utils.SmemAllocator()
570
+ storage = smem.allocate(SharedStorage)
571
+ sQ = storage.sQ.get_tensor(sQ_layout)
572
+ sK = storage.sK.get_tensor(sK_layout)
573
+ if cutlass.const_expr(not self.share_QV_smem):
574
+ sV = storage.sV.get_tensor(sV_layout)
575
+ else:
576
+ sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)
577
+ sdO = storage.sdO.get_tensor(sdO_layout)
578
+ sP = storage.sP.get_tensor(sPdS_layout)
579
+ sdS = storage.sdS.get_tensor(sPdS_layout)
580
+ sLSE = storage.sLSE.get_tensor(sLSE_layout)
581
+ sdPsum = storage.sdPsum.get_tensor(sLSE_layout)
582
+ sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout)
583
+ sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout)
584
+
585
+ # Transpose view of tensors for tiled mma
586
+ sQt, sdOt, sKt, sPt, sdSt = [layout_utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)]
587
+
588
+ gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx)
589
+ gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx)
590
+ gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx)
591
+ gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
592
+ # (CPY_Atom, CPY_M, CPY_K, m_block)
593
+ tQgQ = gmem_thr_copy_QK.partition_S(gQ)
594
+ tQsQ = gmem_thr_copy_QK.partition_D(sQ)
595
+ # (CPY_Atom, CPY_N, CPY_K)
596
+ tKgK = gmem_thr_copy_QK.partition_S(gK)
597
+ tKsK = gmem_thr_copy_QK.partition_D(sK)
598
+ # (CPY_Atom, CPY_N, CPY_K)
599
+ tVgV = gmem_thr_copy_VdO.partition_S(gV)
600
+ tVsV = gmem_thr_copy_VdO.partition_D(sV)
601
+ # (CPY_Atom, CPY_M, CPY_K, m_block)
602
+ tdOgdO = gmem_thr_copy_VdO.partition_S(gdO)
603
+ tdOsdO = gmem_thr_copy_VdO.partition_D(sdO)
604
+ tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE)
605
+ tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE)
606
+ tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum)
607
+ tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum)
608
+ tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
609
+
610
+ # ///////////////////////////////////////////////////////////////////////////////
611
+ # Tile MMA compute thread partitions and allocate accumulators
612
+ # ///////////////////////////////////////////////////////////////////////////////
613
+ thr_mma_sdp = tiled_mma_sdp.get_slice(tidx)
614
+ thr_mma_dkv = tiled_mma_dkv.get_slice(tidx)
615
+ thr_mma_dq = tiled_mma_dq.get_slice(tidx)
616
+ acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded))
617
+ acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded))
618
+ acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32)
619
+ acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32)
620
+ acc_dK.fill(0.0)
621
+ acc_dV.fill(0.0)
622
+
623
+ tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
624
+ tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB)
625
+ tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)
626
+ tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB)
627
+ tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB)
628
+ tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
629
+ tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB)
630
+ tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)
631
+ tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB)
632
+ tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)
633
+
634
+ LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)
635
+ tSsLSEMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
636
+ tSsdPsumMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
637
+
638
+ # ///////////////////////////////////////////////////////////////////////////////
639
+ # Smem copy atom tiling
640
+ # ///////////////////////////////////////////////////////////////////////////////
641
+ smem_copy_atom = cute.make_copy_atom(
642
+ warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,
643
+ )
644
+ smem_copy_atom_transposed = cute.make_copy_atom(
645
+ warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype,
646
+ )
647
+ smem_thr_copy_QdO = utils.make_tiled_copy_A(
648
+ smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
649
+ ).get_slice(tidx)
650
+ smem_thr_copy_KV = utils.make_tiled_copy_B(
651
+ smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB
652
+ ).get_slice(tidx)
653
+ # TODO: should this be smem_copy_atom_transposed?
654
+ smem_thr_copy_PdSt = utils.make_tiled_copy_A(
655
+ smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
656
+ ).get_slice(tidx)
657
+ smem_thr_copy_QdOt = utils.make_tiled_copy_B(
658
+ smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB
659
+ ).get_slice(tidx)
660
+ smem_thr_copy_dS = utils.make_tiled_copy_A(
661
+ smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB
662
+ ).get_slice(tidx)
663
+ smem_thr_copy_Kt = utils.make_tiled_copy_B(
664
+ smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB
665
+ ).get_slice(tidx)
666
+ # TODO: what's the number of bits? What if SdP_swapAB
667
+ r2s_thr_copy_PdS = cute.make_tiled_copy_C(
668
+ cute.make_copy_atom(
669
+ cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
670
+ ),
671
+ tiled_mma_sdp,
672
+ ).get_slice(tidx)
673
+
674
+ tSsQ = smem_thr_copy_QdO.partition_S(sQ)
675
+ tdPsdO = smem_thr_copy_QdO.partition_S(sdO)
676
+ tSsK = smem_thr_copy_KV.partition_S(sK)
677
+ tdPsV = smem_thr_copy_KV.partition_S(sV)
678
+ tdVsPt = smem_thr_copy_PdSt.partition_S(sPt)
679
+ tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt)
680
+ tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt)
681
+ tdKsQt = smem_thr_copy_QdOt.partition_S(sQt)
682
+ tdQsdS = smem_thr_copy_dS.partition_S(sdS)
683
+ tdQsKt = smem_thr_copy_Kt.partition_S(sKt)
684
+ tPsP = r2s_thr_copy_PdS.partition_D(sP)
685
+ tdSsdS = r2s_thr_copy_PdS.partition_D(sdS)
686
+
687
+ # ///////////////////////////////////////////////////////////////////////////////
688
+ # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
689
+ # of tile_shape
690
+ # ///////////////////////////////////////////////////////////////////////////////
691
+ # Construct identity layout for KV
692
+ cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
693
+ tQcQ = gmem_thr_copy_QK.partition_S(cQ)
694
+ t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ)
695
+ if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
696
+ tdOcdO = tQcQ
697
+ t0dOcdO = t0QcQ
698
+ else:
699
+ cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
700
+ tdOcdO = gmem_thr_copy_VdO.partition_S(cdO)
701
+ t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO)
702
+ cLSE = cute.make_identity_tensor((self.m_block_size,))
703
+ tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE)
704
+
705
+ # Allocate predicate tensors for m and n, here we only allocate the tile of k, and
706
+ # use "if" on the mn dimension.
707
+ # This is to reduce register pressure and gets 2-3% performance gain.
708
+
709
+ d_head = mQ.shape[cute.rank(mQ) - 1]
710
+ d_head_v = mdO.shape[cute.rank(mdO) - 1]
711
+
712
+ tQpQ = utils.predicate_k(tQcQ, limit=d_head)
713
+ if cutlass.const_expr(self.same_hdim_kv):
714
+ tdOpdO = tQpQ
715
+ else:
716
+ tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v)
717
+
718
+ # group parameters for compute_one_m_block
719
+ mma_params = SimpleNamespace(
720
+ thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq,
721
+ tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV,
722
+ tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ,
723
+ tdQrdS=tdQrdS, tdQrK=tdQrK,
724
+ acc_dK=acc_dK, acc_dV=acc_dV,
725
+ )
726
+ smem_copy_params = SimpleNamespace(
727
+ smem_thr_copy_QdO=smem_thr_copy_QdO,
728
+ smem_thr_copy_KV=smem_thr_copy_KV,
729
+ smem_thr_copy_PdSt=smem_thr_copy_PdSt,
730
+ smem_thr_copy_QdOt=smem_thr_copy_QdOt,
731
+ smem_thr_copy_dS=smem_thr_copy_dS,
732
+ smem_thr_copy_Kt=smem_thr_copy_Kt,
733
+ r2s_thr_copy_PdS=r2s_thr_copy_PdS,
734
+ tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV,
735
+ tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma,
736
+ tPsP=tPsP, tdSsdS=tdSsdS,
737
+ tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt,
738
+ tdQsdS=tdQsdS, tdQsKt=tdQsKt,
739
+ )
740
+ gmem_copy_params = SimpleNamespace(
741
+ gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum
742
+ )
743
+ load_Q_LSE = partial(
744
+ self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE,
745
+ tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ,
746
+ tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q
747
+ )
748
+ load_dO_dPsum = partial(
749
+ self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE,
750
+ tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO,
751
+ tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q
752
+ )
753
+ compute_one_m_block = partial(
754
+ self.compute_one_m_block, mma_params=mma_params,
755
+ smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params,
756
+ load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum,
757
+ m_block_max=m_block_max,
758
+ softmax_scale_log2=softmax_scale_log2,
759
+ )
760
+
761
+ # ///////////////////////////////////////////////////////////////////////////////
762
+ # Prologue
763
+ # ///////////////////////////////////////////////////////////////////////////////
764
+ # Start async loads of the last mn-tile, where we take care of the mn residue
765
+ self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k,
766
+ headdim=d_head_v)
767
+ if cutlass.const_expr(self.V_in_regs):
768
+ cute.arch.cp_async_commit_group()
769
+ self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k,
770
+ headdim=d_head)
771
+ cute.arch.cp_async_commit_group()
772
+
773
+ if cutlass.const_expr(self.V_in_regs):
774
+ cute.arch.cp_async_wait_group(1)
775
+ cute.arch.barrier()
776
+ tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV)
777
+ cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view)
778
+ # Sync to avoid loading Q to smem_q, which overlaps with smem_v
779
+ cute.arch.barrier()
780
+
781
+ m_block = m_block_min
782
+ assert self.num_stages_Q >= self.num_stages_dO
783
+ for stage in cutlass.range_constexpr(self.num_stages_Q):
784
+ if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1):
785
+ if stage == 0 or m_block + stage < m_block_max:
786
+ load_Q_LSE(m_block + stage, smem_pipe_write_q=stage)
787
+ cute.arch.cp_async_commit_group()
788
+ if cutlass.const_expr(stage < self.num_stages_dO):
789
+ if stage == 0 or m_block + stage < m_block_max:
790
+ load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage)
791
+ cute.arch.cp_async_commit_group()
792
+
793
+ # ///////////////////////////////////////////////////////////////////////////////
794
+ # Mainloop
795
+ # ///////////////////////////////////////////////////////////////////////////////
796
+ # Start processing of the first n-block.
797
+ mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k)
798
+ mask_fn = partial(
799
+ mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
800
+ mask_seqlen=True, mask_causal=self.is_causal
801
+ )
802
+ smem_pipe_read_q = cutlass.Int32(0)
803
+ smem_pipe_read_do = cutlass.Int32(0)
804
+ smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1)
805
+ smem_pipe_write_do = cutlass.Int32(0)
806
+ for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1):
807
+ compute_one_m_block(
808
+ m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do,
809
+ mask_fn=mask_fn,
810
+ )
811
+ smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q)
812
+ smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO)
813
+ smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q)
814
+ smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO)
815
+
816
+ # ///////////////////////////////////////////////////////////////////////////////
817
+ # Epilogue
818
+ # ///////////////////////////////////////////////////////////////////////////////
819
+ # If GQA, we scale dK in the postprocessing kernel instead
820
+ if cutlass.const_expr(self.qhead_per_kvhead == 1):
821
+ acc_dK.store(acc_dK.load() * softmax_scale)
822
+ # reuse sK and sV data iterator
823
+ sdK = cute.make_tensor(sK.iterator, sK_layout)
824
+ sdV = cute.make_tensor(sV.iterator, sV_layout)
825
+ self.epilogue(
826
+ acc_dK, acc_dV, mdK, mdV, sdK, sdV,
827
+ gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv,
828
+ tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v
829
+ )
830
+
831
+ @cute.jit
832
+ def compute_one_m_block(
833
+ self,
834
+ m_block: cutlass.Int32,
835
+ smem_pipe_read_q: cutlass.Int32,
836
+ smem_pipe_read_do: cutlass.Int32,
837
+ smem_pipe_write_q: cutlass.Int32,
838
+ smem_pipe_write_do: cutlass.Int32,
839
+ mma_params: SimpleNamespace,
840
+ smem_copy_params: SimpleNamespace,
841
+ gmem_copy_params: SimpleNamespace,
842
+ load_Q_LSE: Callable,
843
+ load_dO_dPsum: Callable,
844
+ m_block_max: cutlass.Int32,
845
+ softmax_scale_log2: cutlass.Float32,
846
+ mask_fn: Optional[Callable] = None,
847
+ ):
848
+ def load_Q_next():
849
+ m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1)
850
+ if m_block_next < m_block_max:
851
+ load_Q_LSE(m_block_next, smem_pipe_write_q)
852
+ cute.arch.cp_async_commit_group()
853
+
854
+ def load_dO_next():
855
+ if m_block + self.num_stages_dO < m_block_max:
856
+ load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do)
857
+ cute.arch.cp_async_commit_group()
858
+
859
+ # MMA S
860
+ acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C(
861
+ (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size)
862
+ )
863
+ acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
864
+ acc_S.fill(0.0)
865
+ cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0)
866
+ cute.arch.barrier()
867
+ sm80_utils.gemm(
868
+ mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK,
869
+ smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
870
+ smem_copy_params.tSsK,
871
+ smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
872
+ swap_AB=self.SdP_swapAB,
873
+ )
874
+ tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0])
875
+ cute.autovec_copy(
876
+ smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE
877
+ )
878
+ if cutlass.const_expr(mask_fn is not None):
879
+ mask_fn(acc_S, m_block=m_block)
880
+ acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
881
+ bidx = 0
882
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
883
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
884
+ assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)
885
+ for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
886
+ acc_S_mn[r, None].store(cute.math.exp2(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True))
887
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
888
+
889
+ # MMA dP
890
+ acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32)
891
+ acc_dP.fill(0.0)
892
+ cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0)
893
+ cute.arch.barrier()
894
+ sm80_utils.gemm(
895
+ mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV,
896
+ smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
897
+ smem_copy_params.tdPsV,
898
+ smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
899
+ hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None,
900
+ swap_AB=self.SdP_swapAB,
901
+ )
902
+ tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0])
903
+ cute.autovec_copy(
904
+ smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum
905
+ )
906
+ acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP)
907
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
908
+ assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
909
+ for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
910
+ acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]))
911
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
912
+ rP = cute.make_fragment_like(acc_S, self.dtype)
913
+ rP.store(acc_S.load().to(self.dtype))
914
+ if cutlass.const_expr(not self.Mma_dKV_is_RS):
915
+ tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N)
916
+ cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP)
917
+ rdS = cute.make_fragment_like(acc_dP, self.dtype)
918
+ rdS.store(acc_dP.load().to(self.dtype))
919
+ if cutlass.const_expr(not self.Mma_dKV_is_RS):
920
+ cute.arch.barrier() # Make sure P is written
921
+ # For hdim 64, It's faster to write to smem_dS first before the dV gemm
922
+ if cutlass.const_expr(not self.Mma_dKV_is_RS):
923
+ tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)
924
+ cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)
925
+ if cutlass.const_expr(self.Mma_dKV_is_RS):
926
+ tdVrP = layout_utils.reshape_acc_to_frgA(rP)
927
+ else:
928
+ tdVrP = mma_params.tdVrP
929
+
930
+ # MMA dK
931
+ sm80_utils.gemm(
932
+ mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO,
933
+ smem_copy_params.tdVsPt,
934
+ smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],
935
+ smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
936
+ A_in_regs=self.Mma_dKV_is_RS,
937
+ swap_AB=self.dKV_swapAB,
938
+ )
939
+ # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV)
940
+ cute.arch.barrier() # Make sure dS is written
941
+
942
+ # MMA dQ
943
+ def dQ_mma(hook_fn):
944
+ acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C(
945
+ (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size)
946
+ )
947
+ acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32)
948
+ acc_dQ.fill(0.0)
949
+ sm80_utils.gemm(
950
+ mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK,
951
+ smem_copy_params.tdQsdS, smem_copy_params.tdQsKt,
952
+ smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt,
953
+ swap_AB=self.dQ_swapAB,
954
+ hook_fn=hook_fn
955
+ )
956
+ # ((1, 1), num_elements)
957
+ acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ)
958
+ tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block]
959
+ assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic)
960
+ for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True):
961
+ utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i))
962
+ # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1])
963
+ # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ)
964
+
965
+ # If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration
966
+ if cutlass.const_expr(self.num_stages_Q > 1):
967
+ dQ_mma(load_dO_next)
968
+
969
+ # MMA dK
970
+ if cutlass.const_expr(self.Mma_dKV_is_RS):
971
+ tdVrP = layout_utils.reshape_acc_to_frgA(rdS)
972
+ else:
973
+ tdKrdS = mma_params.tdKrdS
974
+ sm80_utils.gemm(
975
+ mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ,
976
+ smem_copy_params.tdKsdSt,
977
+ smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
978
+ smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,
979
+ A_in_regs=self.Mma_dKV_is_RS,
980
+ swap_AB=self.dKV_swapAB,
981
+ hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None,
982
+ )
983
+ # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK)
984
+ if cutlass.const_expr(self.num_stages_Q == 1):
985
+ cute.arch.barrier()
986
+ dQ_mma(load_Q_next)
987
+
988
+ @cute.jit
989
+ def epilogue(
990
+ self,
991
+ acc_dK: cute.Tensor,
992
+ acc_dV: cute.Tensor,
993
+ mdK: cute.Tensor,
994
+ mdV: cute.Tensor,
995
+ sdK: cute.Tensor,
996
+ sdV: cute.Tensor,
997
+ gmem_tiled_copy_dK: cute.TiledCopy,
998
+ gmem_tiled_copy_dV: cute.TiledCopy,
999
+ tiled_mma: cute.TiledMma,
1000
+ tidx: cutlass.Int32,
1001
+ n_block: cutlass.Int32,
1002
+ num_head: cutlass.Int32,
1003
+ batch_size: cutlass.Int32,
1004
+ seqlen: SeqlenInfoQK,
1005
+ d_head: cutlass.Int32,
1006
+ d_head_v: cutlass.Int32
1007
+ ):
1008
+ rdV = cute.make_fragment_like(acc_dV, self.dtype)
1009
+ rdV.store(acc_dV.load().to(self.dtype))
1010
+ rdK = cute.make_fragment_like(acc_dK, self.dtype)
1011
+ rdK.store(acc_dK.load().to(self.dtype))
1012
+ gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx)
1013
+ gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx)
1014
+
1015
+ batch_idx = batch_size
1016
+ head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
1017
+
1018
+ if cutlass.const_expr(self.qhead_per_kvhead == 1):
1019
+ # Make sure all threads have finished reading K and V, otherwise we get racy dQ
1020
+ # because smem_q could be changed.
1021
+ cute.arch.barrier()
1022
+ # smem copy atom for dKV
1023
+ smem_copy_atom_dKV = cute.make_copy_atom(
1024
+ cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width
1025
+ )
1026
+ smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx)
1027
+ taccdVrdV = smem_thr_copy_dKV.retile(rdV)
1028
+ taccdKrdK = smem_thr_copy_dKV.retile(rdK)
1029
+ taccdVsdV = smem_thr_copy_dKV.partition_D(sdV)
1030
+ taccdKsdK = smem_thr_copy_dKV.partition_D(sdK)
1031
+ # copy acc O from rmem to smem with the smem copy atom
1032
+ cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
1033
+ cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
1034
+
1035
+
1036
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
1037
+ mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)]
1038
+ else:
1039
+ mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)]
1040
+
1041
+ blkdK_shape = (self.n_block_size, self.head_dim_padded)
1042
+ blkdV_shape = (self.n_block_size, self.head_dim_v_padded)
1043
+ gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0))
1044
+ gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0))
1045
+ tdKsdK = gmem_thr_copy_dK.partition_S(sdK)
1046
+ tdKgdK = gmem_thr_copy_dK.partition_D(gdK)
1047
+ tdVsdV = gmem_thr_copy_dV.partition_S(sdV)
1048
+ tdVgdV = gmem_thr_copy_dV.partition_D(gdV)
1049
+ tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype)
1050
+ tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype)
1051
+ # sync before all smem stores are done.
1052
+ cute.arch.barrier()
1053
+ # load acc dK and dV from smem to rmem for wider vectorization
1054
+ # Need to check OOB when reading from smem if kBlockN isn't evenly tiled
1055
+ # TODO
1056
+ cute.autovec_copy(tdKsdK, tdKrdK)
1057
+ cute.autovec_copy(tdVsdV, tdVrdV)
1058
+
1059
+ cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
1060
+ tdKcdK = gmem_thr_copy_dK.partition_S(cdK)
1061
+ t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK)
1062
+ if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):
1063
+ tdVcdV = tdKcdK
1064
+ t0dVcdV = t0dKcdK
1065
+ else:
1066
+ cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
1067
+ tdVcdV = gmem_thr_copy_dV.partition_S(cdV)
1068
+ t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV)
1069
+ tdKpdK = utils.predicate_k(tdKcdK, limit=d_head)
1070
+ if cutlass.const_expr(self.same_hdim_kv):
1071
+ tdVpdV = tdKpdK
1072
+ else:
1073
+ tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v)
1074
+ # copy acc dK and acc_dV from rmem to gmem
1075
+ for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])):
1076
+ if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]:
1077
+ cute.copy(
1078
+ gmem_tiled_copy_dK,
1079
+ tdKrdK[None, rest_m, None],
1080
+ tdKgdK[None, rest_m, None],
1081
+ pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None,
1082
+ )
1083
+ for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])):
1084
+ if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]:
1085
+ cute.copy(
1086
+ gmem_tiled_copy_dV,
1087
+ tdVrdV[None, rest_m, None],
1088
+ tdVgdV[None, rest_m, None],
1089
+ pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None,
1090
+ )
1091
+
1092
+ else: # qhead_per_kvhead > 1, do atomic add
1093
+ # For Sm90, we need to sync to avoid racy writes to smem_q
1094
+ # For Sm80, we don't need to sync since we're not touching smem
1095
+ head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head
1096
+
1097
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_k):
1098
+ mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)]
1099
+ else:
1100
+ padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size
1101
+ mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None])
1102
+ mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None])
1103
+
1104
+ gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,))
1105
+ gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,))
1106
+ tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV)
1107
+ tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK)
1108
+ acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV)
1109
+ acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK)
1110
+ assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum)
1111
+ assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum)
1112
+ for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True):
1113
+ utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i))
1114
+ for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True):
1115
+ utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i))
1116
+
1117
+ @cute.jit
1118
+ def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr):
1119
+ return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0
1120
+
1121
+ @cute.jit
1122
+ def load_K(
1123
+ self,
1124
+ gmem_thr_copy: cute.TiledCopy,
1125
+ tKgK: cute.Tensor,
1126
+ tKsK: cute.Tensor,
1127
+ block: cutlass.Int32,
1128
+ seqlen: cutlass.Int32,
1129
+ headdim: cutlass.Int32,
1130
+ ):
1131
+ cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))
1132
+ tKcK = gmem_thr_copy.partition_S(cK)
1133
+ t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK)
1134
+ tKpK = utils.predicate_k(tKcK, limit=headdim)
1135
+ for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
1136
+ # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
1137
+ if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size:
1138
+ # Instead of using tKcK, we using t0KcK and subtract the offset from the limit
1139
+ # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.
1140
+ predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0]
1141
+ predicate = cute.make_fragment_like(tKpK[None, 0, None])
1142
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1143
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1144
+ predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
1145
+ cute.copy(
1146
+ gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate,
1147
+ )
1148
+ # We need to clear the sK smem tiles since we'll use sKt for mma_dq
1149
+
1150
+ @cute.jit
1151
+ def load_V(
1152
+ self,
1153
+ gmem_thr_copy: cute.TiledCopy,
1154
+ tVgV: cute.Tensor,
1155
+ tVsV: cute.Tensor,
1156
+ block: cutlass.Int32,
1157
+ seqlen: cutlass.Int32,
1158
+ headdim: cutlass.Int32,
1159
+ ):
1160
+ cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))
1161
+ tVcV = gmem_thr_copy.partition_S(cV)
1162
+ t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV)
1163
+ tVpV = utils.predicate_k(tVcV, limit=headdim)
1164
+ for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
1165
+ # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
1166
+ if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size:
1167
+ # Instead of using tVcV, we using t0VcV and subtract the offset from the limit
1168
+ # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time.
1169
+ predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0]
1170
+ predicate = cute.make_fragment_like(tVpV[None, 0, None])
1171
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1172
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1173
+ predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n
1174
+ cute.copy(
1175
+ gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate,
1176
+ )
1177
+
1178
+ @cute.jit
1179
+ def load_Q_LSE(
1180
+ self,
1181
+ gmem_tiled_copy_Q: cute.TiledCopy,
1182
+ gmem_tiled_copy_LSE: cute.TiledCopy,
1183
+ tQgQ: cute.Tensor,
1184
+ tQsQ: cute.Tensor,
1185
+ tQcQ: cute.Tensor,
1186
+ t0QcQ: cute.Tensor,
1187
+ tQpQ: cute.Tensor,
1188
+ tLSEgLSE: cute.Tensor,
1189
+ tLSEsLSE: cute.Tensor,
1190
+ tLSEcLSE: cute.Tensor,
1191
+ block: cutlass.Int32,
1192
+ smem_pipe_write_q: cutlass.Int32,
1193
+ seqlen: cutlass.Int32,
1194
+ ):
1195
+ for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
1196
+ # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
1197
+ if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size:
1198
+ # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit
1199
+ # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.
1200
+ predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]
1201
+ predicate = cute.make_fragment_like(tQpQ[None, 0, None])
1202
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1203
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1204
+ predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
1205
+ cute.copy(
1206
+ gmem_tiled_copy_Q,
1207
+ tQgQ[None, m, None, block],
1208
+ tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0],
1209
+ pred=predicate,
1210
+ )
1211
+ # We need to clear the sQ smem tiles since we'll use sQt for mma_dK
1212
+ # We made sure LSE length is padded so we read `kBlockM` elements so that all
1213
+ # elements in sLSE are filled. Without this we might have uninitialized sLSE values.
1214
+ for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])):
1215
+ if tLSEcLSE[0, m][0] < self.m_block_size:
1216
+ cute.copy(
1217
+ gmem_tiled_copy_LSE,
1218
+ tLSEgLSE[None, m, block],
1219
+ tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],
1220
+ )
1221
+
1222
+ @cute.jit
1223
+ def load_dO_dPsum(
1224
+ self,
1225
+ gmem_tiled_copy_dO: cute.TiledCopy,
1226
+ gmem_tiled_copy_dPsum: cute.TiledCopy,
1227
+ tdOgdO: cute.Tensor,
1228
+ tdOsdO: cute.Tensor,
1229
+ tdOcdO: cute.Tensor,
1230
+ t0dOcdO: cute.Tensor,
1231
+ tdOpdO: cute.Tensor,
1232
+ tdPsumgdPsum: cute.Tensor,
1233
+ tdPsumsdPsum: cute.Tensor,
1234
+ tdPsumcdPsum: cute.Tensor,
1235
+ block: cutlass.Int32,
1236
+ smem_pipe_write_q: cutlass.Int32,
1237
+ seqlen: cutlass.Int32,
1238
+ ):
1239
+ for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])):
1240
+ # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked
1241
+ if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size:
1242
+ # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit
1243
+ # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time.
1244
+ predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0]
1245
+ predicate = cute.make_fragment_like(tdOpdO[None, 0, None])
1246
+ for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
1247
+ for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
1248
+ predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m
1249
+ cute.copy(
1250
+ gmem_tiled_copy_dO,
1251
+ tdOgdO[None, m, None, block],
1252
+ tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
1253
+ pred=predicate,
1254
+ )
1255
+ # We need to clear the sQ smem tiles since we'll use sQt for mma_dK
1256
+ # We made sure LSE length is padded so we read `kBlockM` elements so that all
1257
+ # elements in sLSE are filled. Without this we might have uninitialized sLSE values.
1258
+ for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])):
1259
+ if tdPsumcdPsum[0, m][0] < self.m_block_size:
1260
+ cute.copy(
1261
+ gmem_tiled_copy_dPsum,
1262
+ tdPsumgdPsum[None, m, block],
1263
+ tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],
1264
+ )
build/torch-cuda/flash_bwd_postprocess.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
3
+ # from Cutlass C++ to Cute-DSL.
4
+ import math
5
+ from typing import Callable, Optional, Type, Literal
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ import cutlass.utils.hopper_helpers as sm90_utils_basic
12
+ import cutlass.utils.blackwell_helpers as sm100_utils_basic
13
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
14
+ from cutlass import Float32, const_expr
15
+ from cutlass.utils import LayoutEnum
16
+
17
+ from .quack import copy_utils
18
+ from .quack import layout_utils
19
+ from .quack import sm90_utils
20
+
21
+ from . import utils
22
+ from .cute_dsl_utils import assume_tensor_aligned
23
+ from . import ampere_helpers as sm80_utils
24
+ from .seqlen_info import SeqlenInfoQK
25
+ import cutlass.cute.nvgpu.tcgen05 as tcgen05
26
+ from .quack.cute_dsl_utils import ParamsBase
27
+ from .tile_scheduler import (
28
+ SingleTileScheduler,
29
+ SingleTileVarlenScheduler,
30
+ TileSchedulerArguments,
31
+ )
32
+
33
+
34
+ class FlashAttentionBackwardPostprocess:
35
+ def __init__(
36
+ self,
37
+ dtype: Type[cutlass.Numeric],
38
+ head_dim: int,
39
+ arch: Literal[80, 90, 100],
40
+ tile_m: int = 128,
41
+ num_threads: int = 256,
42
+ AtomLayoutMdQ: int = 1,
43
+ dQ_swapAB: bool = False,
44
+ use_2cta_instrs: bool = False,
45
+ cluster_size: int = 1, # for varlen offsets
46
+ ):
47
+ """
48
+ :param head_dim: head dimension
49
+ :type head_dim: int
50
+ :param tile_m: m block size
51
+ :type tile_m: int
52
+ """
53
+ self.dtype = dtype
54
+ self.tile_m = tile_m
55
+ assert arch // 10 in [8, 9, 10, 11], (
56
+ "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x) are supported"
57
+ )
58
+ self.arch = arch
59
+ # padding head_dim to a multiple of 32 as k_block_size
60
+ hdim_multiple_of = 32
61
+ self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
62
+ self.check_hdim_oob = head_dim != self.tile_hdim
63
+ self.num_threads = num_threads
64
+ self.AtomLayoutMdQ = AtomLayoutMdQ
65
+ self.dQ_swapAB = dQ_swapAB
66
+ self.use_2cta_instrs = use_2cta_instrs and arch == 100 and head_dim != 64
67
+ self.cluster_size = cluster_size
68
+
69
+ @staticmethod
70
+ def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
71
+ """Check if the kernel can be implemented with the given parameters.
72
+
73
+ :param dtype: data type
74
+ :type dtype: cutlass.Numeric
75
+ :param head_dim: head dimension
76
+ :type head_dim: int
77
+ :param tile_m: m block size
78
+ :type tile_m: int
79
+
80
+ :return: True if the kernel can be implemented, False otherwise
81
+ :rtype: bool
82
+ """
83
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
84
+ return False
85
+ if head_dim % 8 != 0:
86
+ return False
87
+ if num_threads % 32 != 0:
88
+ return False
89
+ return True
90
+
91
+ def _get_tiled_mma(self):
92
+ if const_expr(self.arch == 80):
93
+ num_mma_warps = self.num_threads // 32
94
+ atom_layout_dQ = (
95
+ (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
96
+ if const_expr(not self.dQ_swapAB)
97
+ else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)
98
+ )
99
+ tiled_mma = cute.make_tiled_mma(
100
+ warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),
101
+ atom_layout_dQ,
102
+ permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
103
+ )
104
+ elif const_expr(self.arch == 90):
105
+ num_mma_warp_groups = self.num_threads // 128
106
+ atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ)
107
+ tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
108
+ tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
109
+ self.dtype,
110
+ self.dtype,
111
+ warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum
112
+ warpgroup.OperandMajorMode.K,
113
+ Float32,
114
+ atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1])
115
+ + (1,),
116
+ tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
117
+ )
118
+ else:
119
+ cta_group = tcgen05.CtaGroup.ONE
120
+ tiled_mma = sm100_utils_basic.make_trivial_tiled_mma(
121
+ self.dtype,
122
+ tcgen05.OperandMajorMode.MN, # dS_major_mode
123
+ tcgen05.OperandMajorMode.MN, # Kt_major_mode
124
+ Float32,
125
+ cta_group,
126
+ (self.tile_m, self.tile_hdim),
127
+ )
128
+ if const_expr(self.arch in [80, 90]):
129
+ assert self.num_threads == tiled_mma.size
130
+ return tiled_mma
131
+
132
+ def _setup_attributes(self):
133
+ # ///////////////////////////////////////////////////////////////////////////////
134
+ # GMEM Tiled copy:
135
+ # ///////////////////////////////////////////////////////////////////////////////
136
+ # Thread layouts for copies
137
+ universal_copy_bits = 128
138
+ async_copy_elems_accum = universal_copy_bits // Float32.width
139
+ atom_async_copy_accum = cute.make_copy_atom(
140
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
141
+ Float32,
142
+ num_bits_per_copy=universal_copy_bits,
143
+ )
144
+ # We don't do bound checking for the gmem -> smem load so we just assert here.
145
+ assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0
146
+ self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
147
+ atom_async_copy_accum,
148
+ cute.make_layout(self.num_threads),
149
+ cute.make_layout(async_copy_elems_accum),
150
+ )
151
+ num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4
152
+ if const_expr(self.arch == 80):
153
+ self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
154
+ Float32, self.num_threads, num_s2r_copy_elems
155
+ )
156
+ self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
157
+ elif const_expr(self.arch == 90):
158
+ num_threads_per_warp_group = 128
159
+ num_mma_warp_groups = self.num_threads // 128
160
+ self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
161
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
162
+ cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout
163
+ cute.make_layout(128 // Float32.width), # val_layout
164
+ )
165
+ self.sdQaccum_layout = cute.make_layout(
166
+ (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups)
167
+ )
168
+ else:
169
+ self.dQ_reduce_ncol = 32
170
+ dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
171
+ assert self.num_threads == 128 # TODO: currently hard-coded
172
+ self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
173
+ Float32, self.num_threads, num_s2r_copy_elems
174
+ )
175
+ self.sdQaccum_layout = cute.make_layout(
176
+ (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage)
177
+ )
178
+
179
+ num_copy_elems = 128 // self.dtype.width
180
+ threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems
181
+ self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(
182
+ self.dtype, threads_per_row, self.num_threads, num_copy_elems
183
+ )
184
+ # ///////////////////////////////////////////////////////////////////////////////
185
+ # Shared memory layout: dQ
186
+ # ///////////////////////////////////////////////////////////////////////////////
187
+ # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
188
+ # then setting kBlockKSmem to 32 will cause "Static shape_div failure".
189
+ # We want to treat it as 64 x 48, so kBlockKSmem should be 16.
190
+ mma_shape_n = self.tiled_mma.get_tile_size(1)
191
+ if const_expr(self.arch == 80):
192
+ sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
193
+ self.sdQ_layout = cute.tile_to_shape(
194
+ sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
195
+ )
196
+ elif const_expr(self.arch == 90):
197
+ self.sdQ_layout = sm90_utils.make_smem_layout(
198
+ self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)
199
+ )
200
+ else:
201
+ # TODO: this is hard-coded for hdim 128
202
+ self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
203
+ self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1
204
+ )
205
+
206
+ @cute.jit
207
+ def __call__(
208
+ self,
209
+ mdQaccum: cute.Tensor,
210
+ mdQ: cute.Tensor,
211
+ scale: cutlass.Float32,
212
+ mCuSeqlensQ: Optional[cute.Tensor],
213
+ mSeqUsedQ: Optional[cute.Tensor],
214
+ stream: cuda.CUstream,
215
+ ):
216
+ # Get the data type and check if it is fp16 or bf16
217
+ if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
218
+ raise TypeError("Only Float16 or BFloat16 is supported")
219
+ if const_expr(mdQaccum is not None):
220
+ if const_expr(mdQaccum.element_type not in [cutlass.Float32]):
221
+ raise TypeError("dQaccum tensor must be Float32")
222
+
223
+ mdQaccum, mdQ = [assume_tensor_aligned(t) for t in (mdQaccum, mdQ)]
224
+
225
+ self.tiled_mma = self._get_tiled_mma()
226
+ self._setup_attributes()
227
+
228
+ smem_size = max(
229
+ cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout),
230
+ cute.size_in_bytes(self.dtype, self.sdQ_layout),
231
+ )
232
+
233
+ if const_expr(mCuSeqlensQ is not None):
234
+ TileScheduler = SingleTileVarlenScheduler
235
+ num_head = mdQ.shape[1]
236
+ num_batch = mCuSeqlensQ.shape[0] - 1
237
+ num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
238
+ else:
239
+ TileScheduler = SingleTileScheduler
240
+ num_head = mdQ.shape[2]
241
+ num_batch = mdQ.shape[0]
242
+ num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)
243
+
244
+ tile_sched_args = TileSchedulerArguments(
245
+ num_block=num_block,
246
+ num_head=num_head,
247
+ num_batch=num_batch,
248
+ num_splits=1,
249
+ seqlen_k=0,
250
+ headdim=mdQ.shape[2],
251
+ headdim_v=0,
252
+ total_q=mdQ.shape[0],
253
+ tile_shape_mn=(self.tile_m, 1),
254
+ mCuSeqlensQ=mCuSeqlensQ,
255
+ mSeqUsedQ=mSeqUsedQ,
256
+ )
257
+
258
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
259
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
260
+
261
+ # grid_dim: (m_block, num_head, batch_size)
262
+ self.kernel(
263
+ mdQaccum,
264
+ mdQ,
265
+ mCuSeqlensQ,
266
+ mSeqUsedQ,
267
+ scale,
268
+ self.tiled_mma,
269
+ self.dQ_swapAB,
270
+ self.sdQaccum_layout,
271
+ self.sdQ_layout,
272
+ self.g2s_tiled_copy_dQaccum,
273
+ self.s2r_tiled_copy_dQaccum,
274
+ self.gmem_tiled_copy_dQ,
275
+ tile_sched_params,
276
+ TileScheduler,
277
+ ).launch(
278
+ grid=grid_dim,
279
+ block=[self.num_threads, 1, 1],
280
+ smem=smem_size,
281
+ stream=stream,
282
+ )
283
+
284
+ @cute.kernel
285
+ def kernel(
286
+ self,
287
+ mdQaccum: cute.Tensor,
288
+ mdQ: cute.Tensor,
289
+ mCuSeqlensQ: Optional[cute.Tensor],
290
+ mSeqUsedQ: Optional[cute.Tensor],
291
+ scale: cutlass.Float32,
292
+ tiled_mma: cute.TiledMma,
293
+ dQ_swapAB: cutlass.Constexpr,
294
+ sdQaccum_layout: cute.Layout,
295
+ sdQ_layout: cute.ComposedLayout,
296
+ g2s_tiled_copy_dQaccum: cute.TiledCopy,
297
+ s2r_tiled_copy_dQaccum: cute.TiledCopy,
298
+ gmem_tiled_copy_dQ: cute.TiledCopy,
299
+ tile_sched_params: ParamsBase,
300
+ TileScheduler: cutlass.Constexpr[Callable],
301
+ ):
302
+ # ///////////////////////////////////////////////////////////////////////////////
303
+ # Get shared memory buffer
304
+ # ///////////////////////////////////////////////////////////////////////////////
305
+ smem = cutlass.utils.SmemAllocator()
306
+ sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
307
+ sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
308
+ if const_expr(self.arch in [80, 90]):
309
+ sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
310
+ else:
311
+ # extra stage dimension
312
+ sdQ = cute.make_tensor(
313
+ cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype),
314
+ sdQ_layout.outer,
315
+ )[None, None, 0]
316
+ sdQt = layout_utils.transpose_view(sdQ)
317
+
318
+ # Thread index, block index
319
+ tidx, _, _ = cute.arch.thread_idx()
320
+
321
+ tile_scheduler = TileScheduler.create(tile_sched_params)
322
+ work_tile = tile_scheduler.initial_work_tile_info()
323
+
324
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
325
+
326
+ if work_tile.is_valid_tile:
327
+ # ///////////////////////////////////////////////////////////////////////////////
328
+ # Get the appropriate tiles for this thread block.
329
+ # ///////////////////////////////////////////////////////////////////////////////
330
+
331
+ seqlen = SeqlenInfoQK.create(
332
+ batch_idx,
333
+ mdQ.shape[1],
334
+ 0,
335
+ mCuSeqlensQ=mCuSeqlensQ,
336
+ mCuSeqlensK=None,
337
+ mSeqUsedQ=mSeqUsedQ,
338
+ mSeqUsedK=None,
339
+ tile_m=self.tile_m * self.cluster_size,
340
+ )
341
+ if const_expr(not seqlen.has_cu_seqlens_q):
342
+ mdQ_cur = mdQ[batch_idx, None, head_idx, None]
343
+ mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
344
+ head_dim = mdQ.shape[3]
345
+ else:
346
+ if cutlass.const_expr(self.arch >= 90):
347
+ padded_offset_q = seqlen.padded_offset_q
348
+ else:
349
+ padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
350
+ mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
351
+ mdQaccum_cur = cute.domain_offset(
352
+ (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
353
+ )
354
+ head_dim = mdQ.shape[2]
355
+
356
+ # HACK: Compiler doesn't seem to recognize that padding
357
+ # by padded_offset_q * self.tile_hdim keeps alignment
358
+ # since statically divisible by 4
359
+
360
+ mdQaccum_cur_ptr = cute.make_ptr(
361
+ dtype=mdQaccum_cur.element_type,
362
+ value=mdQaccum_cur.iterator.toint(),
363
+ mem_space=mdQaccum_cur.iterator.memspace,
364
+ assumed_align=mdQaccum.iterator.alignment,
365
+ )
366
+ mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
367
+
368
+ gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))
369
+ gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
370
+
371
+ seqlen_q = seqlen.seqlen_q
372
+ seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
373
+
374
+ if const_expr(self.arch == 100 and self.use_2cta_instrs):
375
+ # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
376
+ num_reduce_threads = self.num_threads
377
+ thr_mma_dsk = tiled_mma.get_slice(tidx)
378
+ dQacc_shape = thr_mma_dsk.partition_shape_C((self.tile_m, self.tile_hdim))
379
+ tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape)
380
+ tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout)
381
+
382
+ tmem_load_atom = cute.make_copy_atom(
383
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32
384
+ )
385
+ tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)
386
+ thr_tmem_ld = tiled_tmem_ld.get_slice(tidx)
387
+
388
+ cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
389
+ tdQcdQ = thr_mma_dsk.partition_C(cdQ)
390
+ tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout)
391
+ tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor)
392
+
393
+ tiled_copy_accum = s2r_tiled_copy_dQaccum
394
+ g2s_thr_copy = tiled_copy_accum.get_slice(tidx)
395
+
396
+ # S -> R
397
+ tdQrdQ_fp32 = cute.make_fragment(tdQrdQ.shape, cutlass.Float32)
398
+ tdQrdQ_s2r = cute.make_tensor(tdQrdQ_fp32.iterator, tdQrdQ_fp32.shape)
399
+
400
+ smem_copy_atom = sm100_utils_basic.get_smem_store_op(
401
+ LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld
402
+ )
403
+ r2s_tiled_copy = cute.make_tiled_copy(
404
+ smem_copy_atom,
405
+ layout_tv=tiled_tmem_ld.layout_dst_tv_tiled,
406
+ tiler_mn=tiled_tmem_ld.tiler_mn,
407
+ )
408
+ tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ))
409
+ tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype)
410
+
411
+ num_stages = cute.size(tdQrdQ_fp32, mode=[1])
412
+ stage_stride = self.dQ_reduce_ncol
413
+ row_groups = 2
414
+ assert num_stages % row_groups == 0
415
+ assert num_reduce_threads % row_groups == 0
416
+ stage_groups = num_stages // row_groups
417
+ threads_per_row_group = num_reduce_threads // row_groups
418
+ stage_loads = tuple((row_group, row_group) for row_group in range(row_groups))
419
+ stage_iters = tuple(
420
+ (row_group, row_group * threads_per_row_group)
421
+ for row_group in range(row_groups)
422
+ )
423
+ s2r_lane = tidx % threads_per_row_group
424
+ s2r_buf = tidx // threads_per_row_group
425
+
426
+ gdQaccum_layout_g2s = cute.make_layout(
427
+ shape=(self.tile_m * self.dQ_reduce_ncol, 1), stride=(1, 0)
428
+ )
429
+ sdQaccum_g2s = g2s_thr_copy.partition_D(sdQaccum)
430
+
431
+ # G -> S
432
+ for stage_group in cutlass.range_constexpr(stage_groups):
433
+ for stage_offset, smem_buf in stage_loads:
434
+ stage_idx = stage_group + stage_offset * stage_groups
435
+ gdQaccum_stage = cute.local_tile(
436
+ gdQaccum,
437
+ (self.tile_m * self.dQ_reduce_ncol,),
438
+ (stage_idx,),
439
+ )
440
+ gdQaccum_stage_g2s = cute.make_tensor(
441
+ gdQaccum_stage.iterator,
442
+ gdQaccum_layout_g2s,
443
+ )
444
+ tdQgdQ = g2s_thr_copy.partition_S(gdQaccum_stage_g2s)
445
+ cute.copy(
446
+ g2s_thr_copy,
447
+ tdQgdQ[None, None, 0],
448
+ sdQaccum_g2s[None, None, smem_buf],
449
+ )
450
+
451
+ cute.arch.fence_view_async_shared()
452
+ cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads)
453
+
454
+ # S -> R
455
+ for stage_offset, lane_offset in stage_iters:
456
+ stage_idx = stage_group + stage_offset * stage_groups
457
+ s2r_src_tidx = s2r_lane + lane_offset
458
+ s2r_thr_copy = tiled_copy_accum.get_slice(s2r_src_tidx)
459
+ sdQaccum_src = s2r_thr_copy.partition_S(sdQaccum)[None, None, s2r_buf]
460
+
461
+ tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage_idx, None, None]
462
+ tdQrdQ_r2s_cpy = cute.make_tensor(
463
+ tdQrdQ_s2r_cpy.iterator, cute.make_layout(sdQaccum_src.shape)
464
+ )
465
+ cute.copy(s2r_thr_copy, sdQaccum_src, tdQrdQ_r2s_cpy)
466
+ cute.arch.fence_view_async_shared()
467
+ cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads)
468
+
469
+ # R -> S
470
+ stage_lo = stage_idx % stage_stride
471
+ stage_hi = stage_idx // stage_stride
472
+ tdQrdQ_r2s_cpy = cute.make_tensor(
473
+ cute.recast_ptr(tdQrdQ_r2s_cpy.iterator),
474
+ tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].shape,
475
+ )
476
+ dQ_vec = tdQrdQ_r2s_cpy.load() * scale
477
+ tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].store(
478
+ dQ_vec.to(self.dtype)
479
+ )
480
+
481
+ # R -> S
482
+ cute.copy(
483
+ r2s_tiled_copy,
484
+ tdQrdQ_r2s[None, None, None, 0],
485
+ tdQsdQ_r2s[None, None, None, 0],
486
+ )
487
+ cute.arch.fence_view_async_shared()
488
+ cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads)
489
+ else:
490
+ # Step 1: load dQaccum from gmem to smem
491
+ g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx)
492
+ tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum)
493
+ tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat)
494
+ cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s)
495
+ cute.arch.cp_async_commit_group()
496
+ cute.arch.cp_async_wait_group(0)
497
+ cute.arch.barrier()
498
+
499
+ # Step 2: load dQ from smem to rmem
500
+ s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx)
501
+ tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
502
+ tile_shape = (self.tile_m, self.tile_hdim)
503
+ acc = None
504
+ tiled_copy_t2r = None
505
+ if const_expr(self.arch in [80, 90]):
506
+ acc_shape = tiled_mma.partition_shape_C(
507
+ tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
508
+ )
509
+ acc = cute.make_fragment(acc_shape, cutlass.Float32)
510
+ assert cute.size(acc) == cute.size(tdQsdQaccum)
511
+ else:
512
+ thr_mma = tiled_mma.get_slice(0) # 1-CTA
513
+ dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim))
514
+ tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape)
515
+ tdQcdQ = thr_mma.partition_C(
516
+ cute.make_identity_tensor((self.tile_m, self.tile_hdim))
517
+ )
518
+ tmem_load_atom = cute.make_copy_atom(
519
+ tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)),
520
+ Float32,
521
+ )
522
+ tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)
523
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
524
+ tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape
525
+ acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32)
526
+ tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape))
527
+ cute.autovec_copy(tdQsdQaccum, tdQrdQaccum)
528
+ # Convert tdQrdQaccum from fp32 to fp16/bf16
529
+ rdQ = cute.make_fragment_like(acc, self.dtype)
530
+ rdQ.store((acc.load() * scale).to(self.dtype))
531
+
532
+ # Step 3: Copy dQ from register to smem
533
+ cute.arch.barrier() # make sure all threads have finished loading dQaccum
534
+ if const_expr(self.arch in [80, 90]):
535
+ copy_atom_r2s_dQ = utils.get_smem_store_atom(
536
+ self.arch, self.dtype, transpose=self.dQ_swapAB
537
+ )
538
+ tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma)
539
+ else:
540
+ # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op(
541
+ # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r,
542
+ # )
543
+ # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r)
544
+ thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads
545
+ val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width))
546
+ copy_atom_r2s_dQ = cute.make_copy_atom(
547
+ cute.nvgpu.CopyUniversalOp(),
548
+ self.dtype,
549
+ num_bits_per_copy=128,
550
+ )
551
+ tiled_copy_r2s_dQ = cute.make_tiled_copy_tv(
552
+ copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ
553
+ )
554
+ thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
555
+ cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
556
+ if const_expr(self.arch in [80, 90]):
557
+ taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
558
+ else:
559
+ taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
560
+ taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape)
561
+ taccdQsdQ = thr_copy_r2s_dQ.partition_D(
562
+ sdQ if const_expr(not self.dQ_swapAB) else sdQt
563
+ )
564
+ cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ)
565
+
566
+ # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
567
+ cute.arch.barrier() # make sure all smem stores are done
568
+ gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx)
569
+ tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ)
570
+ tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ)
571
+ tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype)
572
+ # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled
573
+ cute.autovec_copy(tdQsdQ, tdQrdQ)
574
+
575
+ # Step 5: Copy dQ from register to gmem
576
+ tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ)
577
+ tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim)
578
+ for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True):
579
+ if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m:
580
+ cute.copy(
581
+ gmem_tiled_copy_dQ,
582
+ tdQrdQ[None, rest_m, None],
583
+ tdQgdQ[None, rest_m, None],
584
+ pred=tdQpdQ[None, rest_m, None],
585
+ )
build/torch-cuda/flash_bwd_preprocess.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
3
+ # from Cutlass C++ to Cute-DSL.
4
+ import math
5
+ import operator
6
+ from typing import Callable, Type, Optional, Literal
7
+
8
+ import cuda.bindings.driver as cuda
9
+
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+ from cutlass import Float32
13
+
14
+ from .quack import copy_utils
15
+
16
+ from . import utils
17
+ from .cute_dsl_utils import assume_tensor_aligned
18
+ from .seqlen_info import SeqlenInfoQK
19
+ from .quack.cute_dsl_utils import ParamsBase
20
+ from .tile_scheduler import (
21
+ SingleTileScheduler,
22
+ SingleTileVarlenScheduler,
23
+ TileSchedulerArguments,
24
+ )
25
+
26
+
27
+ class FlashAttentionBackwardPreprocess:
28
+ def __init__(
29
+ self,
30
+ dtype: Type[cutlass.Numeric],
31
+ head_dim: int,
32
+ head_dim_v: int,
33
+ arch: Literal[80, 90, 100],
34
+ m_block_size: int = 128,
35
+ num_threads: int = 128,
36
+ ):
37
+ """
38
+ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
39
+ should be a multiple of 8.
40
+
41
+ :param head_dim: head dimension
42
+ :type head_dim: int
43
+ :param m_block_size: m block size
44
+ :type m_block_size: int
45
+ :param num_threads: number of threads
46
+ :type num_threads: int
47
+ """
48
+ self.dtype = dtype
49
+ self.m_block_size = m_block_size
50
+ self.arch = arch
51
+ # padding head_dim to a multiple of 32 as k_block_size
52
+ hdim_multiple_of = 32
53
+ self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
54
+ self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
55
+ self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded
56
+ self.num_threads = num_threads
57
+
58
+ @staticmethod
59
+ def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool:
60
+ """Check if the kernel can be implemented with the given parameters.
61
+
62
+ :param dtype: data type
63
+ :type dtype: cutlass.Numeric
64
+ :param head_dim: head dimension
65
+ :type head_dim: int
66
+ :param m_block_size: m block size
67
+ :type m_block_size: int
68
+ :param num_threads: number of threads
69
+ :type num_threads: int
70
+
71
+ :return: True if the kernel can be implemented, False otherwise
72
+ :rtype: bool
73
+ """
74
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
75
+ return False
76
+ if head_dim % 8 != 0:
77
+ return False
78
+ if num_threads % 32 != 0:
79
+ return False
80
+ if num_threads < m_block_size: # For multiplying lse with log2
81
+ return False
82
+ return True
83
+
84
+ def _setup_attributes(self):
85
+ # ///////////////////////////////////////////////////////////////////////////////
86
+ # GMEM Tiled copy:
87
+ # ///////////////////////////////////////////////////////////////////////////////
88
+ # Thread layouts for copies
89
+ # We want kBlockKGmem to be a power of 2 so that when we do the summing,
90
+ # it's just between threads in the same warp
91
+ gmem_k_block_size = (
92
+ 128
93
+ if self.head_dim_v_padded % 128 == 0
94
+ else (
95
+ 64
96
+ if self.head_dim_v_padded % 64 == 0
97
+ else (32 if self.head_dim_v_padded % 32 == 0 else 16)
98
+ )
99
+ )
100
+ num_copy_elems = 128 // self.dtype.width
101
+ threads_per_row = gmem_k_block_size // num_copy_elems
102
+ self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(
103
+ self.dtype, threads_per_row, self.num_threads, num_copy_elems
104
+ )
105
+ universal_copy_bits = 128
106
+ num_copy_elems_dQaccum = universal_copy_bits // Float32.width
107
+ assert (
108
+ self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum
109
+ ) % self.num_threads == 0
110
+ self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
111
+ Float32, self.num_threads, num_copy_elems_dQaccum
112
+ )
113
+
114
+ @cute.jit
115
+ def __call__(
116
+ self,
117
+ mO: cute.Tensor,
118
+ mdO: cute.Tensor,
119
+ mdPsum: cute.Tensor,
120
+ mLSE: Optional[cute.Tensor],
121
+ mLSElog2: Optional[cute.Tensor],
122
+ mdQaccum: Optional[cute.Tensor],
123
+ mCuSeqlensQ: Optional[cute.Tensor],
124
+ mSeqUsedQ: Optional[cute.Tensor],
125
+ stream: cuda.CUstream,
126
+ ):
127
+ # Get the data type and check if it is fp16 or bf16
128
+ if cutlass.const_expr(not (mO.element_type == mdO.element_type)):
129
+ raise TypeError("All tensors must have the same data type")
130
+ if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
131
+ raise TypeError("Only Float16 or BFloat16 is supported")
132
+ if cutlass.const_expr(mdPsum.element_type not in [Float32]):
133
+ raise TypeError("dPsum tensor must be Float32")
134
+ if cutlass.const_expr(mdQaccum is not None):
135
+ if cutlass.const_expr(mdQaccum.element_type not in [Float32]):
136
+ raise TypeError("dQaccum tensor must be Float32")
137
+ if cutlass.const_expr(mLSE is not None):
138
+ assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
139
+ if cutlass.const_expr(mLSE.element_type not in [Float32]):
140
+ raise TypeError("LSE tensor must be Float32")
141
+ if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
142
+ raise TypeError("LSElog2 tensor must be Float32")
143
+
144
+ mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)]
145
+
146
+ self._setup_attributes()
147
+
148
+ if cutlass.const_expr(mCuSeqlensQ is not None):
149
+ TileScheduler = SingleTileVarlenScheduler
150
+ num_head = mO.shape[1]
151
+ num_batch = mCuSeqlensQ.shape[0] - 1
152
+ else:
153
+ TileScheduler = SingleTileScheduler
154
+ num_head = mO.shape[2]
155
+ num_batch = mO.shape[0]
156
+
157
+ tile_sched_args = TileSchedulerArguments(
158
+ num_block=cute.ceil_div(mO.shape[1], self.m_block_size),
159
+ num_head=num_head,
160
+ num_batch=num_batch,
161
+ num_splits=1,
162
+ seqlen_k=0,
163
+ headdim=0,
164
+ headdim_v=mO.shape[2],
165
+ total_q=mO.shape[0],
166
+ tile_shape_mn=(self.m_block_size, 1),
167
+ mCuSeqlensQ=mCuSeqlensQ,
168
+ mSeqUsedQ=mSeqUsedQ,
169
+ )
170
+
171
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
172
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
173
+
174
+ self.kernel(
175
+ mO,
176
+ mdO,
177
+ mdPsum,
178
+ mLSE,
179
+ mLSElog2,
180
+ mdQaccum,
181
+ mCuSeqlensQ,
182
+ mSeqUsedQ,
183
+ self.gmem_tiled_copy_O,
184
+ self.gmem_tiled_copy_dQaccum,
185
+ tile_sched_params,
186
+ TileScheduler,
187
+ ).launch(
188
+ grid=grid_dim,
189
+ block=[self.num_threads, 1, 1],
190
+ stream=stream,
191
+ )
192
+
193
+ @cute.kernel
194
+ def kernel(
195
+ self,
196
+ mO: cute.Tensor,
197
+ mdO: cute.Tensor,
198
+ mdPsum: cute.Tensor,
199
+ mLSE: Optional[cute.Tensor],
200
+ mLSElog2: Optional[cute.Tensor],
201
+ mdQaccum: Optional[cute.Tensor],
202
+ mCuSeqlensQ: Optional[cute.Tensor],
203
+ mSeqUsedQ: Optional[cute.Tensor],
204
+ gmem_tiled_copy_O: cute.TiledCopy,
205
+ gmem_tiled_copy_dQaccum: cute.TiledCopy,
206
+ tile_sched_params: ParamsBase,
207
+ TileScheduler: cutlass.Constexpr[Callable],
208
+ ):
209
+ # Thread index, block index
210
+ tidx, _, _ = cute.arch.thread_idx()
211
+
212
+ tile_scheduler = TileScheduler.create(tile_sched_params)
213
+ work_tile = tile_scheduler.initial_work_tile_info()
214
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
215
+
216
+ if work_tile.is_valid_tile:
217
+ # ///////////////////////////////////////////////////////////////////////////////
218
+ # Get the appropriate tiles for this thread block.
219
+ # ///////////////////////////////////////////////////////////////////////////////
220
+ seqlen = SeqlenInfoQK.create(
221
+ batch_idx,
222
+ mO.shape[1],
223
+ 0,
224
+ mCuSeqlensQ=mCuSeqlensQ,
225
+ mCuSeqlensK=None,
226
+ mSeqUsedQ=mSeqUsedQ,
227
+ mSeqUsedK=None,
228
+ )
229
+
230
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
231
+ mO_cur = mO[batch_idx, None, head_idx, None]
232
+ mdO_cur = mdO[batch_idx, None, head_idx, None]
233
+ mdPsum_cur = mdPsum[batch_idx, head_idx, None]
234
+ headdim_v = mO.shape[3]
235
+ else:
236
+ mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
237
+ mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
238
+
239
+ padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
240
+ if cutlass.const_expr(self.arch >= 90):
241
+ padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
242
+ mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
243
+ headdim_v = mO.shape[2]
244
+
245
+ blkOdO_shape = (self.m_block_size, self.head_dim_v_padded)
246
+ # (m_block_size, head_dim_v)
247
+ gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0))
248
+ gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0))
249
+
250
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
251
+ # (CPY_Atom, CPY_M, CPY_K)
252
+ tOgO = gmem_thr_copy_O.partition_S(gO)
253
+ tOgdO = gmem_thr_copy_O.partition_S(gdO)
254
+
255
+ # ///////////////////////////////////////////////////////////////////////////////
256
+ # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
257
+ # of tile_shape
258
+ # ///////////////////////////////////////////////////////////////////////////////
259
+ # Construct identity layout for KV
260
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
261
+ tOcO = gmem_thr_copy_O.partition_S(cO)
262
+ t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
263
+ tOpO = utils.predicate_k(tOcO, limit=headdim_v)
264
+ tOpdO = utils.predicate_k(tOcO, limit=headdim_v)
265
+
266
+ seqlen_q = seqlen.seqlen_q
267
+ seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size)
268
+
269
+ if cutlass.const_expr(mLSE is not None):
270
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
271
+ mLSE_cur = mLSE[batch_idx, head_idx, None]
272
+ else:
273
+ mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
274
+
275
+ gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
276
+ lse = Float32.inf
277
+ if tidx < seqlen_q - m_block * self.m_block_size:
278
+ lse = gLSE[tidx]
279
+
280
+ tOrO = cute.make_fragment_like(tOgO)
281
+ tOrdO = cute.make_fragment_like(tOgdO)
282
+ assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
283
+ assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
284
+ assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
285
+ for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
286
+ # Instead of using tOcO, we using t0OcO and subtract the offset from the limit
287
+ # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time.
288
+ if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]:
289
+ cute.copy(
290
+ gmem_thr_copy_O,
291
+ tOgO[None, m, None],
292
+ tOrO[None, m, None],
293
+ pred=tOpO[None, m, None]
294
+ if cutlass.const_expr(self.check_hdim_v_oob)
295
+ else None,
296
+ )
297
+ cute.copy(
298
+ gmem_thr_copy_O,
299
+ tOgdO[None, m, None],
300
+ tOrdO[None, m, None],
301
+ pred=tOpdO[None, m, None]
302
+ if cutlass.const_expr(self.check_hdim_v_oob)
303
+ else None,
304
+ )
305
+ # Sum across the "k" dimension
306
+ dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
307
+ cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
308
+ )
309
+ threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
310
+ assert cute.arch.WARP_SIZE % threads_per_row == 0
311
+ dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row)
312
+ dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32)
313
+ dP_sum.store(dpsum)
314
+
315
+ # Write dPsum from rmem -> gmem
316
+ gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,))
317
+ # Only the thread corresponding to column 0 writes out the dPsum to gmem
318
+ if tOcO[0, 0, 0][1] == 0:
319
+ for m in cutlass.range(cute.size(dP_sum), unroll_full=True):
320
+ row = tOcO[0, m, 0][0]
321
+ gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0
322
+
323
+ # Clear dQaccum
324
+ if cutlass.const_expr(mdQaccum is not None):
325
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
326
+ mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
327
+ else:
328
+ mdQaccum_cur = cute.domain_offset(
329
+ (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
330
+ )
331
+
332
+ # HACK: Compiler doesn't seem to recognize that padding
333
+ # by padded_offset_q * self.head_dim_padded keeps alignment
334
+ # since statically divisible by 4
335
+
336
+ mdQaccum_cur_ptr = cute.make_ptr(
337
+ dtype=mdQaccum_cur.element_type,
338
+ value=mdQaccum_cur.iterator.toint(),
339
+ mem_space=mdQaccum_cur.iterator.memspace,
340
+ assumed_align=mdQaccum.iterator.alignment,
341
+ )
342
+ mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
343
+
344
+ blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
345
+ gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
346
+ gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
347
+ tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
348
+ zero = cute.make_fragment_like(tdQgdQaccum)
349
+ zero.fill(0.0)
350
+ cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
351
+
352
+ if cutlass.const_expr(mLSE is not None):
353
+ if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
354
+ mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
355
+ else:
356
+ mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
357
+
358
+ gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
359
+ LOG2_E = math.log2(math.e)
360
+ if tidx < seqlen_q_rounded - m_block * self.m_block_size:
361
+ gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
build/torch-cuda/flash_bwd_sm100.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/flash_bwd_sm90.py ADDED
@@ -0,0 +1,1591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Optional, Type
3
+ from functools import partial
4
+
5
+ import cuda.bindings.driver as cuda
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ import cutlass.utils.hopper_helpers as sm90_utils_basic
10
+ from cutlass.cute.nvgpu import cpasync, warpgroup
11
+ from cutlass.cute import FastDivmodDivisor
12
+ from cutlass import Float32, Int32, Boolean, const_expr
13
+ from cutlass.utils import LayoutEnum
14
+
15
+ from .quack import copy_utils
16
+ from .quack import layout_utils
17
+ from .quack import sm90_utils
18
+ from .quack.sm90_utils import gemm_zero_init, gemm_w_idx
19
+
20
+ from .cute_dsl_utils import assume_tensor_aligned
21
+ from . import utils
22
+ from .mask import AttentionMask
23
+ from .seqlen_info import SeqlenInfoQK
24
+ from .block_info import BlockInfo
25
+ from . import pipeline
26
+ from .quack.cute_dsl_utils import ParamsBase
27
+ from .tile_scheduler import TileSchedulerArguments, SingleTileScheduler
28
+ from .named_barrier import NamedBarrierBwd
29
+ from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
30
+ from .block_sparsity import BlockSparseTensors
31
+ from .block_sparse_utils import (
32
+ get_total_q_block_count_bwd,
33
+ produce_block_sparse_q_loads_bwd_sm90,
34
+ consume_block_sparse_mma_bwd_sm90,
35
+ dQaccum_store_block_sparse_bwd_sm90,
36
+ )
37
+
38
+
39
+ class FlashAttentionBackwardSm90:
40
+ arch = 90
41
+
42
+ def __init__(
43
+ self,
44
+ dtype: Type[cutlass.Numeric],
45
+ head_dim: int,
46
+ head_dim_v: Optional[int] = None,
47
+ qhead_per_kvhead: int = 1,
48
+ is_causal: bool = False,
49
+ tile_m: int = 64,
50
+ tile_n: int = 128,
51
+ Q_stage: int = 2,
52
+ dO_stage: int = 2,
53
+ PdS_stage: int = 2,
54
+ SdP_swapAB: bool = False,
55
+ dKV_swapAB: bool = False,
56
+ dQ_swapAB: bool = False,
57
+ AtomLayoutMSdP: int = 1,
58
+ AtomLayoutNdKV: int = 2,
59
+ AtomLayoutMdQ: int = 1,
60
+ num_threads: int = 384,
61
+ V_in_regs: bool = False,
62
+ score_mod: cutlass.Constexpr | None = None,
63
+ score_mod_bwd: cutlass.Constexpr | None = None,
64
+ mask_mod: cutlass.Constexpr | None = None,
65
+ has_aux_tensors: cutlass.Constexpr = False,
66
+ subtile_factor: cutlass.Constexpr[int] = 1,
67
+ ):
68
+ self.dtype = dtype
69
+ # padding head_dim to a multiple of 16 as k_block_size
70
+ hdim_multiple_of = 16
71
+ self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
72
+ head_dim_v = head_dim_v if head_dim_v is not None else head_dim
73
+ self.same_hdim_kv = head_dim == head_dim_v
74
+ self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
75
+ # Can save registers (and hence be faster) if we don't have to check hdim predication
76
+ self.check_hdim_oob = head_dim != self.tile_hdim
77
+ self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
78
+ self.qhead_per_kvhead = qhead_per_kvhead
79
+ self.is_causal = is_causal
80
+ self.is_local = False
81
+ self.tile_m = tile_m
82
+ self.tile_n = tile_n
83
+ self.num_threads = num_threads
84
+ self.Q_stage = Q_stage
85
+ self.dO_stage = dO_stage
86
+ self.PdS_stage = PdS_stage
87
+ assert self.dO_stage in [1, self.Q_stage]
88
+ assert self.PdS_stage in [1, self.Q_stage]
89
+ self.SdP_swapAB = SdP_swapAB
90
+ self.dKV_swapAB = dKV_swapAB
91
+ self.dQ_swapAB = dQ_swapAB
92
+ self.AtomLayoutMSdP = AtomLayoutMSdP
93
+ self.AtomLayoutNdKV = AtomLayoutNdKV
94
+ self.AtomLayoutMdQ = AtomLayoutMdQ
95
+ self.num_mma_warp_groups = (self.num_threads // 128) - 1
96
+ self.mma_dkv_is_rs = (
97
+ AtomLayoutMSdP == 1
98
+ and AtomLayoutNdKV == self.num_mma_warp_groups
99
+ and SdP_swapAB
100
+ and not dKV_swapAB
101
+ )
102
+ self.V_in_regs = V_in_regs
103
+ if qhead_per_kvhead > 1:
104
+ assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
105
+ assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups"
106
+ # These are tuned for speed
107
+ # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
108
+ # them and then shuffle to get the value whenever we need? This can reduce register
109
+ # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
110
+ # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
111
+ # TODO: impl these for hdim 64
112
+ self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
113
+ self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
114
+
115
+ self.buffer_align_bytes = 1024
116
+
117
+ self.score_mod = score_mod
118
+ self.score_mod_bwd = score_mod_bwd
119
+ self.mask_mod = mask_mod
120
+ self.has_aux_tensors = has_aux_tensors
121
+ self.subtile_factor = subtile_factor
122
+ if cutlass.const_expr(has_aux_tensors):
123
+ self.vec_size: cutlass.Constexpr = 1
124
+ else:
125
+ self.vec_size: cutlass.Constexpr = 4
126
+ self.qk_acc_dtype = Float32
127
+
128
+ @staticmethod
129
+ def can_implement(
130
+ dtype,
131
+ head_dim,
132
+ head_dim_v,
133
+ tile_m,
134
+ tile_n,
135
+ Q_stage,
136
+ num_threads,
137
+ V_in_regs=False,
138
+ ) -> bool:
139
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
140
+ return False
141
+ if head_dim % 8 != 0:
142
+ return False
143
+ if head_dim_v % 8 != 0:
144
+ return False
145
+ if tile_n % 16 != 0:
146
+ return False
147
+ if num_threads % 32 != 0:
148
+ return False
149
+ if (tile_m * 2) % num_threads != 0:
150
+ return False
151
+ return True
152
+
153
+ def _check_type(
154
+ self,
155
+ mQ_type: Type[cutlass.Numeric],
156
+ mK_type: Type[cutlass.Numeric],
157
+ mV_type: Type[cutlass.Numeric],
158
+ mdO_type: Type[cutlass.Numeric],
159
+ mLSE_type: Type[cutlass.Numeric],
160
+ mdPsum_type: Type[cutlass.Numeric],
161
+ mdQaccum_type: Type[cutlass.Numeric],
162
+ mdK_type: Type[cutlass.Numeric],
163
+ mdV_type: Type[cutlass.Numeric],
164
+ ):
165
+ # Get the data type and check if it is fp16 or bf16
166
+ if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):
167
+ raise TypeError("All tensors must have the same data type")
168
+ if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
169
+ raise TypeError("Only Float16 or BFloat16 is supported")
170
+ if const_expr(mLSE_type not in [Float32]):
171
+ raise TypeError("LSE tensor must be Float32")
172
+ if const_expr(mdPsum_type not in [Float32]):
173
+ raise TypeError("dPsum tensor must be Float32")
174
+ if const_expr(mdQaccum_type not in [Float32]):
175
+ raise TypeError("dQaccum tensor must be Float32")
176
+ if const_expr(self.qhead_per_kvhead == 1):
177
+ if const_expr(not (mdK_type == mdV_type == mQ_type)):
178
+ raise TypeError("mdK and mdV tensors must have the same data type as mQ")
179
+ else:
180
+ if const_expr(not (mdK_type == mdV_type == Float32)):
181
+ raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32")
182
+ assert mQ_type == self.dtype
183
+
184
+ def _setup_attributes(self):
185
+ self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [
186
+ sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage)
187
+ for shape, stage in [
188
+ ((self.tile_m, self.tile_hdim), self.Q_stage),
189
+ ((self.tile_n, self.tile_hdim), None),
190
+ ((self.tile_n, self.tile_hdimv), None),
191
+ ((self.tile_m, self.tile_hdimv), self.dO_stage),
192
+ ((self.tile_m, self.tile_n), self.PdS_stage),
193
+ ]
194
+ ]
195
+ self.sdQaccum_layout = cute.make_layout(
196
+ (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups)
197
+ )
198
+ # dQaccum R->S
199
+ self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
200
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
201
+ # thr_layout
202
+ cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
203
+ cute.make_layout(128 // Float32.width), # val_layout
204
+ )
205
+ # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
206
+ # TODO: assert that sVaccum and sKaccum don't overflow smem
207
+
208
+ def _get_tiled_mma(self):
209
+ # S = Q @ K.T, dP = dO @ V.T
210
+ atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP)
211
+ tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
212
+ tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
213
+ self.dtype,
214
+ self.dtype,
215
+ warpgroup.OperandMajorMode.K,
216
+ warpgroup.OperandMajorMode.K,
217
+ Float32,
218
+ atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1])
219
+ + (1,),
220
+ tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1],
221
+ )
222
+ # dV = P.T @ dO, dK = dS.T @ Q
223
+ atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV)
224
+ tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
225
+ tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
226
+ tiled_mma_dK, tiled_mma_dV = [
227
+ sm90_utils_basic.make_trivial_tiled_mma(
228
+ self.dtype,
229
+ self.dtype,
230
+ warpgroup.OperandMajorMode.MN
231
+ if not self.mma_dkv_is_rs
232
+ else warpgroup.OperandMajorMode.K,
233
+ warpgroup.OperandMajorMode.MN,
234
+ Float32,
235
+ atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1])
236
+ + (1,),
237
+ tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1],
238
+ a_source=warpgroup.OperandSource.RMEM
239
+ if self.mma_dkv_is_rs
240
+ else warpgroup.OperandSource.SMEM,
241
+ )
242
+ for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
243
+ ]
244
+ # dQ = dS @ K
245
+ atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ)
246
+ tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
247
+ tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
248
+ self.dtype,
249
+ self.dtype,
250
+ warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
251
+ warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
252
+ Float32,
253
+ atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,),
254
+ tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
255
+ )
256
+ return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
257
+
258
+ def _get_shared_storage_cls(self):
259
+ sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [
260
+ cute.struct.Align[cute.struct.MemRange[t, cute.cosize(layout)], self.buffer_align_bytes]
261
+ for (layout, t) in [
262
+ (self.sQ_layout, self.dtype),
263
+ (self.sK_layout, self.dtype),
264
+ (self.sV_layout, self.dtype),
265
+ (self.sdO_layout, self.dtype),
266
+ (self.sdQaccum_layout, Float32),
267
+ ]
268
+ ]
269
+
270
+ cosize_sdS = cute.cosize(self.sPdS_layout)
271
+ cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0
272
+ sLSE_struct = cute.struct.Align[
273
+ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128
274
+ ]
275
+ sdPsum_struct = cute.struct.Align[
276
+ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128
277
+ ]
278
+
279
+ @cute.struct
280
+ class SharedStorageQKV:
281
+ mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2]
282
+ mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2]
283
+ sLSE: sLSE_struct
284
+ sdPsum: sdPsum_struct
285
+ sQ: sQ_struct
286
+ sV: sV_struct
287
+ sK: sK_struct
288
+ sdO: sdO_struct
289
+ sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
290
+ sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024]
291
+ sdQaccum: sdQaccum_struct
292
+
293
+ return SharedStorageQKV
294
+
295
+ @cute.jit
296
+ def __call__(
297
+ self,
298
+ mQ: cute.Tensor,
299
+ mK: cute.Tensor,
300
+ mV: cute.Tensor,
301
+ mdO: cute.Tensor,
302
+ mLSE: cute.Tensor,
303
+ mdPsum: cute.Tensor,
304
+ mdQaccum: cute.Tensor,
305
+ mdK: cute.Tensor,
306
+ mdV: cute.Tensor,
307
+ softmax_scale: Float32,
308
+ stream: cuda.CUstream,
309
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
310
+ mCuSeqlensK: Optional[cute.Tensor] = None,
311
+ mSeqUsedQ: Optional[cute.Tensor] = None,
312
+ mSeqUsedK: Optional[cute.Tensor] = None,
313
+ softcap: Float32 | float | None = None,
314
+ window_size_left: Int32 | int | None = None,
315
+ window_size_right: Int32 | int | None = None,
316
+ mdQ_semaphore: Optional[cute.Tensor] = None,
317
+ mdK_semaphore: Optional[cute.Tensor] = None,
318
+ mdV_semaphore: Optional[cute.Tensor] = None,
319
+ aux_tensors: Optional[list] = None,
320
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
321
+ ):
322
+ assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
323
+ "determinism not supported yet for Sm90"
324
+ )
325
+
326
+ self._check_type(
327
+ *(
328
+ t.element_type if t is not None else None
329
+ for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
330
+ )
331
+ )
332
+
333
+ mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
334
+ assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
335
+ ]
336
+
337
+ layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
338
+ mQ, mK, mV, mdO = [layout_utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)]
339
+ if const_expr(self.qhead_per_kvhead == 1):
340
+ mdK, mdV = [layout_utils.select(t, layout_transpose) for t in (mdK, mdV)]
341
+ else:
342
+ accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b)
343
+ mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]
344
+ LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b)
345
+ mLSE, mdPsum, mdQaccum = [
346
+ layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
347
+ ]
348
+
349
+ tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
350
+
351
+ self.num_mma_threads = tiled_mma_SdP.size
352
+ assert self.num_mma_threads + 128 == self.num_threads
353
+
354
+ self.num_threads_per_warp_group = 128
355
+ self.num_producer_threads = 32
356
+
357
+ self.num_mma_regs = 240
358
+ self.num_producer_regs = 24
359
+ # self.num_mma_regs = 232
360
+ # self.num_producer_regs = 40
361
+
362
+ self._setup_attributes()
363
+ SharedStorage = self._get_shared_storage_cls()
364
+
365
+ self.tma_copy_bytes = {
366
+ name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
367
+ for name, mX, layout in [
368
+ ("Q", mQ, self.sQ_layout),
369
+ ("K", mK, self.sK_layout),
370
+ ("V", mV, self.sV_layout),
371
+ ("dO", mdO, self.sdO_layout),
372
+ ]
373
+ }
374
+ self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
375
+ self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
376
+ self.tma_copy_bytes["dQ"] = (
377
+ self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups
378
+ )
379
+ self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
380
+ self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
381
+
382
+ tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
383
+ cpasync.CopyBulkTensorTileG2SOp(),
384
+ mQ,
385
+ cute.select(self.sQ_layout, mode=[0, 1]),
386
+ (self.tile_m, self.tile_hdim),
387
+ )
388
+ tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
389
+ cpasync.CopyBulkTensorTileG2SOp(),
390
+ mK,
391
+ cute.select(self.sK_layout, mode=[0, 1]),
392
+ (self.tile_n, self.tile_hdim),
393
+ )
394
+ tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
395
+ cpasync.CopyBulkTensorTileG2SOp(),
396
+ mV,
397
+ cute.select(self.sV_layout, mode=[0, 1]),
398
+ (self.tile_n, self.tile_hdimv),
399
+ )
400
+ tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom(
401
+ cpasync.CopyBulkTensorTileG2SOp(),
402
+ mdO,
403
+ cute.select(self.sdO_layout, mode=[0, 1]),
404
+ (self.tile_m, self.tile_hdimv),
405
+ )
406
+ if const_expr(self.qhead_per_kvhead == 1):
407
+ tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
408
+ cpasync.CopyBulkTensorTileS2GOp(),
409
+ mdK,
410
+ cute.select(self.sK_layout, mode=[0, 1]),
411
+ (self.tile_n, self.tile_hdim),
412
+ )
413
+ tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
414
+ cpasync.CopyBulkTensorTileS2GOp(),
415
+ mdV,
416
+ cute.select(self.sV_layout, mode=[0, 1]),
417
+ (self.tile_n, self.tile_hdimv),
418
+ )
419
+ else:
420
+ tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
421
+
422
+ TileScheduler = SingleTileScheduler
423
+ tile_sched_args = TileSchedulerArguments(
424
+ cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
425
+ cute.size(mQ.shape[2]),
426
+ cute.size(mQ.shape[3]),
427
+ 1, # num_splits
428
+ cute.size(mK.shape[0]),
429
+ mQ.shape[1],
430
+ mV.shape[1],
431
+ total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
432
+ tile_shape_mn=(self.tile_m, self.tile_n),
433
+ mCuSeqlensQ=None,
434
+ mSeqUsedQ=None,
435
+ qhead_per_kvhead_packgqa=1,
436
+ element_size=self.dtype.width // 8,
437
+ is_persistent=False,
438
+ lpt=False,
439
+ )
440
+
441
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
442
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
443
+
444
+ LOG2_E = math.log2(math.e)
445
+ if const_expr(self.score_mod is None):
446
+ softmax_scale_log2 = softmax_scale * LOG2_E
447
+ else:
448
+ softmax_scale_log2 = LOG2_E
449
+
450
+ fastdiv_mods = None
451
+ if const_expr(aux_tensors is not None):
452
+ seqlen_q = cute.size(mQ.shape[0])
453
+ seqlen_k = cute.size(mK.shape[0])
454
+ seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
455
+ seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
456
+ fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
457
+
458
+ qhead_per_kvhead_divmod = None
459
+ if const_expr(self.qhead_per_kvhead > 1):
460
+ qhead_per_kvhead_divmod = FastDivmodDivisor(self.qhead_per_kvhead)
461
+
462
+ self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
463
+
464
+ self.kernel(
465
+ tma_tensor_Q,
466
+ tma_tensor_K,
467
+ tma_tensor_V,
468
+ tma_tensor_dO,
469
+ tma_tensor_dK if const_expr(self.qhead_per_kvhead == 1) else mdK,
470
+ tma_tensor_dV if const_expr(self.qhead_per_kvhead == 1) else mdV,
471
+ tma_atom_Q,
472
+ tma_atom_K,
473
+ tma_atom_V,
474
+ tma_atom_dO,
475
+ tma_atom_dK,
476
+ tma_atom_dV,
477
+ mLSE,
478
+ mdPsum,
479
+ mdQaccum,
480
+ self.sQ_layout,
481
+ self.sK_layout,
482
+ self.sV_layout,
483
+ self.sPdS_layout,
484
+ self.sdO_layout,
485
+ self.sdQaccum_layout,
486
+ self.r2s_tiled_copy_dQaccum,
487
+ tiled_mma_SdP,
488
+ tiled_mma_dK,
489
+ tiled_mma_dV,
490
+ tiled_mma_dQ,
491
+ softmax_scale_log2,
492
+ softmax_scale,
493
+ tile_sched_params,
494
+ TileScheduler,
495
+ SharedStorage,
496
+ aux_tensors,
497
+ fastdiv_mods,
498
+ blocksparse_tensors,
499
+ qhead_per_kvhead_divmod,
500
+ ).launch(
501
+ grid=grid_dim,
502
+ block=[self.num_threads, 1, 1],
503
+ stream=stream,
504
+ min_blocks_per_mp=1,
505
+ )
506
+
507
+ @cute.kernel
508
+ def kernel(
509
+ self,
510
+ mQ: cute.Tensor,
511
+ mK: cute.Tensor,
512
+ mV: cute.Tensor,
513
+ mdO: cute.Tensor,
514
+ mdK: cute.Tensor,
515
+ mdV: cute.Tensor,
516
+ tma_atom_Q: cute.CopyAtom,
517
+ tma_atom_K: cute.CopyAtom,
518
+ tma_atom_V: cute.CopyAtom,
519
+ tma_atom_dO: cute.CopyAtom,
520
+ tma_atom_dK: cute.CopyAtom,
521
+ tma_atom_dV: cute.CopyAtom,
522
+ mLSE: cute.Tensor,
523
+ mdPsum: cute.Tensor,
524
+ mdQaccum: cute.Tensor,
525
+ sQ_layout: cute.ComposedLayout,
526
+ sK_layout: cute.ComposedLayout,
527
+ sV_layout: cute.ComposedLayout,
528
+ sPdS_layout: cute.ComposedLayout,
529
+ sdO_layout: cute.ComposedLayout,
530
+ sdQaccum_layout: cute.Layout,
531
+ r2s_tiled_copy_dQaccum: cute.TiledCopy,
532
+ tiled_mma_SdP: cute.TiledMma,
533
+ tiled_mma_dK: cute.TiledMma,
534
+ tiled_mma_dV: cute.TiledMma,
535
+ tiled_mma_dQ: cute.TiledMma,
536
+ softmax_scale_log2,
537
+ softmax_scale,
538
+ tile_sched_params: ParamsBase,
539
+ TileScheduler: cutlass.Constexpr[Callable],
540
+ SharedStorage: cutlass.Constexpr[Callable],
541
+ aux_tensors: Optional[list] = None,
542
+ fastdiv_mods=(None, None),
543
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
544
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
545
+ ):
546
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
547
+
548
+ # prefetch TMA descriptors
549
+ if warp_idx == 0:
550
+ cpasync.prefetch_descriptor(tma_atom_Q)
551
+ cpasync.prefetch_descriptor(tma_atom_K)
552
+ cpasync.prefetch_descriptor(tma_atom_V)
553
+ cpasync.prefetch_descriptor(tma_atom_dO)
554
+
555
+ smem = cutlass.utils.SmemAllocator()
556
+ storage = smem.allocate(SharedStorage)
557
+
558
+ pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread)
559
+ pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(
560
+ cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
561
+ )
562
+ pipeline_Q = pipeline.PipelineTmaAsync.create(
563
+ barrier_storage=storage.mbar_ptr_Q.data_ptr(),
564
+ num_stages=self.Q_stage,
565
+ producer_group=pipeline_producer_group,
566
+ consumer_group=pipeline_consumer_group,
567
+ tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"],
568
+ defer_sync=True,
569
+ )
570
+ pipeline_dO = pipeline.PipelineTmaAsync.create(
571
+ barrier_storage=storage.mbar_ptr_dO.data_ptr(),
572
+ num_stages=self.dO_stage,
573
+ producer_group=pipeline_producer_group,
574
+ consumer_group=pipeline_consumer_group,
575
+ tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"],
576
+ defer_sync=False,
577
+ )
578
+
579
+ sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
580
+ sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner)
581
+ sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
582
+ sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
583
+ sP = None
584
+ if const_expr(not self.mma_dkv_is_rs):
585
+ sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)
586
+ sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)
587
+ sLSE = storage.sLSE.get_tensor(
588
+ cute.make_layout(
589
+ (self.tile_m, self.Q_stage),
590
+ stride=(1, cute.round_up(self.tile_m, 64)),
591
+ )
592
+ )
593
+ sdPsum = storage.sdPsum.get_tensor(
594
+ cute.make_layout(
595
+ (self.tile_m, self.dO_stage),
596
+ stride=(1, cute.round_up(self.tile_m, 64)),
597
+ )
598
+ )
599
+ sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)
600
+
601
+ block_info = BlockInfo(
602
+ self.tile_m,
603
+ self.tile_n,
604
+ self.is_causal,
605
+ self.is_local,
606
+ False, # is_split_kv
607
+ None,
608
+ None,
609
+ qhead_per_kvhead_packgqa=1,
610
+ )
611
+ SeqlenInfoCls = partial(
612
+ SeqlenInfoQK.create,
613
+ seqlen_q_static=mQ.shape[0],
614
+ seqlen_k_static=mK.shape[0],
615
+ mCuSeqlensQ=None,
616
+ mCuSeqlensK=None,
617
+ mSeqUsedQ=None,
618
+ mSeqUsedK=None,
619
+ )
620
+ AttentionMaskCls = partial(
621
+ AttentionMask,
622
+ self.tile_m,
623
+ self.tile_n,
624
+ window_size_left=None,
625
+ window_size_right=None,
626
+ swap_AB=self.SdP_swapAB,
627
+ )
628
+ TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
629
+
630
+ if warp_idx < 4:
631
+ cute.arch.setmaxregister_decrease(self.num_producer_regs)
632
+ if warp_idx == 0:
633
+ self.load(
634
+ mQ,
635
+ mK,
636
+ mV,
637
+ mdO,
638
+ mLSE,
639
+ mdPsum,
640
+ sQ,
641
+ sK,
642
+ sV,
643
+ sdO,
644
+ sLSE,
645
+ sdPsum,
646
+ tma_atom_Q,
647
+ tma_atom_K,
648
+ tma_atom_V,
649
+ tma_atom_dO,
650
+ pipeline_Q,
651
+ pipeline_dO,
652
+ block_info,
653
+ SeqlenInfoCls,
654
+ TileSchedulerCls,
655
+ blocksparse_tensors,
656
+ qhead_per_kvhead_divmod,
657
+ )
658
+ if warp_idx == 1:
659
+ self.dQaccum_store(
660
+ mdQaccum,
661
+ sdQaccum,
662
+ block_info,
663
+ TileSchedulerCls,
664
+ SeqlenInfoCls,
665
+ blocksparse_tensors,
666
+ )
667
+ else:
668
+ cute.arch.setmaxregister_increase(self.num_mma_regs)
669
+ tidx, _, _ = cute.arch.thread_idx()
670
+ tidx = tidx - 128
671
+ self.mma(
672
+ tiled_mma_SdP,
673
+ tiled_mma_dK,
674
+ tiled_mma_dV,
675
+ tiled_mma_dQ,
676
+ mdK,
677
+ mdV,
678
+ mdQaccum,
679
+ sQ,
680
+ sK,
681
+ sV,
682
+ sdO,
683
+ sP,
684
+ sdS,
685
+ sLSE,
686
+ sdPsum,
687
+ sdQaccum,
688
+ pipeline_Q,
689
+ pipeline_dO,
690
+ tidx,
691
+ tma_atom_dK,
692
+ tma_atom_dV,
693
+ r2s_tiled_copy_dQaccum,
694
+ softmax_scale_log2,
695
+ softmax_scale,
696
+ block_info,
697
+ SeqlenInfoCls,
698
+ AttentionMaskCls,
699
+ TileSchedulerCls,
700
+ aux_tensors,
701
+ fastdiv_mods,
702
+ blocksparse_tensors,
703
+ qhead_per_kvhead_divmod,
704
+ )
705
+
706
+ @cute.jit
707
+ def load(
708
+ self,
709
+ mQ: cute.Tensor,
710
+ mK: cute.Tensor,
711
+ mV: cute.Tensor,
712
+ mdO: cute.Tensor,
713
+ mLSE: cute.Tensor,
714
+ mdPsum: cute.Tensor,
715
+ sQ: cute.Tensor,
716
+ sK: cute.Tensor,
717
+ sV: cute.Tensor,
718
+ sdO: cute.Tensor,
719
+ sLSE: cute.Tensor,
720
+ sdPsum: cute.Tensor,
721
+ tma_atom_Q: cute.CopyAtom,
722
+ tma_atom_K: cute.CopyAtom,
723
+ tma_atom_V: cute.CopyAtom,
724
+ tma_atom_dO: cute.CopyAtom,
725
+ pipeline_Q: cutlass.pipeline.PipelineAsync,
726
+ pipeline_dO: cutlass.pipeline.PipelineAsync,
727
+ block_info: BlockInfo,
728
+ SeqlenInfoCls: Callable,
729
+ TileSchedulerCls: Callable,
730
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
731
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
732
+ ):
733
+ warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
734
+
735
+ if warp_idx_in_wg == 0:
736
+ producer_state_Q = cutlass.pipeline.make_pipeline_state(
737
+ cutlass.pipeline.PipelineUserType.Producer, self.Q_stage
738
+ )
739
+ producer_state_dO = cutlass.pipeline.make_pipeline_state(
740
+ cutlass.pipeline.PipelineUserType.Producer, self.dO_stage
741
+ )
742
+ tile_scheduler = TileSchedulerCls()
743
+ work_tile = tile_scheduler.initial_work_tile_info()
744
+ while work_tile.is_valid_tile:
745
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
746
+ seqlen = SeqlenInfoCls(batch_idx)
747
+ head_idx_kv = (
748
+ head_idx
749
+ if const_expr(self.qhead_per_kvhead == 1)
750
+ else head_idx // qhead_per_kvhead_divmod
751
+ )
752
+ mK_cur = mK[None, None, head_idx_kv, batch_idx]
753
+ gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
754
+ mV_cur = mV[None, None, head_idx_kv, batch_idx]
755
+ gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
756
+
757
+ mQ_cur = mQ[None, None, head_idx, batch_idx]
758
+ gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
759
+ mdO_cur = mdO[None, None, head_idx, batch_idx]
760
+ gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
761
+ mLSE_cur = mLSE[None, head_idx, batch_idx]
762
+ gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
763
+ mdPsum_cur = mdPsum[None, head_idx, batch_idx]
764
+ gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
765
+
766
+ load_K, _, _ = copy_utils.tma_get_copy_fn(
767
+ tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True
768
+ )
769
+ load_V, _, _ = copy_utils.tma_get_copy_fn(
770
+ tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True
771
+ )
772
+ load_Q, _, _ = copy_utils.tma_get_copy_fn(
773
+ tma_atom_Q, 0, cute.make_layout(1), gQ, sQ
774
+ )
775
+ load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)
776
+ load_dO, _, _ = copy_utils.tma_get_copy_fn(
777
+ tma_atom_dO, 0, cute.make_layout(1), gdO, sdO
778
+ )
779
+ load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)
780
+ load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE)
781
+ load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q)
782
+ load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum)
783
+ load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO)
784
+
785
+ m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
786
+
787
+ if const_expr(not self.use_block_sparsity):
788
+ total_m_block_cnt = m_block_max - m_block_min
789
+ process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
790
+ else:
791
+ total_m_block_cnt = get_total_q_block_count_bwd(
792
+ blocksparse_tensors,
793
+ batch_idx,
794
+ head_idx,
795
+ n_block,
796
+ subtile_factor=self.subtile_factor,
797
+ m_block_max=m_block_max,
798
+ )
799
+ process_tile = total_m_block_cnt > Int32(0)
800
+
801
+ if process_tile:
802
+ if const_expr(not self.use_block_sparsity):
803
+ first_m_block = m_block_min
804
+ pipeline_Q.producer_acquire(
805
+ producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"]
806
+ )
807
+ load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
808
+ load_Q(first_m_block, producer_state=producer_state_Q)
809
+ load_LSE(first_m_block, producer_state=producer_state_Q)
810
+ producer_state_dO_cur = (
811
+ producer_state_dO
812
+ if const_expr(self.Q_stage != self.dO_stage)
813
+ else producer_state_Q
814
+ )
815
+ pipeline_dO.producer_acquire(
816
+ producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"]
817
+ )
818
+ load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))
819
+ load_dO(first_m_block, producer_state=producer_state_dO_cur)
820
+ load_dPsum(first_m_block, producer_state=producer_state_dO_cur)
821
+ producer_state_Q.advance()
822
+ producer_state_dO.advance()
823
+
824
+ for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
825
+ pipeline_Q.producer_acquire(producer_state_Q)
826
+ load_Q(m_block, producer_state=producer_state_Q)
827
+ load_LSE(m_block, producer_state=producer_state_Q)
828
+ producer_state_dO_cur = (
829
+ producer_state_dO
830
+ if const_expr(self.Q_stage != self.dO_stage)
831
+ else producer_state_Q
832
+ )
833
+ pipeline_dO.producer_acquire(producer_state_dO_cur)
834
+ load_dO(m_block, producer_state=producer_state_dO_cur)
835
+ load_dPsum(m_block, producer_state=producer_state_dO_cur)
836
+ producer_state_Q.advance()
837
+ producer_state_dO.advance()
838
+ else:
839
+ producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90(
840
+ blocksparse_tensors,
841
+ batch_idx,
842
+ head_idx,
843
+ n_block,
844
+ producer_state_Q,
845
+ producer_state_dO,
846
+ pipeline_Q,
847
+ pipeline_dO,
848
+ load_K,
849
+ load_V,
850
+ load_Q,
851
+ load_dO,
852
+ load_LSE,
853
+ load_dPsum,
854
+ self.tma_copy_bytes["K"],
855
+ self.tma_copy_bytes["V"],
856
+ Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage),
857
+ subtile_factor=self.subtile_factor,
858
+ m_block_max=m_block_max,
859
+ )
860
+
861
+ tile_scheduler.prefetch_next_work()
862
+ tile_scheduler.advance_to_next_work()
863
+ work_tile = tile_scheduler.get_current_work()
864
+
865
+ @cute.jit
866
+ def apply_score_mod(
867
+ self,
868
+ acc_S: cute.Tensor,
869
+ thr_mma_SdP: cute.core.ThrMma,
870
+ batch_idx,
871
+ head_idx,
872
+ m_block,
873
+ n_block,
874
+ softmax_scale,
875
+ seqlen_info: SeqlenInfoQK,
876
+ aux_tensors=None,
877
+ fastdiv_mods=(None, None),
878
+ ):
879
+ # [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing
880
+ cS = cute.make_identity_tensor(
881
+ (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
882
+ )
883
+ cS = cute.domain_offset(
884
+ (n_block * self.tile_n, m_block * self.tile_m)
885
+ if self.SdP_swapAB
886
+ else (m_block * self.tile_m, n_block * self.tile_n),
887
+ cS,
888
+ )
889
+ tScS = thr_mma_SdP.partition_C(cS)
890
+
891
+ apply_score_mod_inner(
892
+ acc_S,
893
+ tScS,
894
+ self.score_mod,
895
+ batch_idx,
896
+ head_idx,
897
+ softmax_scale,
898
+ self.vec_size,
899
+ self.qk_acc_dtype,
900
+ aux_tensors,
901
+ fastdiv_mods,
902
+ seqlen_info,
903
+ constant_q_idx=None,
904
+ qhead_per_kvhead=self.qhead_per_kvhead,
905
+ transpose_indices=self.SdP_swapAB,
906
+ )
907
+
908
+ @cute.jit
909
+ def apply_score_mod_bwd(
910
+ self,
911
+ grad_tensor: cute.Tensor,
912
+ score_tensor: cute.Tensor,
913
+ thr_mma_SdP: cute.core.ThrMma,
914
+ batch_idx,
915
+ head_idx,
916
+ m_block,
917
+ n_block,
918
+ softmax_scale,
919
+ seqlen_info: SeqlenInfoQK,
920
+ aux_tensors=None,
921
+ fastdiv_mods=(None, None),
922
+ ):
923
+ cS = cute.make_identity_tensor(
924
+ (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)
925
+ )
926
+ cS = cute.domain_offset(
927
+ (n_block * self.tile_n, m_block * self.tile_m)
928
+ if self.SdP_swapAB
929
+ else (m_block * self.tile_m, n_block * self.tile_n),
930
+ cS,
931
+ )
932
+ tScS = thr_mma_SdP.partition_C(cS)
933
+
934
+ apply_score_mod_bwd_inner(
935
+ grad_tensor,
936
+ score_tensor,
937
+ tScS,
938
+ self.score_mod_bwd,
939
+ batch_idx,
940
+ head_idx,
941
+ softmax_scale,
942
+ self.vec_size,
943
+ self.qk_acc_dtype,
944
+ aux_tensors,
945
+ fastdiv_mods,
946
+ seqlen_info,
947
+ constant_q_idx=None,
948
+ qhead_per_kvhead=self.qhead_per_kvhead,
949
+ transpose_indices=self.SdP_swapAB,
950
+ )
951
+
952
+ @cute.jit
953
+ def mma(
954
+ self,
955
+ tiled_mma_SdP: cute.TiledMma,
956
+ tiled_mma_dK: cute.TiledMma,
957
+ tiled_mma_dV: cute.TiledMma,
958
+ tiled_mma_dQ: cute.TiledMma,
959
+ mdK: cute.Tensor,
960
+ mdV: cute.Tensor,
961
+ mdQaccum: cute.Tensor,
962
+ sQ: cute.Tensor,
963
+ sK: cute.Tensor,
964
+ sV: cute.Tensor,
965
+ sdO: cute.Tensor,
966
+ sP: Optional[cute.Tensor],
967
+ sdS: cute.Tensor,
968
+ sLSE: cute.Tensor,
969
+ sdPsum: cute.Tensor,
970
+ sdQaccum: cute.Tensor,
971
+ pipeline_Q: cutlass.pipeline.PipelineAsync,
972
+ pipeline_dO: cutlass.pipeline.PipelineAsync,
973
+ tidx: Int32,
974
+ tma_atom_dK: cute.CopyAtom,
975
+ tma_atom_dV: cute.CopyAtom,
976
+ r2s_tiled_copy_dQaccum: cute.TiledCopy,
977
+ softmax_scale_log2: Float32,
978
+ softmax_scale: Float32,
979
+ block_info: BlockInfo,
980
+ SeqlenInfoCls: Callable,
981
+ AttentionMaskCls: Callable,
982
+ TileSchedulerCls: Callable,
983
+ aux_tensors: Optional[list] = None,
984
+ fastdiv_mods=(None, None),
985
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
986
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
987
+ ):
988
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
989
+ warp_group_thread_layout = cute.make_layout(
990
+ self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
991
+ )
992
+ thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
993
+ wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
994
+ wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
995
+ wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
996
+ wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx))
997
+ # S = Q @ K.T
998
+ shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
999
+ _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
1000
+ wg_mma_SdP, shape_mnk_S, sQ, sK, swap_AB=self.SdP_swapAB
1001
+ )
1002
+ mma_qk_fn = partial(
1003
+ gemm_zero_init, tiled_mma_SdP, shape_mnk_S[:2], tSrQ, tSrK, swap_AB=self.SdP_swapAB
1004
+ )
1005
+ # dP = dO @ V.T
1006
+ shape_mnk_dP = (self.tile_m, self.tile_n, self.tile_hdimv)
1007
+ _, tdPrdO, tdPrV = sm90_utils.partition_fragment_ABC(
1008
+ wg_mma_SdP, shape_mnk_dP, sdO, sV, swap_AB=self.SdP_swapAB
1009
+ )
1010
+ mma_dov_fn = partial(
1011
+ gemm_zero_init, tiled_mma_SdP, shape_mnk_dP[:2], tdPrdO, tdPrV, swap_AB=self.SdP_swapAB
1012
+ )
1013
+ # dV += P.T @ dO
1014
+ sPt = layout_utils.transpose_view(sP) if sP is not None else None
1015
+ sdOt = layout_utils.transpose_view(sdO)
1016
+ shape_mnk_dV = (self.tile_n, self.tile_hdimv, self.tile_m)
1017
+ acc_dV, tdVrPt, tdVrdOt = sm90_utils.partition_fragment_ABC(
1018
+ wg_mma_dV, shape_mnk_dV, sPt, sdOt, swap_AB=self.dKV_swapAB
1019
+ )
1020
+ if const_expr(not self.mma_dkv_is_rs):
1021
+ mma_pdo_fn = partial(
1022
+ gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB
1023
+ )
1024
+ else:
1025
+ mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt)
1026
+ # dK += dS.T @ Q
1027
+ sdSt = layout_utils.transpose_view(sdS)
1028
+ sQt = layout_utils.transpose_view(sQ)
1029
+ shape_mnk_dK = (self.tile_n, self.tile_hdim, self.tile_m)
1030
+ acc_dK, tdKrdSt, tdKrQt = sm90_utils.partition_fragment_ABC(
1031
+ wg_mma_dK, shape_mnk_dK, sdSt, sQt, swap_AB=self.dKV_swapAB
1032
+ )
1033
+ if const_expr(not self.mma_dkv_is_rs):
1034
+ mma_dsq_fn = partial(
1035
+ gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB
1036
+ )
1037
+ else:
1038
+ mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt)
1039
+ # dQ = dS @ K
1040
+ sKt = layout_utils.transpose_view(sK)
1041
+ shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
1042
+ _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(
1043
+ wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB
1044
+ )
1045
+ mma_dsk_fn = partial(
1046
+ gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB
1047
+ )
1048
+
1049
+ # Smem copy atom tiling
1050
+ copy_P_r2s = None
1051
+ if const_expr(sP is not None):
1052
+ sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt
1053
+ copy_P_r2s, _, _ = copy_utils.get_smem_store_C(
1054
+ tiled_mma_SdP, sP_cpy, tidx, self.arch, transpose=self.SdP_swapAB
1055
+ )
1056
+ sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt
1057
+ copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(
1058
+ tiled_mma_SdP, sdS_cpy, tidx, self.arch, transpose=self.SdP_swapAB
1059
+ )
1060
+
1061
+ tLSEsLSE = layout_utils.mma_partition_C_vec(
1062
+ sLSE, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
1063
+ )
1064
+ tLSEsdPsum = layout_utils.mma_partition_C_vec(
1065
+ sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
1066
+ )
1067
+
1068
+ smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
1069
+ tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
1070
+
1071
+ PdS_barrier = cutlass.pipeline.NamedBarrier(
1072
+ barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
1073
+ )
1074
+ score_mod_fn = partial(
1075
+ self.apply_score_mod,
1076
+ thr_mma_SdP=thr_mma_SdP,
1077
+ softmax_scale=softmax_scale,
1078
+ aux_tensors=aux_tensors,
1079
+ fastdiv_mods=fastdiv_mods,
1080
+ )
1081
+ score_mod_bwd_fn = partial(
1082
+ self.apply_score_mod_bwd,
1083
+ thr_mma_SdP=thr_mma_SdP,
1084
+ softmax_scale=softmax_scale,
1085
+ aux_tensors=aux_tensors,
1086
+ fastdiv_mods=fastdiv_mods,
1087
+ )
1088
+
1089
+ mma_one_m_block_all = partial(
1090
+ self.mma_one_m_block,
1091
+ warp_group_idx=warp_group_idx,
1092
+ mma_qk_fn=mma_qk_fn,
1093
+ mma_dov_fn=mma_dov_fn,
1094
+ mma_pdo_fn=mma_pdo_fn,
1095
+ mma_dsq_fn=mma_dsq_fn,
1096
+ mma_dsk_fn=mma_dsk_fn,
1097
+ copy_P_r2s=copy_P_r2s,
1098
+ copy_dS_r2s=copy_dS_r2s,
1099
+ pipeline_Q=pipeline_Q,
1100
+ pipeline_dO=pipeline_dO,
1101
+ tLSEsLSE=tLSEsLSE,
1102
+ tLSEsdPsum=tLSEsdPsum,
1103
+ tdQsdQaccum=tdQsdQaccum,
1104
+ softmax_scale_log2=softmax_scale_log2,
1105
+ PdS_barrier=PdS_barrier,
1106
+ # acc_dV=acc_dV,
1107
+ # acc_dK=acc_dK,
1108
+ )
1109
+
1110
+ consumer_state_Q = cutlass.pipeline.make_pipeline_state(
1111
+ cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage
1112
+ )
1113
+ consumer_state_dO = cutlass.pipeline.make_pipeline_state(
1114
+ cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
1115
+ )
1116
+ tile_scheduler = TileSchedulerCls()
1117
+ work_tile = tile_scheduler.initial_work_tile_info()
1118
+ while work_tile.is_valid_tile:
1119
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1120
+ seqlen = SeqlenInfoCls(batch_idx)
1121
+ mask = AttentionMaskCls(seqlen)
1122
+ score_mod_fn_cur = partial(
1123
+ score_mod_fn,
1124
+ batch_idx=batch_idx,
1125
+ head_idx=head_idx,
1126
+ n_block=n_block,
1127
+ seqlen_info=seqlen,
1128
+ )
1129
+ score_mod_bwd_fn_cur = partial(
1130
+ score_mod_bwd_fn,
1131
+ batch_idx=batch_idx,
1132
+ head_idx=head_idx,
1133
+ n_block=n_block,
1134
+ seqlen_info=seqlen,
1135
+ )
1136
+ m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1137
+
1138
+ if const_expr(not self.use_block_sparsity):
1139
+ process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
1140
+ else:
1141
+ total_m_block_cnt = get_total_q_block_count_bwd(
1142
+ blocksparse_tensors,
1143
+ batch_idx,
1144
+ head_idx,
1145
+ n_block,
1146
+ subtile_factor=self.subtile_factor,
1147
+ m_block_max=m_block_max,
1148
+ )
1149
+ process_tile = total_m_block_cnt > Int32(0)
1150
+
1151
+ if process_tile:
1152
+ if const_expr(not self.use_block_sparsity):
1153
+ mask_fn = partial(
1154
+ mask.apply_mask,
1155
+ batch_idx=batch_idx,
1156
+ head_idx=head_idx,
1157
+ n_block=n_block,
1158
+ thr_mma=thr_mma_SdP,
1159
+ mask_seqlen=True,
1160
+ mask_causal=self.is_causal,
1161
+ mask_local=self.is_local,
1162
+ mask_mod=self.mask_mod,
1163
+ aux_tensors=aux_tensors,
1164
+ fastdiv_mods=fastdiv_mods,
1165
+ )
1166
+ dKV_accumulate = False
1167
+ for m_block in cutlass.range(m_block_min, m_block_max, unroll=1):
1168
+ consumer_state_Q, consumer_state_dO = mma_one_m_block_all(
1169
+ m_block,
1170
+ consumer_state_Q,
1171
+ consumer_state_dO,
1172
+ mask_fn=mask_fn,
1173
+ score_mod_fn=score_mod_fn_cur,
1174
+ score_mod_bwd_fn=score_mod_bwd_fn_cur,
1175
+ dKV_accumulate=dKV_accumulate,
1176
+ )
1177
+ dKV_accumulate = True
1178
+ else:
1179
+ consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90(
1180
+ blocksparse_tensors,
1181
+ batch_idx,
1182
+ head_idx,
1183
+ n_block,
1184
+ consumer_state_Q,
1185
+ consumer_state_dO,
1186
+ mma_one_m_block_all,
1187
+ mask,
1188
+ self.mask_mod,
1189
+ is_causal=self.is_causal,
1190
+ is_local=self.is_local,
1191
+ thr_mma_SdP=thr_mma_SdP,
1192
+ score_mod_fn=score_mod_fn_cur,
1193
+ score_mod_bwd_fn=score_mod_bwd_fn_cur,
1194
+ subtile_factor=self.subtile_factor,
1195
+ m_block_max=m_block_max,
1196
+ aux_tensors=aux_tensors,
1197
+ fastdiv_mods=fastdiv_mods,
1198
+ )
1199
+
1200
+ if const_expr(self.qhead_per_kvhead == 1):
1201
+ acc_dK.store(acc_dK.load() * softmax_scale)
1202
+ self.epilogue_dKV(
1203
+ acc_dV,
1204
+ mdV,
1205
+ sV,
1206
+ acc_dK,
1207
+ mdK,
1208
+ sK,
1209
+ seqlen,
1210
+ tma_atom_dK,
1211
+ tma_atom_dV,
1212
+ tiled_mma_dK,
1213
+ tiled_mma_dV,
1214
+ tidx,
1215
+ n_block,
1216
+ head_idx,
1217
+ batch_idx,
1218
+ qhead_per_kvhead_divmod,
1219
+ )
1220
+ else:
1221
+ # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros.
1222
+ if const_expr(self.use_block_sparsity):
1223
+ acc_dK.fill(0.0)
1224
+ acc_dV.fill(0.0)
1225
+ self.epilogue_dKV(
1226
+ acc_dV,
1227
+ mdV,
1228
+ sV,
1229
+ acc_dK,
1230
+ mdK,
1231
+ sK,
1232
+ seqlen,
1233
+ tma_atom_dK,
1234
+ tma_atom_dV,
1235
+ tiled_mma_dK,
1236
+ tiled_mma_dV,
1237
+ tidx,
1238
+ n_block,
1239
+ head_idx,
1240
+ batch_idx,
1241
+ qhead_per_kvhead_divmod,
1242
+ )
1243
+
1244
+ tile_scheduler.advance_to_next_work()
1245
+ work_tile = tile_scheduler.get_current_work()
1246
+
1247
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1248
+ if warp_idx == 4:
1249
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1250
+
1251
+ @cute.jit
1252
+ def mma_one_m_block(
1253
+ self,
1254
+ m_block: Int32,
1255
+ consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
1256
+ consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
1257
+ warp_group_idx: Int32,
1258
+ mma_qk_fn: Callable,
1259
+ mma_dov_fn: Callable,
1260
+ mma_pdo_fn: Callable,
1261
+ mma_dsq_fn: Callable,
1262
+ mma_dsk_fn: Callable,
1263
+ copy_P_r2s: Optional[Callable],
1264
+ copy_dS_r2s: Callable,
1265
+ pipeline_Q: cutlass.pipeline.PipelineAsync,
1266
+ pipeline_dO: cutlass.pipeline.PipelineAsync,
1267
+ tLSEsLSE: cute.Tensor,
1268
+ tLSEsdPsum: cute.Tensor,
1269
+ tdQsdQaccum: cute.Tensor,
1270
+ softmax_scale_log2: Float32,
1271
+ PdS_barrier: cutlass.pipeline.NamedBarrier,
1272
+ mask_fn: Optional[Callable] = None,
1273
+ score_mod_fn: Optional[Callable] = None,
1274
+ score_mod_bwd_fn: Optional[Callable] = None,
1275
+ dKV_accumulate: Boolean = True,
1276
+ ):
1277
+ consumer_state_dO_cur = (
1278
+ consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
1279
+ )
1280
+ smem_idx_Q = consumer_state_Q.index
1281
+ smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
1282
+ smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0
1283
+ # (1) [GEMM 1] S = Q @ K^T
1284
+ pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
1285
+ acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
1286
+ tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
1287
+ # (2) [GEMM 2] dP = dO @ V.T
1288
+ pipeline_dO.consumer_wait(
1289
+ consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur)
1290
+ )
1291
+ acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)
1292
+
1293
+ if const_expr(self.score_mod_bwd is not None):
1294
+ acc_S_pre = cute.make_fragment_like(acc_S)
1295
+ cute.autovec_copy(acc_S, acc_S_pre)
1296
+
1297
+ if const_expr(self.score_mod is not None):
1298
+ score_mod_fn(acc_S, m_block=m_block)
1299
+
1300
+ # (3) [Pointwise 1] P = exp(S - LSE)
1301
+ if cutlass.const_expr(mask_fn is not None):
1302
+ mask_fn(acc_S, m_block=m_block)
1303
+ acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
1304
+ for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
1305
+ for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
1306
+ acc_S_mn[r, c] = cute.math.exp2(
1307
+ acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True
1308
+ )
1309
+ tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
1310
+
1311
+ # Convert P from f32 -> f16
1312
+ tdVrP = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_S), self.dtype)
1313
+ # R2S for P
1314
+ if const_expr(not self.mma_dkv_is_rs):
1315
+ # sync to ensure P has already been used in the previous iteration before overwriting
1316
+ if const_expr(self.PdS_stage == 1):
1317
+ PdS_barrier.arrive_and_wait()
1318
+ copy_P_r2s(tdVrP, dst_idx=smem_idx_PdS)
1319
+
1320
+ # (4) [Pointwise 2] dS = P*(dP-dPsum)
1321
+ warpgroup.wait_group(0)
1322
+ acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
1323
+ for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
1324
+ for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
1325
+ acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
1326
+
1327
+ if const_expr(self.score_mod_bwd is not None):
1328
+ score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
1329
+
1330
+ # Convert dS from f32 -> f16
1331
+ tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)
1332
+
1333
+ # If there's double buffering on dS, we don't need to sync here.
1334
+ # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
1335
+ # But because both WGs have to sync at the end of the loop and double buffering,
1336
+ # this race condition is not possible.
1337
+ # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and
1338
+ # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
1339
+ if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
1340
+ cute.arch.fence_view_async_shared()
1341
+ PdS_barrier.arrive_and_wait()
1342
+
1343
+ # R2S for dS
1344
+ copy_dS_r2s(tdKrdS, dst_idx=smem_idx_PdS)
1345
+
1346
+ # (5) [GEMM 3] dV += P.T @ dO
1347
+ if const_expr(not self.mma_dkv_is_rs):
1348
+ mma_pdo_fn(
1349
+ A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1
1350
+ )
1351
+ else:
1352
+ mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1)
1353
+
1354
+ # smem fence to make sure sdS is written before it's read by WGMMA
1355
+ cute.arch.fence_view_async_shared()
1356
+ PdS_barrier.arrive_and_wait()
1357
+ # (6) [GEMM 4] dQ = dS @ K
1358
+ acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
1359
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)
1360
+ pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
1361
+
1362
+ # (7) [GEMM 5] dK += dS.T @ Q
1363
+ if const_expr(not self.mma_dkv_is_rs):
1364
+ mma_dsq_fn(
1365
+ A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
1366
+ )
1367
+ else:
1368
+ mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
1369
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ)
1370
+
1371
+ cute.arch.barrier(
1372
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1373
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1374
+ )
1375
+ tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape))
1376
+ cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
1377
+ cute.arch.fence_view_async_shared()
1378
+ cute.arch.barrier_arrive(
1379
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1380
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1381
+ )
1382
+
1383
+ warpgroup.wait_group(0)
1384
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK)
1385
+ pipeline_Q.consumer_release(consumer_state_Q)
1386
+ # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block)
1387
+
1388
+ consumer_state_Q.advance()
1389
+ consumer_state_dO.advance()
1390
+ return consumer_state_Q, consumer_state_dO
1391
+
1392
+ @cute.jit
1393
+ def epilogue_dKV(
1394
+ self,
1395
+ acc_dV: cute.Tensor,
1396
+ mdV: cute.Tensor,
1397
+ sV: cute.Tensor,
1398
+ acc_dK: cute.Tensor,
1399
+ mdK: cute.Tensor,
1400
+ sK: cute.Tensor,
1401
+ seqlen: SeqlenInfoQK,
1402
+ tma_atom_dK: cute.CopyAtom,
1403
+ tma_atom_dV: cute.CopyAtom,
1404
+ tiled_mma_dK: cute.TiledMma,
1405
+ tiled_mma_dV: cute.TiledMma,
1406
+ tidx: Int32,
1407
+ n_block: Int32,
1408
+ head_idx: Int32,
1409
+ batch_idx: Int32,
1410
+ qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
1411
+ ):
1412
+ epi_barrier = cutlass.pipeline.NamedBarrier(
1413
+ barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads
1414
+ )
1415
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1416
+
1417
+ if const_expr(self.qhead_per_kvhead == 1):
1418
+ mdV_cur = mdV[None, None, head_idx, batch_idx]
1419
+ mdK_cur = mdK[None, None, head_idx, batch_idx]
1420
+ gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
1421
+ gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
1422
+ store_dK, _, _ = copy_utils.tma_get_copy_fn(
1423
+ tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True
1424
+ )
1425
+ store_dV, _, _ = copy_utils.tma_get_copy_fn(
1426
+ tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True
1427
+ )
1428
+ sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)
1429
+ sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)
1430
+ copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(
1431
+ tiled_mma_dV, sdV, tidx, self.arch, transpose=self.dKV_swapAB
1432
+ )
1433
+ copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(
1434
+ tiled_mma_dK, sdK, tidx, self.arch, transpose=self.dKV_swapAB
1435
+ )
1436
+ cute.arch.cp_async_bulk_wait_group(1, read=True)
1437
+ epi_barrier.arrive_and_wait()
1438
+ copy_dV_r2s(acc_dV, dst_idx=None)
1439
+ cute.arch.fence_view_async_shared()
1440
+ epi_barrier.arrive_and_wait()
1441
+ if warp_idx == 4:
1442
+ store_dV()
1443
+ cute.arch.cp_async_bulk_commit_group()
1444
+ cute.arch.cp_async_bulk_wait_group(1, read=True)
1445
+ epi_barrier.arrive_and_wait()
1446
+ copy_dK_r2s(acc_dK, dst_idx=None)
1447
+ cute.arch.fence_view_async_shared()
1448
+ epi_barrier.arrive_and_wait()
1449
+ if warp_idx == 4:
1450
+ store_dK()
1451
+ cute.arch.cp_async_bulk_commit_group()
1452
+ else:
1453
+ sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_mma_warp_groups
1454
+ sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_mma_warp_groups
1455
+ sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_mma_warp_groups))
1456
+ sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_mma_warp_groups))
1457
+ head_idx_kv = head_idx // qhead_per_kvhead_divmod
1458
+ mdKaccum_cur = mdK[None, head_idx_kv, batch_idx]
1459
+ gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
1460
+ gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))
1461
+ mdVaccum_cur = mdV[None, head_idx_kv, batch_idx]
1462
+ gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
1463
+ gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))
1464
+ # These two overlap each other
1465
+ sVaccum_ptr = cute.recast_ptr(sV.iterator, dtype=Float32)
1466
+ sdKaccum = cute.make_tensor(sVaccum_ptr, sdKaccum_layout)
1467
+ sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)
1468
+ tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(
1469
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
1470
+ cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
1471
+ cute.make_layout(128 // Float32.width),
1472
+ )
1473
+ thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)
1474
+ tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum)
1475
+ tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum)
1476
+
1477
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1478
+ epi_barrier.arrive_and_wait()
1479
+ tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape)
1480
+ cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum)
1481
+ cute.arch.fence_view_async_shared()
1482
+ epi_barrier.arrive_and_wait()
1483
+ if warp_idx == 4:
1484
+ with cute.arch.elect_one():
1485
+ for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1486
+ copy_utils.cpasync_reduce_bulk_add_f32(
1487
+ sdKaccum[None, wg_idx].iterator,
1488
+ gdKaccum[None, wg_idx].iterator,
1489
+ self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups,
1490
+ )
1491
+ cute.arch.cp_async_bulk_commit_group()
1492
+
1493
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
1494
+ epi_barrier.arrive_and_wait()
1495
+ tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape)
1496
+ cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum)
1497
+ cute.arch.fence_view_async_shared()
1498
+ epi_barrier.arrive_and_wait()
1499
+ if warp_idx == 4:
1500
+ with cute.arch.elect_one():
1501
+ for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1502
+ copy_utils.cpasync_reduce_bulk_add_f32(
1503
+ sdVaccum[None, wg_idx].iterator,
1504
+ gdVaccum[None, wg_idx].iterator,
1505
+ self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups,
1506
+ )
1507
+ cute.arch.cp_async_bulk_commit_group()
1508
+
1509
+ @cute.jit
1510
+ def dQaccum_store(
1511
+ self,
1512
+ mdQaccum: cute.Tensor,
1513
+ sdQaccum: cute.Tensor,
1514
+ block_info: BlockInfo,
1515
+ TileSchedulerCls: cutlass.Constexpr[Callable],
1516
+ SeqlenInfoCls: cutlass.Constexpr[Callable],
1517
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
1518
+ ):
1519
+ tile_scheduler = TileSchedulerCls()
1520
+ work_tile = tile_scheduler.initial_work_tile_info()
1521
+ while work_tile.is_valid_tile:
1522
+ n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1523
+ seqlen = SeqlenInfoCls(batch_idx)
1524
+ mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
1525
+ gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
1526
+ # (M * K / WG, WG, _)
1527
+ gdQaccum = cute.flat_divide(
1528
+ gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,)
1529
+ )
1530
+ m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1531
+ if const_expr(not self.use_block_sparsity):
1532
+ process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
1533
+ loop_count = m_block_max - m_block_min
1534
+ else:
1535
+ total_block_cnt = get_total_q_block_count_bwd(
1536
+ blocksparse_tensors,
1537
+ batch_idx,
1538
+ head_idx,
1539
+ n_block,
1540
+ subtile_factor=self.subtile_factor,
1541
+ m_block_max=m_block_max,
1542
+ )
1543
+ process_tile = total_block_cnt > Int32(0)
1544
+
1545
+ if process_tile:
1546
+ if const_expr(not self.use_block_sparsity):
1547
+ for iter_idx in cutlass.range(loop_count, unroll=1):
1548
+ m_block = m_block_min + iter_idx
1549
+ m_block_safe = m_block
1550
+
1551
+ for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1552
+ cute.arch.cp_async_bulk_wait_group(
1553
+ self.num_mma_warp_groups - 1 - warp_group_idx, read=True
1554
+ )
1555
+ cute.arch.barrier_arrive(
1556
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1557
+ number_of_threads=self.num_threads_per_warp_group
1558
+ + cute.arch.WARP_SIZE,
1559
+ )
1560
+
1561
+ for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1562
+ cute.arch.barrier(
1563
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1564
+ number_of_threads=self.num_threads_per_warp_group
1565
+ + cute.arch.WARP_SIZE,
1566
+ )
1567
+ with cute.arch.elect_one():
1568
+ copy_utils.cpasync_reduce_bulk_add_f32(
1569
+ sdQaccum[None, warp_group_idx].iterator,
1570
+ gdQaccum[None, warp_group_idx, m_block_safe].iterator,
1571
+ self.tma_copy_bytes["dQ"],
1572
+ )
1573
+ cute.arch.cp_async_bulk_commit_group()
1574
+ else:
1575
+ dQaccum_store_block_sparse_bwd_sm90(
1576
+ blocksparse_tensors,
1577
+ batch_idx,
1578
+ head_idx,
1579
+ n_block,
1580
+ sdQaccum,
1581
+ gdQaccum,
1582
+ subtile_factor=self.subtile_factor,
1583
+ m_block_max=m_block_max,
1584
+ num_mma_warp_groups=self.num_mma_warp_groups,
1585
+ num_threads_per_warp_group=self.num_threads_per_warp_group,
1586
+ tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
1587
+ )
1588
+ tile_scheduler.advance_to_next_work()
1589
+ work_tile = tile_scheduler.get_current_work()
1590
+
1591
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
build/torch-cuda/flash_fwd.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/flash_fwd_combine.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h
3
+ # from Cutlass C++ to Cute-DSL.
4
+ import math
5
+ from typing import Type, Optional
6
+ from functools import partial
7
+
8
+ import cuda.bindings.driver as cuda
9
+
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+ from cutlass.cute.nvgpu import cpasync
13
+ from cutlass import Float32, Int32, const_expr
14
+
15
+ from . import utils
16
+ from .cute_dsl_utils import assume_tensor_aligned
17
+ from .seqlen_info import SeqlenInfo
18
+ from cutlass.cute import FastDivmodDivisor
19
+
20
+
21
+ class FlashAttentionForwardCombine:
22
+ def __init__(
23
+ self,
24
+ dtype: Type[cutlass.Numeric],
25
+ dtype_partial: Type[cutlass.Numeric],
26
+ head_dim: int,
27
+ m_block_size: int = 8,
28
+ k_block_size: int = 64,
29
+ log_max_splits: int = 4,
30
+ num_threads: int = 256,
31
+ stages: int = 4,
32
+ ):
33
+ """
34
+ Forward combine kernel for split attention computation.
35
+
36
+ :param dtype: output data type
37
+ :param dtype_partial: partial accumulation data type
38
+ :param head_dim: head dimension
39
+ :param m_block_size: m block size
40
+ :param k_block_size: k block size
41
+ :param log_max_splits: log2 of maximum splits
42
+ :param num_threads: number of threads
43
+ :param varlen: whether using variable length sequences
44
+ :param stages: number of pipeline stages
45
+ """
46
+ self.dtype = dtype
47
+ self.dtype_partial = dtype_partial
48
+ self.head_dim = head_dim
49
+ self.m_block_size = m_block_size
50
+ self.k_block_size = k_block_size
51
+ self.max_splits = 1 << log_max_splits
52
+ self.num_threads = num_threads
53
+ self.is_even_k = head_dim % k_block_size == 0
54
+ self.stages = stages
55
+
56
+ @staticmethod
57
+ def can_implement(
58
+ dtype,
59
+ dtype_partial,
60
+ head_dim,
61
+ m_block_size,
62
+ k_block_size,
63
+ log_max_splits,
64
+ num_threads,
65
+ ) -> bool:
66
+ """Check if the kernel can be implemented with the given parameters."""
67
+ if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
68
+ return False
69
+ if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]:
70
+ return False
71
+ if head_dim % 8 != 0:
72
+ return False
73
+ if num_threads % 32 != 0:
74
+ return False
75
+ if m_block_size % 8 != 0:
76
+ return False
77
+ max_splits = 1 << log_max_splits
78
+ if max_splits > 256:
79
+ return False
80
+ if (m_block_size * max_splits) % num_threads != 0:
81
+ return False
82
+ return True
83
+
84
+ def _setup_attributes(self):
85
+ # GMEM copy setup for O partial
86
+ universal_copy_bits = 128
87
+ async_copy_elems = universal_copy_bits // self.dtype_partial.width
88
+ assert self.k_block_size % async_copy_elems == 0
89
+
90
+ k_block_gmem = (
91
+ 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
92
+ )
93
+ gmem_threads_per_row = k_block_gmem // async_copy_elems
94
+ assert self.num_threads % gmem_threads_per_row == 0
95
+
96
+ # Async copy atom for O partial load
97
+ atom_async_copy_partial = cute.make_copy_atom(
98
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
99
+ self.dtype_partial,
100
+ num_bits_per_copy=universal_copy_bits,
101
+ )
102
+ tOpartial_layout = cute.make_ordered_layout(
103
+ (self.num_threads // gmem_threads_per_row, gmem_threads_per_row),
104
+ order=(1, 0),
105
+ )
106
+ vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load
107
+ self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(
108
+ atom_async_copy_partial, tOpartial_layout, vOpartial_layout
109
+ )
110
+
111
+ # GMEM copy setup for final O (use universal copy for store)
112
+ atom_universal_copy = cute.make_copy_atom(
113
+ cute.nvgpu.CopyUniversalOp(),
114
+ self.dtype,
115
+ num_bits_per_copy=async_copy_elems * self.dtype.width,
116
+ )
117
+ self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
118
+ atom_universal_copy,
119
+ tOpartial_layout,
120
+ vOpartial_layout, # 4 vals per store
121
+ )
122
+
123
+ # LSE copy setup with async copy (alignment = 1)
124
+ lse_copy_bits = Float32.width # 1 element per copy, width is in bits
125
+ m_block_smem = (
126
+ 128
127
+ if self.m_block_size % 128 == 0
128
+ else (
129
+ 64
130
+ if self.m_block_size % 64 == 0
131
+ else (
132
+ 32
133
+ if self.m_block_size % 32 == 0
134
+ else (16 if self.m_block_size % 16 == 0 else 8)
135
+ )
136
+ )
137
+ )
138
+ gmem_threads_per_row_lse = m_block_smem
139
+ assert self.num_threads % gmem_threads_per_row_lse == 0
140
+
141
+ # Async copy atom for LSE load
142
+ atom_async_copy_lse = cute.make_copy_atom(
143
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
144
+ Float32,
145
+ num_bits_per_copy=lse_copy_bits,
146
+ )
147
+ tLSE_layout = cute.make_ordered_layout(
148
+ (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),
149
+ order=(1, 0),
150
+ )
151
+ vLSE_layout = cute.make_layout(1)
152
+ self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(
153
+ atom_async_copy_lse, tLSE_layout, vLSE_layout
154
+ )
155
+
156
+ # ///////////////////////////////////////////////////////////////////////////////
157
+ # Shared memory
158
+ # ///////////////////////////////////////////////////////////////////////////////
159
+
160
+ # Shared memory to register copy for LSE
161
+ self.smem_threads_per_col_lse = self.num_threads // m_block_smem
162
+ assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size
163
+
164
+ s2r_layout_atom_lse = cute.make_ordered_layout(
165
+ (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),
166
+ order=(0, 1),
167
+ )
168
+ self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(
169
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),
170
+ s2r_layout_atom_lse,
171
+ cute.make_layout(1),
172
+ )
173
+
174
+ # LSE shared memory layout with swizzling to avoid bank conflicts
175
+ # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
176
+ if const_expr(m_block_smem == 8):
177
+ smem_lse_swizzle = cute.make_swizzle(5, 0, 5)
178
+ elif const_expr(m_block_smem == 16):
179
+ smem_lse_swizzle = cute.make_swizzle(4, 0, 4)
180
+ else:
181
+ smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
182
+ smem_layout_atom_lse = cute.make_composed_layout(
183
+ smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
184
+ )
185
+ self.smem_layout_lse = cute.tile_to_shape(
186
+ smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
187
+ )
188
+
189
+ # O partial shared memory layout (simple layout for pipeline stages)
190
+ self.smem_layout_o = cute.make_ordered_layout(
191
+ (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
192
+ )
193
+
194
+ @cute.jit
195
+ def __call__(
196
+ self,
197
+ mO_partial: cute.Tensor,
198
+ mLSE_partial: cute.Tensor,
199
+ mO: cute.Tensor,
200
+ mLSE: Optional[cute.Tensor] = None,
201
+ cu_seqlens: Optional[cute.Tensor] = None,
202
+ seqused: Optional[cute.Tensor] = None,
203
+ num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
204
+ semaphore_to_reset: Optional[cute.Tensor] = None,
205
+ stream: cuda.CUstream = None,
206
+ ):
207
+ # Type checking
208
+ if const_expr(not (mO_partial.element_type == self.dtype_partial)):
209
+ raise TypeError("O partial tensor must match dtype_partial")
210
+ if const_expr(not (mO.element_type == self.dtype)):
211
+ raise TypeError("O tensor must match dtype")
212
+ if const_expr(mLSE_partial.element_type not in [Float32]):
213
+ raise TypeError("LSE partial tensor must be Float32")
214
+ if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
215
+ raise TypeError("LSE tensor must be Float32")
216
+
217
+ # Shape validation - input tensors are in user format, need to be converted to kernel format
218
+ if const_expr(len(mO_partial.shape) not in [4, 5]):
219
+ raise ValueError(
220
+ "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
221
+ )
222
+ if const_expr(len(mLSE_partial.shape) not in [3, 4]):
223
+ raise ValueError(
224
+ "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
225
+ )
226
+ if const_expr(len(mO.shape) not in [3, 4]):
227
+ raise ValueError(
228
+ "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
229
+ )
230
+ if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
231
+ raise ValueError(
232
+ "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
233
+ )
234
+
235
+ mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]
236
+ # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
237
+ # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
238
+ O_partial_layout_transpose = (
239
+ [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
240
+ )
241
+ # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
242
+ mO_partial = cute.make_tensor(
243
+ mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
244
+ )
245
+ O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
246
+ mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
247
+ # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
248
+ # or (num_splits, total_q, h) -> (total_q, num_splits, h)
249
+ LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
250
+ mLSE_partial = cute.make_tensor(
251
+ mLSE_partial.iterator,
252
+ cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
253
+ )
254
+ # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
255
+ LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
256
+ mLSE = (
257
+ cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
258
+ if mLSE is not None
259
+ else None
260
+ )
261
+
262
+ # Determine if we have variable length sequences
263
+ varlen = const_expr(cu_seqlens is not None or seqused is not None)
264
+
265
+ self._setup_attributes()
266
+
267
+ @cute.struct
268
+ class SharedStorage:
269
+ sLSE: cute.struct.Align[
270
+ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
271
+ ]
272
+ sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
273
+ sO: cute.struct.Align[
274
+ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
275
+ ]
276
+
277
+ smem_size = SharedStorage.size_in_bytes()
278
+
279
+ # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
280
+ seqlen = mO_partial.shape[0]
281
+ num_head = mO_partial.shape[3]
282
+ batch_size = (
283
+ mO_partial.shape[4]
284
+ if const_expr(cu_seqlens is None)
285
+ else Int32(cu_seqlens.shape[0] - 1)
286
+ )
287
+
288
+ # Create FastDivmodDivisor objects for efficient division
289
+ seqlen_divmod = FastDivmodDivisor(seqlen)
290
+ head_divmod = FastDivmodDivisor(num_head)
291
+
292
+ grid_dim = (
293
+ cute.ceil_div(seqlen * num_head, self.m_block_size),
294
+ cute.ceil_div(self.head_dim, self.k_block_size),
295
+ batch_size,
296
+ )
297
+
298
+ self.kernel(
299
+ mO_partial,
300
+ mLSE_partial,
301
+ mO,
302
+ mLSE,
303
+ cu_seqlens,
304
+ seqused,
305
+ num_splits_dynamic_ptr,
306
+ semaphore_to_reset,
307
+ SharedStorage,
308
+ self.smem_layout_lse,
309
+ self.smem_layout_o,
310
+ self.gmem_tiled_copy_O_partial,
311
+ self.gmem_tiled_copy_O,
312
+ self.gmem_tiled_copy_LSE,
313
+ self.s2r_tiled_copy_LSE,
314
+ seqlen_divmod,
315
+ head_divmod,
316
+ varlen,
317
+ ).launch(
318
+ grid=grid_dim,
319
+ block=[self.num_threads, 1, 1],
320
+ smem=smem_size,
321
+ stream=stream,
322
+ )
323
+
324
+ @cute.kernel
325
+ def kernel(
326
+ self,
327
+ mO_partial: cute.Tensor,
328
+ mLSE_partial: cute.Tensor,
329
+ mO: cute.Tensor,
330
+ mLSE: Optional[cute.Tensor],
331
+ cu_seqlens: Optional[cute.Tensor],
332
+ seqused: Optional[cute.Tensor],
333
+ num_splits_dynamic_ptr: Optional[cute.Tensor],
334
+ semaphore_to_reset: Optional[cute.Tensor],
335
+ SharedStorage: cutlass.Constexpr,
336
+ smem_layout_lse: cute.Layout | cute.ComposedLayout,
337
+ smem_layout_o: cute.Layout,
338
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
339
+ gmem_tiled_copy_O: cute.TiledCopy,
340
+ gmem_tiled_copy_LSE: cute.TiledCopy,
341
+ s2r_tiled_copy_LSE: cute.TiledCopy,
342
+ seqlen_divmod: FastDivmodDivisor,
343
+ head_divmod: FastDivmodDivisor,
344
+ varlen: cutlass.Constexpr[bool],
345
+ ):
346
+ # Thread and block indices
347
+ tidx, _, _ = cute.arch.thread_idx()
348
+ m_block, k_block, batch_idx = cute.arch.block_idx()
349
+
350
+ # ///////////////////////////////////////////////////////////////////////////////
351
+ # Get shared memory buffer
352
+ # ///////////////////////////////////////////////////////////////////////////////
353
+ smem = cutlass.utils.SmemAllocator()
354
+ storage = smem.allocate(SharedStorage)
355
+ sLSE = storage.sLSE.get_tensor(smem_layout_lse)
356
+ sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,))
357
+ sO = storage.sO.get_tensor(smem_layout_o)
358
+
359
+ # Handle semaphore reset
360
+ if const_expr(semaphore_to_reset is not None):
361
+ if (
362
+ tidx == 0
363
+ and m_block == cute.arch.grid_dim()[0] - 1
364
+ and k_block == cute.arch.grid_dim()[1] - 1
365
+ and batch_idx == cute.arch.grid_dim()[2] - 1
366
+ ):
367
+ semaphore_to_reset[0] = 0
368
+
369
+ # Get number of splits
370
+ num_splits = (
371
+ num_splits_dynamic_ptr[batch_idx]
372
+ if const_expr(num_splits_dynamic_ptr is not None)
373
+ else mLSE_partial.shape[1]
374
+ )
375
+ # Handle variable length sequences using SeqlenInfo
376
+ seqlen_info = SeqlenInfo.create(
377
+ batch_idx=batch_idx,
378
+ seqlen_static=mO_partial.shape[0],
379
+ cu_seqlens=cu_seqlens,
380
+ seqused=seqused,
381
+ )
382
+ seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
383
+
384
+ # Extract number of heads (head index will be determined dynamically)
385
+ num_head = mO_partial.shape[3]
386
+ max_idx = seqlen * num_head
387
+
388
+ # Early exit for single split if dynamic
389
+ if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
390
+ const_expr(not varlen) or m_block * self.m_block_size < max_idx
391
+ ):
392
+ # ===============================
393
+ # Step 1: Load LSE_partial from gmem to shared memory
394
+ # ===============================
395
+
396
+ if const_expr(cu_seqlens is None):
397
+ mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx]
398
+ else:
399
+ mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial)
400
+ mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
401
+
402
+ gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
403
+ tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
404
+
405
+ # Create identity tensor for coordinate tracking
406
+ cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size))
407
+ tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
408
+
409
+ # Load LSE partial values
410
+ for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
411
+ mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
412
+ idx = m_block * self.m_block_size + mi
413
+ if idx < max_idx:
414
+ # Calculate actual sequence position and head using FastDivmodDivisor
415
+ if const_expr(not varlen):
416
+ head_idx, m_idx = divmod(idx, seqlen_divmod)
417
+ else:
418
+ head_idx = idx // seqlen
419
+ m_idx = idx - head_idx * seqlen
420
+ mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]
421
+ for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
422
+ si = tLSEcLSE[0, s, 0][0] # Get split coordinate
423
+ if si < num_splits:
424
+ cute.copy(
425
+ gmem_thr_copy_LSE,
426
+ mLSE_partial_cur_copy[None, si],
427
+ tLSEsLSE[None, s, m],
428
+ )
429
+ else:
430
+ tLSEsLSE[None, s, m].fill(-Float32.inf)
431
+ # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
432
+ cute.arch.cp_async_commit_group()
433
+
434
+ # ===============================
435
+ # Step 2: Load O_partial for pipeline stages
436
+ # ===============================
437
+
438
+ gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
439
+ cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size))
440
+ tOcO = gmem_thr_copy_O_partial.partition_D(cO)
441
+ tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
442
+ if const_expr(cu_seqlens is None):
443
+ mO_partial_cur = mO_partial[None, None, None, None, batch_idx]
444
+ else:
445
+ mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial)
446
+
447
+ # Precompute these values to avoid recomputing them in the loop
448
+ num_rows = const_expr(cute.size(tOcO, mode=[1]))
449
+ tOmidx = cute.make_fragment(num_rows, cutlass.Int32)
450
+ tOhidx = cute.make_fragment(num_rows, cutlass.Int32)
451
+ tOrOptr = cute.make_fragment(num_rows, cutlass.Int64)
452
+ for m in cutlass.range(num_rows, unroll_full=True):
453
+ mi = tOcO[0, m, 0][0] # m coordinate
454
+ idx = m_block * self.m_block_size + mi
455
+ if const_expr(not varlen):
456
+ tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
457
+ else:
458
+ tOhidx[m] = idx // seqlen
459
+ tOmidx[m] = idx - tOhidx[m] * seqlen
460
+ tOrOptr[m] = utils.elem_pointer(
461
+ mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
462
+ ).toint()
463
+ if idx >= max_idx:
464
+ tOhidx[m] = -1
465
+
466
+ tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean)
467
+ if const_expr(not self.is_even_k):
468
+ for k in cutlass.range(cute.size(tOpO), unroll_full=True):
469
+ tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
470
+ # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
471
+
472
+ load_O_partial = partial(
473
+ self.load_O_partial,
474
+ gmem_tiled_copy_O_partial,
475
+ tOrOptr,
476
+ tOsO_partial,
477
+ tOhidx,
478
+ tOpO,
479
+ tOcO,
480
+ mO_partial_cur.layout,
481
+ )
482
+
483
+ # Load first few stages of O_partial
484
+ for stage in cutlass.range(self.stages - 1, unroll_full=True):
485
+ if stage < num_splits:
486
+ load_O_partial(stage, stage)
487
+ cute.arch.cp_async_commit_group()
488
+
489
+ # ===============================
490
+ # Step 3: Load and transpose LSE from smem to registers
491
+ # ===============================
492
+
493
+ # Wait for LSE and initial O partial stages to complete
494
+ cute.arch.cp_async_wait_group(self.stages - 1)
495
+ cute.arch.sync_threads()
496
+ # if cute.arch.thread_idx()[0] == 0:
497
+ # # cute.print_tensor(sLSE)
498
+ # for i in range(64):
499
+ # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0])
500
+ # cute.arch.sync_threads()
501
+
502
+ s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
503
+ ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
504
+ ts2rrLSE = cute.make_fragment_like(ts2rsLSE)
505
+ cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
506
+
507
+ # ===============================
508
+ # Step 4: Compute final LSE along split dimension
509
+ # ===============================
510
+
511
+ lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32)
512
+ ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
513
+ # We compute the max valid split for each row to short-circuit the computation later
514
+ max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32)
515
+ assert cute.size(ts2rrLSE, mode=[0]) == 1
516
+ # Compute max, scales, and final LSE for each row
517
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
518
+ # Find max LSE value across splits
519
+ threads_per_col = const_expr(self.smem_threads_per_col_lse)
520
+ lse_max = cute.arch.warp_reduction_max(
521
+ ts2rrLSE[None, None, m]
522
+ .load()
523
+ .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
524
+ threads_in_group=threads_per_col,
525
+ )
526
+ # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)
527
+ # Find max valid split index
528
+ max_valid_idx = -1
529
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
530
+ if ts2rrLSE[0, s, m] != -Float32.inf:
531
+ max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate
532
+ # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
533
+ max_valid_split[m] = cute.arch.warp_reduction_max(
534
+ max_valid_idx, threads_in_group=threads_per_col
535
+ )
536
+ # Compute exp scales and sum
537
+ lse_max_cur = (
538
+ 0.0 if lse_max == -Float32.inf else lse_max
539
+ ) # In case all local LSEs are -inf
540
+ LOG2_E = math.log2(math.e)
541
+ lse_sum_cur = 0.0
542
+ for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
543
+ scale = cute.math.exp2(
544
+ ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True
545
+ )
546
+ lse_sum_cur += scale
547
+ ts2rrLSE[0, s, m] = scale # Store scale for later use
548
+ lse_sum_cur = cute.arch.warp_reduction_sum(
549
+ lse_sum_cur, threads_in_group=threads_per_col
550
+ )
551
+ lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
552
+ # Normalize scales
553
+ inv_sum = (
554
+ 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
555
+ )
556
+ ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
557
+ # Store the scales exp(lse - lse_logsum) back to smem
558
+ cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
559
+
560
+ # Store max valid split to smem
561
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
562
+ if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
563
+ mi = ts2rcLSE[0, 0, m][1]
564
+ if mi < self.m_block_size:
565
+ sMaxValidSplit[mi] = max_valid_split[m]
566
+
567
+ # ===============================
568
+ # Step 5: Store final LSE to gmem
569
+ # ===============================
570
+
571
+ if const_expr(mLSE is not None):
572
+ if const_expr(cu_seqlens is None):
573
+ mLSE_cur = mLSE[None, None, batch_idx]
574
+ else:
575
+ mLSE_cur = cute.domain_offset((offset, 0), mLSE)
576
+ if k_block == 0: # Only first k_block writes LSE when mLSE is provided
577
+ for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
578
+ if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
579
+ mi = ts2rcLSE[0, 0, m][1]
580
+ idx = m_block * self.m_block_size + mi
581
+ if idx < max_idx:
582
+ if const_expr(not varlen):
583
+ head_idx, m_idx = divmod(idx, seqlen_divmod)
584
+ else:
585
+ head_idx = idx // seqlen
586
+ m_idx = idx - head_idx * seqlen
587
+ mLSE_cur[m_idx, head_idx] = lse_sum[m]
588
+
589
+ # ===============================
590
+ # Step 6: Read O_partial and accumulate final O
591
+ # ===============================
592
+
593
+ cute.arch.sync_threads()
594
+
595
+ # Get max valid split for this thread
596
+ thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
597
+ for m in cutlass.range(1, cute.size(tOcO, mode=[1])):
598
+ thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
599
+
600
+ tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0])
601
+ tOrO = cute.make_fragment_like(tOrO_partial, Float32)
602
+ tOrO.fill(0.0)
603
+
604
+ stage_load = self.stages - 1
605
+ stage_compute = 0
606
+
607
+ # Main accumulation loop
608
+ for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
609
+ # Get scales for this split
610
+ scale = cute.make_fragment(num_rows, Float32)
611
+ for m in cutlass.range(num_rows, unroll_full=True):
612
+ scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
613
+
614
+ # Load next stage if needed
615
+ split_to_load = s + self.stages - 1
616
+ if split_to_load <= thr_max_valid_split:
617
+ load_O_partial(split_to_load, stage_load)
618
+ cute.arch.cp_async_commit_group()
619
+ stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1
620
+
621
+ # Wait for the current stage to be ready
622
+ cute.arch.cp_async_wait_group(self.stages - 1)
623
+ # We don't need __syncthreads() because each thread is just reading its own data from smem
624
+ # Copy from smem to registers
625
+ cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)
626
+ stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1
627
+
628
+ # Accumulate scaled partial results
629
+ for m in cutlass.range(num_rows, unroll_full=True):
630
+ if tOhidx[m] >= 0 and scale[m] > 0.0:
631
+ tOrO[None, m, None].store(
632
+ tOrO[None, m, None].load()
633
+ + scale[m] * tOrO_partial[None, m, None].load().to(Float32)
634
+ )
635
+
636
+ # ===============================
637
+ # Step 7: Write final O to gmem
638
+ # ===============================
639
+
640
+ rO = cute.make_fragment_like(tOrO, self.dtype)
641
+ rO.store(tOrO.load().to(self.dtype))
642
+ if const_expr(cu_seqlens is None):
643
+ mO_cur = mO[None, None, None, batch_idx]
644
+ else:
645
+ mO_cur = cute.domain_offset((offset, 0, 0), mO)
646
+ mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)
647
+ elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))
648
+ # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,))
649
+ gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
650
+ # Write final results
651
+ for m in cutlass.range(num_rows, unroll_full=True):
652
+ if tOhidx[m] >= 0:
653
+ mO_cur_copy = cute.tiled_divide(
654
+ mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
655
+ )
656
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
657
+ k_idx = tOcO[0, 0, k][1] // elems_per_store
658
+ if const_expr(self.is_even_k) or tOpO[k]:
659
+ cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])
660
+
661
+ @cute.jit
662
+ def load_O_partial(
663
+ self,
664
+ gmem_tiled_copy_O_partial: cute.TiledCopy,
665
+ tOrOptr: cute.Tensor,
666
+ tOsO_partial: cute.Tensor,
667
+ tOhidx: cute.Tensor,
668
+ tOpO: cute.Tensor,
669
+ tOcO: cute.Tensor,
670
+ mO_cur_partial_layout: cute.Layout,
671
+ split: Int32,
672
+ stage: Int32,
673
+ ) -> None:
674
+ elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))
675
+ tOsO_partial_cur = tOsO_partial[None, None, None, stage]
676
+ for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):
677
+ if tOhidx[m] >= 0:
678
+ o_gmem_ptr = cute.make_ptr(
679
+ tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
680
+ )
681
+ mO_partial_cur = cute.make_tensor(
682
+ o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
683
+ )
684
+ mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
685
+ for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
686
+ k_idx = tOcO[0, 0, k][1] // elems_per_load
687
+ if const_expr(self.is_even_k) or tOpO[k]:
688
+ cute.copy(
689
+ gmem_tiled_copy_O_partial,
690
+ mO_partial_cur_copy[None, k_idx, split],
691
+ tOsO_partial_cur[None, m, k],
692
+ )
build/torch-cuda/flash_fwd_sm100.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch-cuda/interface.py ADDED
@@ -0,0 +1,1855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
3
+
4
+ # Supported features:
5
+ # - BF16 & FP16 dtype
6
+ # - noncausal & causal attention
7
+ # - MHA, GQA, MQA
8
+ # - hdim 64, 96, 128.
9
+ # - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape)
10
+ # - varlen
11
+ # - sliding window
12
+ # - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow)
13
+
14
+ # Features not supported yet:
15
+ # - split (i.e. FlashDecoding)
16
+ # - tuned block sizes
17
+ # - paged KV
18
+ # - append KV to existing KV cache
19
+ # - FP8
20
+ # - bwd pass optimized for Hopper/Blackwell
21
+
22
+ import os
23
+ import math
24
+ from functools import lru_cache
25
+ from typing import Optional, Tuple, Callable
26
+
27
+ import torch
28
+
29
+
30
+ import cuda.bindings.driver as cuda
31
+
32
+ import cutlass
33
+ import cutlass.cute as cute
34
+ from .cache_utils import get_jit_cache
35
+ from .testing import is_fake_mode
36
+
37
+
38
+ if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
39
+ from . import cute_dsl_ptxas # noqa: F401
40
+
41
+ # Patch to dump ptx and then use system ptxas to compile to cubin
42
+ cute_dsl_ptxas.patch()
43
+
44
+
45
+ from . import utils
46
+ from .cute_dsl_utils import (
47
+ to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,
48
+ )
49
+ from .flash_fwd import FlashAttentionForwardSm90
50
+ from .flash_fwd_sm100 import FlashAttentionForwardSm100
51
+ from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess
52
+ from .flash_bwd import FlashAttentionBackwardSm80
53
+ from .flash_bwd_sm90 import FlashAttentionBackwardSm90
54
+ from .flash_bwd_sm100 import FlashAttentionBackwardSm100
55
+ from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess
56
+ from .flash_fwd_combine import FlashAttentionForwardCombine
57
+
58
+ from .block_sparsity import (
59
+ BlockSparseTensorsTorch,
60
+ to_cute_block_sparse_tensors,
61
+ normalize_block_sparse_config,
62
+ normalize_block_sparse_config_bwd,
63
+ )
64
+
65
+ @lru_cache(maxsize=None)
66
+ def _get_device_arch():
67
+ """Cached device arch check."""
68
+ major, minor = torch.cuda.get_device_capability()
69
+ return major * 10 + minor
70
+
71
+ def maybe_contiguous(x):
72
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
73
+
74
+
75
+ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
76
+ assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
77
+ assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
78
+ assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
79
+ assert t.is_cuda, f"{name} must be on CUDA"
80
+
81
+
82
+ torch2cute_dtype_map = {
83
+ torch.float16: cutlass.Float16,
84
+ torch.bfloat16: cutlass.BFloat16,
85
+ torch.float32: cutlass.Float32,
86
+ }
87
+
88
+
89
+ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
90
+ # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
91
+ if num_n_blocks <= 4:
92
+ return 1
93
+
94
+ # NOTE: We should revisit this heuristic after persistence is supported for split KV.
95
+ # Sometimes, it's ideal to over-schedule splits for better efficiency.
96
+ return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
97
+
98
+
99
+ def _flash_attn_fwd(
100
+ q: torch.Tensor,
101
+ k: torch.Tensor,
102
+ v: torch.Tensor,
103
+ cu_seqlens_q: Optional[torch.Tensor] = None,
104
+ cu_seqlens_k: Optional[torch.Tensor] = None,
105
+ seqused_q: Optional[torch.Tensor] = None,
106
+ seqused_k: Optional[torch.Tensor] = None,
107
+ max_seqlen_q: Optional[int] = None,
108
+ max_seqlen_k: Optional[int] = None,
109
+ page_table: Optional[torch.Tensor] = None,
110
+ softmax_scale: Optional[float] = None,
111
+ causal: bool = False,
112
+ softcap: Optional[float] = None,
113
+ window_size_left: Optional[int] = None,
114
+ window_size_right: Optional[int] = None,
115
+ learnable_sink: Optional[torch.Tensor] = None,
116
+ # m_block_size: int = 128,
117
+ # n_block_size: int = 64,
118
+ # num_threads: int = 128,
119
+ m_block_size: int = 128,
120
+ n_block_size: int = 128,
121
+ num_threads: int = 384,
122
+ num_splits: int = 1,
123
+ pack_gqa: Optional[bool] = None,
124
+ _arch: Optional[int] = None,
125
+ score_mod: Optional[Callable] = None,
126
+ mask_mod: Optional[Callable] = None,
127
+ block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
128
+ return_lse: bool = False,
129
+ out: Optional[torch.Tensor] = None,
130
+ lse: Optional[torch.Tensor] = None,
131
+ aux_tensors: Optional[list[torch.Tensor]] = None,
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ """Forward pass for FlashAttention.
134
+
135
+ Args:
136
+ ...
137
+ score_mod: A callable that takes the attention scores and applies a modification.
138
+ mask_mod: A callable that takes token position information and selectively masks
139
+ block_sparse_tensors: A tuple of tensors used for block sparsity.
140
+ return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
141
+ Note: the returned LSE currently does not support taking gradient.
142
+ out: Optional pre-allocated output tensor. If None, will be allocated internally.
143
+ lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
144
+ aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
145
+ """
146
+ q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
147
+ num_head, head_dim = q.shape[-2:]
148
+ if cu_seqlens_q is None:
149
+ batch_size, seqlen_q = q.shape[:2]
150
+ total_q = batch_size * seqlen_q
151
+ else:
152
+ batch_size = cu_seqlens_q.shape[0] - 1
153
+ seqlen_q = None
154
+ total_q = q.shape[0]
155
+ if page_table is not None:
156
+ assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
157
+ assert page_table.dtype == torch.int32, "page_table must be int32"
158
+ assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension"
159
+ max_num_pages_per_seq = page_table.shape[1]
160
+ assert page_table.shape == (batch_size, max_num_pages_per_seq)
161
+ num_pages, page_size = k.shape[:2]
162
+ seqlen_k = num_pages * page_size
163
+ else:
164
+ num_pages, page_size = None, None
165
+ seqlen_k = k.shape[-3]
166
+ num_head_kv = k.shape[-2]
167
+ head_dim_v = v.shape[-1]
168
+ if cu_seqlens_k is None:
169
+ if page_table is None:
170
+ assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
171
+ assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
172
+ else:
173
+ assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
174
+ assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
175
+ else:
176
+ assert k.shape == (seqlen_k, num_head_kv, head_dim)
177
+ assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
178
+ assert cu_seqlens_k.shape == (batch_size + 1,), (
179
+ "cu_seqlens_k must have shape (batch_size + 1,)"
180
+ )
181
+
182
+ if cu_seqlens_q is not None:
183
+ assert cu_seqlens_q.shape == (batch_size + 1,), (
184
+ "cu_seqlens_q must have shape (batch_size + 1,)"
185
+ )
186
+ assert seqused_q is None or seqused_q.shape == (batch_size,), (
187
+ "seqused_q must have shape (batch_size,)"
188
+ )
189
+ assert seqused_k is None or seqused_k.shape == (batch_size,), (
190
+ "seqused_k must have shape (batch_size,)"
191
+ )
192
+ assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
193
+ assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
194
+ for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
195
+ if t is not None:
196
+ assert t.dtype == torch.int32, (
197
+ "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
198
+ )
199
+ assert t.stride(0) == 1, (
200
+ "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
201
+ )
202
+ if learnable_sink is not None:
203
+ assert learnable_sink.shape == (num_head,)
204
+ assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
205
+
206
+ assert all(
207
+ t is None or t.is_cuda
208
+ for t in (
209
+ q,
210
+ k,
211
+ v,
212
+ cu_seqlens_q,
213
+ cu_seqlens_k,
214
+ seqused_q,
215
+ seqused_k,
216
+ page_table,
217
+ learnable_sink,
218
+ )
219
+ ), "inputs must be on CUDA device"
220
+ assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
221
+ assert head_dim <= 256, "head_dim must be less than or equal to 256"
222
+ alignment = 16 // q.element_size()
223
+ assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
224
+ assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
225
+ if softmax_scale is None:
226
+ softmax_scale = 1.0 / math.sqrt(head_dim)
227
+ if softcap == 0.0:
228
+ softcap = None
229
+ qhead_per_kvhead = num_head // num_head_kv
230
+ if pack_gqa is None:
231
+ pack_gqa = qhead_per_kvhead > 1
232
+
233
+ out_torch_dtype = q.dtype
234
+ device = q.device
235
+ q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
236
+ lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q)
237
+ requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
238
+
239
+ if out is None:
240
+ out = torch.empty(
241
+ *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device
242
+ )
243
+ else:
244
+ _validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device)
245
+
246
+ if lse is None:
247
+ lse = (
248
+ torch.empty(lse_shape, dtype=torch.float32, device=device)
249
+ if requires_grad or return_lse
250
+ else None
251
+ )
252
+ elif lse is not None:
253
+ _validate_tensor(lse, "lse", lse_shape, torch.float32, device)
254
+
255
+ dtype = torch2cute_dtype_map[q.dtype]
256
+ arch = _get_device_arch() if _arch is None else _arch
257
+
258
+ assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
259
+
260
+ use_block_sparsity = block_sparse_tensors is not None
261
+
262
+ if mask_mod is None:
263
+ if causal:
264
+ window_size_right = 0
265
+ if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
266
+ window_size_left = None
267
+ window_size_right = None
268
+ local = window_size_left is not None or window_size_right is not None
269
+ if window_size_left is not None or window_size_right is not None:
270
+ if window_size_left is None and window_size_right == 0:
271
+ causal, local = True, False
272
+ window_size_right = None
273
+ else:
274
+ causal, local = False, True
275
+ else:
276
+ causal, local = False, False
277
+
278
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
279
+
280
+ if arch // 10 == 9: # TODO: tune block size according to hdim.
281
+ if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
282
+ n_block_size = 192
283
+
284
+ if arch // 10 in [10, 11]:
285
+ if (
286
+ pack_gqa
287
+ and (128 % qhead_per_kvhead != 0)
288
+ ):
289
+ pack_gqa = False
290
+ # TODO: fix GQA + SplitKV + non-varlen
291
+ if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
292
+ pack_gqa = False
293
+
294
+ if max_seqlen_q is None:
295
+ max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
296
+ if max_seqlen_k is None:
297
+ max_seqlen_k = seqlen_k
298
+ seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
299
+ if arch // 10 == 10:
300
+ q_stage = 2 if seqlen_q_packgqa > m_block_size else 1
301
+ else:
302
+ q_stage = 1
303
+
304
+ if num_splits < 1:
305
+ m_block_size_effective = q_stage * m_block_size
306
+ seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size))
307
+ num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
308
+ num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
309
+ total_mblocks = batch_size * num_head_kv * num_m_blocks
310
+ num_splits = num_splits_heuristic(
311
+ total_mblocks,
312
+ torch.cuda.get_device_properties(device).multi_processor_count,
313
+ num_n_blocks,
314
+ 128,
315
+ )
316
+
317
+ is_split_kv = num_splits > 1
318
+ if is_split_kv:
319
+ out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
320
+ lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
321
+
322
+ # hash score and mask mods for compile cache
323
+ score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
324
+ mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
325
+
326
+ if softcap is not None:
327
+ assert score_mod is None, "softcap and score_mod cannot be used together"
328
+ score_mod = utils.create_softcap_scoremod(softcap)
329
+
330
+ is_varlen = (
331
+ cu_seqlens_q is not None
332
+ or cu_seqlens_k is not None
333
+ or seqused_q is not None
334
+ or seqused_k is not None
335
+ )
336
+
337
+ if mask_mod is not None:
338
+ if is_varlen:
339
+ raise NotImplementedError(
340
+ "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
341
+ )
342
+
343
+ if use_block_sparsity:
344
+ if is_varlen:
345
+ raise NotImplementedError(
346
+ "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR."
347
+ )
348
+ # NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
349
+ if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
350
+ pack_gqa = False
351
+ if is_split_kv:
352
+ raise NotImplementedError(
353
+ "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
354
+ )
355
+
356
+ # See get_broadcast_dims for why this is needed in compile key
357
+ block_sparse_broadcast_pattern = None
358
+ normalized_block_sparse_tensors = None
359
+ q_subtile_factor = None
360
+ if block_sparse_tensors is not None:
361
+ if seqlen_q is None:
362
+ raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
363
+ (
364
+ normalized_block_sparse_tensors,
365
+ block_sparse_broadcast_pattern,
366
+ q_subtile_factor,
367
+ ) = normalize_block_sparse_config(
368
+ block_sparse_tensors,
369
+ batch_size=batch_size,
370
+ num_head=num_head,
371
+ seqlen_q=seqlen_q,
372
+ seqlen_k=seqlen_k,
373
+ block_size=(m_block_size, n_block_size),
374
+ q_stage=q_stage,
375
+ )
376
+ if aux_tensors is not None:
377
+ aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
378
+ else:
379
+ aux_tensor_metadata = None
380
+
381
+ compile_key = (
382
+ dtype,
383
+ head_dim,
384
+ head_dim_v,
385
+ qhead_per_kvhead,
386
+ causal,
387
+ score_mod_hash,
388
+ mask_mod_hash,
389
+ use_block_sparsity,
390
+ block_sparse_broadcast_pattern,
391
+ aux_tensor_metadata,
392
+ lse is None,
393
+ cu_seqlens_q is None,
394
+ cu_seqlens_k is None,
395
+ seqused_q is None,
396
+ seqused_k is None,
397
+ page_table is not None,
398
+ window_size_left is not None,
399
+ window_size_right is not None,
400
+ learnable_sink is not None,
401
+ m_block_size,
402
+ n_block_size,
403
+ q_stage,
404
+ num_threads,
405
+ is_split_kv,
406
+ pack_gqa,
407
+ arch,
408
+ page_size not in [None, 128], # paged KV non-TMA
409
+ q_subtile_factor,
410
+ )
411
+ if compile_key not in _flash_attn_fwd.compile_cache:
412
+ (
413
+ cu_seqlens_q_tensor,
414
+ cu_seqlens_k_tensor,
415
+ seqused_q_tensor,
416
+ seqused_k_tensor,
417
+ learnable_sink_tensor,
418
+ ) = [
419
+ to_cute_tensor(t, assumed_align=4, leading_dim=0)
420
+ if t is not None
421
+ else None
422
+ for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
423
+ ]
424
+ page_table_tensor = (
425
+ to_cute_tensor(page_table, assumed_align=4, leading_dim=1)
426
+ if page_table is not None
427
+ else None
428
+ )
429
+ q_tensor, k_tensor, v_tensor, o_tensor = [
430
+ to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial)
431
+ ]
432
+ if is_split_kv:
433
+ lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)
434
+ elif lse is not None:
435
+ lse_tensor = to_cute_tensor(lse, assumed_align=4)
436
+ else:
437
+ lse_tensor = None
438
+
439
+ sparse_tensors = None
440
+ if normalized_block_sparse_tensors is not None:
441
+ sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)
442
+
443
+ cute_aux_tensors = None
444
+ aux_tensor_metadata = None
445
+ if aux_tensors is not None:
446
+ cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]
447
+
448
+ if arch // 10 == 9:
449
+ assert page_table is None, "paged KV not supported on SM 9.0"
450
+ assert not is_split_kv, "SplitKV not supported on SM 9.0"
451
+ # fa_fwd = FlashAttentionForwardSm80(
452
+ fa_fwd = FlashAttentionForwardSm90(
453
+ dtype,
454
+ head_dim,
455
+ head_dim_v,
456
+ qhead_per_kvhead,
457
+ is_causal=causal,
458
+ is_local=local,
459
+ pack_gqa=pack_gqa,
460
+ tile_m=m_block_size,
461
+ tile_n=n_block_size,
462
+ # num_stages=1,
463
+ num_stages=2,
464
+ num_threads=num_threads,
465
+ Q_in_regs=False,
466
+ intra_wg_overlap=True,
467
+ mma_pv_is_rs=True,
468
+ mask_mod=mask_mod,
469
+ score_mod=score_mod,
470
+ has_aux_tensors=aux_tensors is not None,
471
+ q_subtile_factor=q_subtile_factor,
472
+ )
473
+ elif arch // 10 in [10, 11]:
474
+ head_dim_padded = int(math.ceil(head_dim / 16) * 16)
475
+ head_dim_v_padded = int(math.ceil(head_dim / 16) * 16)
476
+ use_2cta_instrs = (
477
+ not causal
478
+ and not local
479
+ and not is_split_kv
480
+ and cu_seqlens_q is None
481
+ and seqused_q is None
482
+ and not use_block_sparsity
483
+ and page_size in [None, 128]
484
+ and head_dim_padded == 128
485
+ and head_dim_v_padded == 128
486
+ )
487
+ fa_fwd = FlashAttentionForwardSm100(
488
+ head_dim,
489
+ head_dim_v,
490
+ qhead_per_kvhead=qhead_per_kvhead,
491
+ is_causal=causal,
492
+ is_local=local,
493
+ is_split_kv=is_split_kv,
494
+ pack_gqa=pack_gqa,
495
+ m_block_size=m_block_size,
496
+ n_block_size=n_block_size,
497
+ q_stage=q_stage,
498
+ is_persistent=not causal
499
+ and not local
500
+ and cu_seqlens_q is None
501
+ and seqused_q is None
502
+ and not is_split_kv,
503
+ score_mod=score_mod,
504
+ mask_mod=mask_mod,
505
+ has_aux_tensors=aux_tensors is not None,
506
+ paged_kv_non_tma=page_size not in [None, 128],
507
+ is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
508
+ q_subtile_factor=q_subtile_factor,
509
+ use_2cta_instrs=use_2cta_instrs,
510
+ )
511
+ else:
512
+ raise ValueError(
513
+ f"Unsupported compute capability: {arch}. Supported: 9.x, 10.x, 11.x"
514
+ )
515
+ # TODO: check @can_implement
516
+ _flash_attn_fwd.compile_cache[compile_key] = cute.compile(
517
+ fa_fwd,
518
+ q_tensor,
519
+ k_tensor,
520
+ v_tensor,
521
+ o_tensor,
522
+ lse_tensor,
523
+ softmax_scale,
524
+ current_stream,
525
+ cu_seqlens_q_tensor,
526
+ cu_seqlens_k_tensor,
527
+ seqused_q_tensor,
528
+ seqused_k_tensor,
529
+ page_table_tensor,
530
+ window_size_left,
531
+ window_size_right,
532
+ learnable_sink_tensor,
533
+ sparse_tensors,
534
+ cute_aux_tensors,
535
+ options="--enable-tvm-ffi",
536
+ )
537
+
538
+ # In "fake mode", we will take torch fake tensors as input and the expected behaviors are:
539
+ # - Use those fake metadata to populate compilation cache
540
+ # - Return "fake" output tensors, which could be needed in follow-up fake operations
541
+ # Thus, we skip the actual kernel invocation here.
542
+ if not is_fake_mode():
543
+ _flash_attn_fwd.compile_cache[compile_key](
544
+ q.detach(),
545
+ k.detach(),
546
+ v.detach(),
547
+ out.detach() if not is_split_kv else out_partial,
548
+ lse_partial if is_split_kv else lse,
549
+ softmax_scale,
550
+ current_stream,
551
+ cu_seqlens_q,
552
+ cu_seqlens_k,
553
+ seqused_q,
554
+ seqused_k,
555
+ page_table,
556
+ window_size_left,
557
+ window_size_right,
558
+ learnable_sink,
559
+ normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
560
+ aux_tensors,
561
+ )
562
+ if is_split_kv:
563
+ _flash_attn_fwd_combine(
564
+ out_partial,
565
+ lse_partial.transpose(-1, -2),
566
+ out,
567
+ lse.transpose(-1, -2) if lse is not None else None,
568
+ cu_seqlens_q,
569
+ seqused_q,
570
+ )
571
+ return out, lse
572
+
573
+
574
+ _flash_attn_fwd.compile_cache = get_jit_cache("fwd")
575
+
576
+
577
+ def _flash_attn_bwd(
578
+ q: torch.Tensor,
579
+ k: torch.Tensor,
580
+ v: torch.Tensor,
581
+ out: torch.Tensor,
582
+ dout: torch.Tensor,
583
+ lse: torch.Tensor,
584
+ softmax_scale: Optional[float] = None,
585
+ causal: bool = False,
586
+ softcap: float = 0.0,
587
+ window_size_left: Optional[int] = None,
588
+ window_size_right: Optional[int] = None,
589
+ m_block_size: int = 64,
590
+ n_block_size: int = 128,
591
+ num_threads: int = 256,
592
+ pack_gqa: bool = False,
593
+ num_stages_Q: int = 2,
594
+ num_stages_dO: int = 2,
595
+ SdP_swapAB: bool = False,
596
+ dKV_swapAB: bool = False,
597
+ dQ_swapAB: bool = False,
598
+ AtomLayoutMSdP: int = 2,
599
+ AtomLayoutNdKV: int = 2,
600
+ AtomLayoutMdQ: int = 2,
601
+ V_in_regs: bool = False,
602
+ cu_seqlens_q: Optional[torch.Tensor] = None,
603
+ cu_seqlens_k: Optional[torch.Tensor] = None,
604
+ seqused_q: Optional[torch.Tensor] = None,
605
+ seqused_k: Optional[torch.Tensor] = None,
606
+ max_seqlen_q: Optional[int] = None,
607
+ max_seqlen_k: Optional[int] = None,
608
+ deterministic: bool = False,
609
+ dq: Optional[torch.Tensor] = None,
610
+ dk: Optional[torch.Tensor] = None,
611
+ dv: Optional[torch.Tensor] = None,
612
+ score_mod: Optional[Callable] = None,
613
+ score_mod_bwd: Optional[Callable] = None,
614
+ mask_mod: Optional[Callable] = None,
615
+ aux_tensors: Optional[list[torch.Tensor]] = None,
616
+ block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
617
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
618
+ arch = _get_device_arch()
619
+ assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
620
+
621
+ num_head, head_dim = q.shape[-2:]
622
+
623
+ if causal:
624
+ window_size_right = 0
625
+ if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
626
+ window_size_left = None
627
+ window_size_right = None
628
+ local = window_size_left is not None or window_size_right is not None
629
+ if local:
630
+ if window_size_left is None and window_size_right == 0:
631
+ causal, local = True, False
632
+ window_size_right = None
633
+ else:
634
+ causal, local = False, True
635
+
636
+ if arch // 10 == 9:
637
+ m_block_size = 80 if not causal else 64
638
+ n_block_size = 128
639
+ num_stages_Q = 2
640
+ num_stages_dO = 2
641
+ num_stages_PdS = 2
642
+ SdP_swapAB = True
643
+ dKV_swapAB = False
644
+ dQ_swapAB = not causal
645
+ AtomLayoutMSdP = 1
646
+ AtomLayoutNdKV = 2
647
+ AtomLayoutMdQ = 1
648
+ cluster_size = 1
649
+ use_2cta_instrs = False
650
+ assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
651
+ is_varlen = (
652
+ cu_seqlens_q is not None
653
+ or cu_seqlens_k is not None
654
+ or seqused_q is not None
655
+ or seqused_k is not None
656
+ )
657
+ assert not is_varlen, "varlen backward is not yet supported on sm90"
658
+ else:
659
+ m_block_size = 128
660
+ n_block_size = 128
661
+ dQ_swapAB = False
662
+ dKV_swapAB = False
663
+ AtomLayoutMdQ = 1
664
+ AtomLayoutNdKV = 1
665
+ disable_2cta = (
666
+ local
667
+ or score_mod is not None
668
+ or score_mod_bwd is not None
669
+ or mask_mod is not None
670
+ )
671
+ cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
672
+ use_2cta_instrs = cluster_size==2
673
+
674
+ q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
675
+ maybe_contiguous(t)
676
+ for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
677
+ ]
678
+ if cu_seqlens_q is None:
679
+ batch_size, seqlen_q = q.shape[:2]
680
+ total_q = batch_size * seqlen_q
681
+ else:
682
+ batch_size = cu_seqlens_q.shape[0] - 1
683
+ total_q = q.shape[0]
684
+ seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q
685
+
686
+ if cu_seqlens_k is None:
687
+ batch_size, seqlen_k = k.shape[:2]
688
+ total_k = batch_size * seqlen_k
689
+ else:
690
+ batch_size = cu_seqlens_k.shape[0] - 1
691
+ total_k = k.shape[0]
692
+ seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
693
+
694
+ num_head_kv = k.shape[-2]
695
+ head_dim_v = v.shape[-1]
696
+
697
+ use_block_sparsity = block_sparse_tensors is not None
698
+
699
+ # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,
700
+ # the base block_m of 128 from forward, and block-sparse size for subtiling.
701
+ if arch // 10 == 9 and use_block_sparsity:
702
+ m_block_size = 64
703
+ # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)
704
+ dQ_swapAB = False
705
+
706
+ # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
707
+ subtile_factor = 2
708
+ seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
709
+ seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
710
+ num_n_blocks = seqlen_k_rounded // n_block_size
711
+ if cluster_size == 2 and num_n_blocks % cluster_size != 0:
712
+ seqlen_k_rounded = seqlen_k_rounded + n_block_size
713
+
714
+ if cu_seqlens_k is None:
715
+ assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
716
+ assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
717
+ else:
718
+ assert k.shape == (total_k, num_head_kv, head_dim)
719
+ assert v.shape == (total_k, num_head_kv, head_dim_v)
720
+ assert cu_seqlens_k.shape == (batch_size + 1,), (
721
+ "cu_seqlens_k must have shape (batch_size + 1,)"
722
+ )
723
+
724
+ if cu_seqlens_q is not None:
725
+ assert cu_seqlens_q.shape == (batch_size + 1,), (
726
+ "cu_seqlens_q must have shape (batch_size + 1,)"
727
+ )
728
+
729
+ assert out.shape == (total_q, num_head, head_dim_v)
730
+ assert dout.shape == (total_q, num_head, head_dim_v)
731
+ assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)"
732
+ else:
733
+ assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v)
734
+ assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v)
735
+ assert lse.shape == (batch_size, num_head, seqlen_q), (
736
+ "lse must have shape (batch_size, num_head, seqlen_q)"
737
+ )
738
+
739
+ assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16"
740
+ assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, (
741
+ "inputs must have the same dtype"
742
+ )
743
+ for t in [cu_seqlens_q, cu_seqlens_k]:
744
+ if t is not None:
745
+ assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
746
+ assert lse.dtype == torch.float32, "lse must be float32"
747
+ assert all(
748
+ t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
749
+ ), "inputs must be on CUDA device"
750
+ assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
751
+ assert head_dim <= 256, "head_dim must be less than or equal to 256"
752
+ alignment = 16 // q.element_size()
753
+ assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
754
+ assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
755
+ if softmax_scale is None:
756
+ softmax_scale = 1.0 / math.sqrt(head_dim)
757
+ qhead_per_kvhead = num_head // num_head_kv
758
+ if pack_gqa is None:
759
+ pack_gqa = qhead_per_kvhead > 1
760
+ # pack_gqa backward not yet supported in bwd
761
+ pack_gqa = False
762
+ if arch // 10 not in [10, 11]:
763
+ assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now"
764
+
765
+ if score_mod is not None:
766
+ assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
767
+ assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
768
+ assert cu_seqlens_q is None and cu_seqlens_k is None, (
769
+ "varlen + score_mod not supported in bwd yet"
770
+ )
771
+
772
+ device = q.device
773
+ out_torch_dtype = q.dtype
774
+
775
+ if dq is None:
776
+ dq = torch.empty_like(q)
777
+ else:
778
+ _validate_tensor(dq, "dq", q.shape, out_torch_dtype, device)
779
+
780
+ if dk is None:
781
+ dk = torch.empty_like(k)
782
+ else:
783
+ _validate_tensor(dk, "dk", k.shape, out_torch_dtype, device)
784
+
785
+ if dv is None:
786
+ dv = torch.empty_like(v)
787
+ else:
788
+ _validate_tensor(dv, "dv", v.shape, out_torch_dtype, device)
789
+
790
+ head_dim_rounded = (head_dim + 32 - 1) // 32 * 32
791
+
792
+ if cu_seqlens_q is None:
793
+ dq_accum = torch.empty(
794
+ batch_size,
795
+ num_head,
796
+ seqlen_q_rounded * head_dim_rounded,
797
+ dtype=torch.float32,
798
+ device=device,
799
+ )
800
+ dpsum = torch.empty(
801
+ batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
802
+ )
803
+ lse_log2 = torch.empty(
804
+ batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device
805
+ )
806
+ else:
807
+ total_q_rounded_padded = (
808
+ (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size
809
+ )
810
+ dq_accum = torch.empty(
811
+ num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device
812
+ )
813
+ dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
814
+ lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
815
+
816
+ dKV_postprocess = qhead_per_kvhead > 1
817
+ if dKV_postprocess:
818
+ head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
819
+ if cu_seqlens_k is None:
820
+ dk_accum = torch.zeros(
821
+ batch_size,
822
+ num_head_kv,
823
+ seqlen_k_rounded * head_dim_rounded,
824
+ dtype=torch.float32,
825
+ device=device,
826
+ )
827
+ dv_accum = torch.zeros(
828
+ batch_size,
829
+ num_head_kv,
830
+ seqlen_k_rounded * head_dim_v_rounded,
831
+ dtype=torch.float32,
832
+ device=device,
833
+ )
834
+ else:
835
+ cluster_tile_n = cluster_size * n_block_size
836
+ total_k_rounded_padded = (
837
+ (total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n
838
+ )
839
+ dk_accum = torch.zeros(
840
+ num_head_kv,
841
+ total_k_rounded_padded * head_dim_rounded,
842
+ dtype=torch.float32,
843
+ device=device,
844
+ )
845
+ dv_accum = torch.zeros(
846
+ num_head_kv,
847
+ total_k_rounded_padded * head_dim_v_rounded,
848
+ dtype=torch.float32,
849
+ device=device,
850
+ )
851
+
852
+ dtype = torch2cute_dtype_map[q.dtype]
853
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
854
+
855
+ if deterministic:
856
+ dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device="cuda")
857
+ else:
858
+ dQ_semaphore = None
859
+
860
+ if deterministic and qhead_per_kvhead > 1:
861
+ dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
862
+ dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
863
+ else:
864
+ dK_semaphore = None
865
+ dV_semaphore = None
866
+
867
+ # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
868
+ compile_key_pre = (
869
+ arch,
870
+ dtype,
871
+ head_dim,
872
+ head_dim_v,
873
+ m_block_size,
874
+ num_threads,
875
+ cu_seqlens_q is None,
876
+ seqused_q is None,
877
+ get_broadcast_dims(out),
878
+ get_broadcast_dims(dout),
879
+ )
880
+ if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
881
+ o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
882
+ dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
883
+ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
884
+ ]
885
+ lse_tensor = to_cute_tensor(lse, assumed_align=4)
886
+ cu_seqlens_q_tensor, seqused_q_tensor = [
887
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
888
+ for t in (cu_seqlens_q, seqused_q)
889
+ ]
890
+ fa_bwd_pre = FlashAttentionBackwardPreprocess(
891
+ dtype,
892
+ head_dim,
893
+ head_dim_v,
894
+ arch,
895
+ m_block_size,
896
+ num_threads=num_threads,
897
+ )
898
+ # TODO: check @can_implement
899
+ _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile(
900
+ fa_bwd_pre,
901
+ o_tensor,
902
+ do_tensor,
903
+ dpsum_tensor,
904
+ lse_tensor,
905
+ lse_log2_tensor,
906
+ dq_accum_tensor,
907
+ cu_seqlens_q_tensor,
908
+ seqused_q_tensor,
909
+ current_stream,
910
+ options="--enable-tvm-ffi",
911
+ )
912
+ if not is_fake_mode():
913
+ _flash_attn_bwd.compile_cache_pre[compile_key_pre](
914
+ out,
915
+ dout,
916
+ dpsum,
917
+ lse,
918
+ lse_log2,
919
+ dq_accum,
920
+ cu_seqlens_q,
921
+ seqused_q,
922
+ current_stream,
923
+ )
924
+
925
+ # NB num_threads application for 3 kernels
926
+ # There are pre, main, post processing kernels, currenlty num_threads is only actually
927
+ # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
928
+ # before cache key gen
929
+ num_threads = 384
930
+
931
+ # Backward kernel: compute dk, dv, dq_accum.
932
+ score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
933
+ score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False
934
+ mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False
935
+ num_aux_tensors = len(aux_tensors) if aux_tensors else 0
936
+ cute_aux_tensors = None
937
+ if aux_tensors is not None:
938
+ cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]
939
+
940
+ block_sparse_broadcast_pattern = None
941
+ normalized_block_sparse_tensors = None
942
+ if block_sparse_tensors is not None:
943
+ (
944
+ normalized_block_sparse_tensors,
945
+ block_sparse_broadcast_pattern,
946
+ ) = normalize_block_sparse_config_bwd(
947
+ block_sparse_tensors,
948
+ batch_size=batch_size,
949
+ num_head=num_head,
950
+ seqlen_q=seqlen_q,
951
+ seqlen_k=seqlen_k,
952
+ block_size=(m_block_size, n_block_size),
953
+ subtile_factor=subtile_factor,
954
+ )
955
+
956
+ if arch // 10 == 9:
957
+ compile_key = (
958
+ arch,
959
+ dtype,
960
+ head_dim,
961
+ head_dim_v,
962
+ qhead_per_kvhead,
963
+ causal,
964
+ softcap != 0.0,
965
+ m_block_size,
966
+ n_block_size,
967
+ num_threads,
968
+ pack_gqa,
969
+ num_stages_Q,
970
+ num_stages_dO,
971
+ SdP_swapAB,
972
+ dKV_swapAB,
973
+ dQ_swapAB,
974
+ AtomLayoutMSdP,
975
+ AtomLayoutNdKV,
976
+ AtomLayoutMdQ,
977
+ V_in_regs,
978
+ cu_seqlens_q is None,
979
+ cu_seqlens_k is None,
980
+ seqused_q is None,
981
+ seqused_k is None,
982
+ score_mod_hash,
983
+ score_mod_bwd_hash,
984
+ mask_mod_hash,
985
+ num_aux_tensors,
986
+ use_block_sparsity,
987
+ block_sparse_broadcast_pattern,
988
+ get_broadcast_dims(q),
989
+ get_broadcast_dims(k),
990
+ get_broadcast_dims(v),
991
+ get_broadcast_dims(dout),
992
+ )
993
+ else:
994
+ compile_key = (
995
+ arch,
996
+ dtype,
997
+ head_dim,
998
+ head_dim_v,
999
+ qhead_per_kvhead,
1000
+ causal,
1001
+ window_size_left is not None,
1002
+ window_size_right is not None,
1003
+ softcap != 0.0,
1004
+ m_block_size,
1005
+ n_block_size,
1006
+ num_threads,
1007
+ pack_gqa,
1008
+ cluster_size,
1009
+ use_2cta_instrs,
1010
+ deterministic,
1011
+ score_mod_hash,
1012
+ score_mod_bwd_hash,
1013
+ mask_mod_hash,
1014
+ num_aux_tensors,
1015
+ use_block_sparsity,
1016
+ block_sparse_broadcast_pattern,
1017
+ cu_seqlens_q is None,
1018
+ cu_seqlens_k is None,
1019
+ seqused_q is None,
1020
+ seqused_k is None,
1021
+ get_broadcast_dims(q),
1022
+ get_broadcast_dims(k),
1023
+ get_broadcast_dims(v),
1024
+ get_broadcast_dims(dout),
1025
+ )
1026
+ if compile_key not in _flash_attn_bwd.compile_cache:
1027
+ q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
1028
+ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)
1029
+ ]
1030
+ dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
1031
+ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
1032
+ ]
1033
+ if dKV_postprocess:
1034
+ dk_accum_tensor, dv_accum_tensor = [
1035
+ to_cute_tensor(t) for t in (dk_accum, dv_accum)
1036
+ ]
1037
+ cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [
1038
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1039
+ for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
1040
+ ]
1041
+ dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [
1042
+ utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())
1043
+ if t is not None else None
1044
+ for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
1045
+ ]
1046
+ fa_bwd_sm80 = FlashAttentionBackwardSm80(
1047
+ dtype,
1048
+ head_dim,
1049
+ head_dim_v,
1050
+ qhead_per_kvhead,
1051
+ m_block_size,
1052
+ n_block_size,
1053
+ num_stages_Q,
1054
+ num_stages_dO,
1055
+ num_threads,
1056
+ pack_gqa,
1057
+ causal,
1058
+ SdP_swapAB,
1059
+ dKV_swapAB,
1060
+ dQ_swapAB,
1061
+ AtomLayoutMSdP,
1062
+ AtomLayoutNdKV,
1063
+ AtomLayoutMdQ,
1064
+ V_in_regs=V_in_regs,
1065
+ )
1066
+ if arch // 10 == 9:
1067
+ fa_bwd_obj = FlashAttentionBackwardSm90(
1068
+ dtype,
1069
+ head_dim,
1070
+ head_dim_v,
1071
+ qhead_per_kvhead,
1072
+ causal,
1073
+ m_block_size,
1074
+ n_block_size,
1075
+ num_stages_Q,
1076
+ num_stages_dO,
1077
+ num_stages_PdS,
1078
+ SdP_swapAB,
1079
+ dKV_swapAB,
1080
+ dQ_swapAB,
1081
+ AtomLayoutMSdP,
1082
+ AtomLayoutNdKV,
1083
+ AtomLayoutMdQ,
1084
+ num_threads,
1085
+ V_in_regs=V_in_regs,
1086
+ score_mod=score_mod,
1087
+ score_mod_bwd=score_mod_bwd,
1088
+ mask_mod=mask_mod,
1089
+ has_aux_tensors=aux_tensors is not None,
1090
+ subtile_factor=subtile_factor,
1091
+ )
1092
+ else:
1093
+ fa_bwd_obj = FlashAttentionBackwardSm100(
1094
+ head_dim,
1095
+ head_dim_v,
1096
+ is_causal=causal,
1097
+ is_local=local,
1098
+ qhead_per_kvhead=qhead_per_kvhead,
1099
+ tile_m=m_block_size,
1100
+ tile_n=n_block_size,
1101
+ cluster_size=cluster_size,
1102
+ use_2cta_instrs=use_2cta_instrs,
1103
+ deterministic=deterministic,
1104
+ score_mod=score_mod,
1105
+ score_mod_bwd=score_mod_bwd,
1106
+ mask_mod=mask_mod,
1107
+ has_aux_tensors=aux_tensors is not None,
1108
+ subtile_factor=subtile_factor,
1109
+ )
1110
+
1111
+ # Block sparse tensors for backward use Q-direction indexing (transposed from forward).
1112
+ sparse_tensors_compile = None
1113
+ if normalized_block_sparse_tensors is not None:
1114
+ sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)
1115
+
1116
+ # TODO: check @can_implement
1117
+ _flash_attn_bwd.compile_cache[compile_key] = cute.compile(
1118
+ fa_bwd_obj,
1119
+ q_tensor,
1120
+ k_tensor,
1121
+ v_tensor,
1122
+ do_tensor,
1123
+ lse_log2_tensor,
1124
+ dpsum_tensor,
1125
+ dq_accum_tensor,
1126
+ dk_tensor if not dKV_postprocess else dk_accum_tensor,
1127
+ dv_tensor if not dKV_postprocess else dv_accum_tensor,
1128
+ softmax_scale,
1129
+ current_stream,
1130
+ cu_seqlens_q_tensor,
1131
+ cu_seqlens_k_tensor,
1132
+ seqused_q_tensor,
1133
+ seqused_k_tensor,
1134
+ None, # softcap - not yet supported in backward
1135
+ window_size_left,
1136
+ window_size_right,
1137
+ dQ_semaphore_tensor,
1138
+ dK_semaphore_tensor,
1139
+ dV_semaphore_tensor,
1140
+ cute_aux_tensors,
1141
+ sparse_tensors_compile,
1142
+ options="--enable-tvm-ffi",
1143
+ )
1144
+ if not is_fake_mode():
1145
+ _flash_attn_bwd.compile_cache[compile_key](
1146
+ q.detach(),
1147
+ k.detach(),
1148
+ v.detach(),
1149
+ dout,
1150
+ lse_log2,
1151
+ dpsum,
1152
+ dq_accum,
1153
+ dk if not dKV_postprocess else dk_accum,
1154
+ dv if not dKV_postprocess else dv_accum,
1155
+ softmax_scale,
1156
+ current_stream,
1157
+ cu_seqlens_q,
1158
+ cu_seqlens_k,
1159
+ seqused_q,
1160
+ seqused_k,
1161
+ None, # softcap - not yet supported in backward
1162
+ window_size_left,
1163
+ window_size_right,
1164
+ dQ_semaphore,
1165
+ dK_semaphore,
1166
+ dV_semaphore,
1167
+ aux_tensors,
1168
+ normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
1169
+ )
1170
+
1171
+ num_threads = 256 if arch // 10 == 9 else 128
1172
+ # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
1173
+ compile_key_post = (
1174
+ arch,
1175
+ dtype,
1176
+ head_dim,
1177
+ m_block_size,
1178
+ num_threads,
1179
+ AtomLayoutMdQ,
1180
+ dQ_swapAB,
1181
+ cu_seqlens_q is None,
1182
+ seqused_q is None,
1183
+ use_2cta_instrs,
1184
+ 1, # no cluster for tile_m
1185
+ get_broadcast_dims(dq_accum),
1186
+ get_broadcast_dims(dq),
1187
+ )
1188
+ if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1189
+ dq_accum_tensor = to_cute_tensor(dq_accum)
1190
+ dq_tensor = to_cute_tensor(dq)
1191
+ cu_seqlens_q_tensor, seqused_q_tensor = [
1192
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1193
+ for t in (cu_seqlens_q, seqused_q)
1194
+ ]
1195
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
1196
+ dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB,
1197
+ use_2cta_instrs=use_2cta_instrs,
1198
+ )
1199
+ # TODO: check @can_implement
1200
+ _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1201
+ fa_bwd_post,
1202
+ dq_accum_tensor,
1203
+ dq_tensor,
1204
+ softmax_scale,
1205
+ cu_seqlens_q_tensor,
1206
+ seqused_q_tensor,
1207
+ current_stream,
1208
+ options="--enable-tvm-ffi",
1209
+ )
1210
+
1211
+ if not is_fake_mode():
1212
+ _flash_attn_bwd.compile_cache_post[compile_key_post](
1213
+ dq_accum,
1214
+ dq,
1215
+ softmax_scale,
1216
+ cu_seqlens_q,
1217
+ seqused_q,
1218
+ current_stream,
1219
+ )
1220
+
1221
+ if dKV_postprocess:
1222
+ # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
1223
+ compile_key_post = (
1224
+ arch,
1225
+ dtype,
1226
+ head_dim,
1227
+ n_block_size,
1228
+ num_threads,
1229
+ AtomLayoutNdKV,
1230
+ dKV_swapAB,
1231
+ cu_seqlens_k is None,
1232
+ seqused_k is None,
1233
+ False, # even for 2cta, is split along hdim, so always False
1234
+ cluster_size, # cluster is for tile_n
1235
+ get_broadcast_dims(dk_accum),
1236
+ get_broadcast_dims(dk),
1237
+ )
1238
+ if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1239
+ dk_accum_tensor = to_cute_tensor(dk_accum)
1240
+ dk_tensor = to_cute_tensor(dk)
1241
+ cu_seqlens_k_tensor, seqused_k_tensor = [
1242
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1243
+ for t in (cu_seqlens_k, seqused_k)
1244
+ ]
1245
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
1246
+ dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
1247
+ cluster_size=cluster_size,
1248
+ )
1249
+ # TODO: check @can_implement
1250
+ _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1251
+ fa_bwd_post,
1252
+ dk_accum_tensor,
1253
+ dk_tensor,
1254
+ softmax_scale,
1255
+ cu_seqlens_k_tensor,
1256
+ seqused_k_tensor,
1257
+ current_stream,
1258
+ options="--enable-tvm-ffi",
1259
+ )
1260
+ if not is_fake_mode():
1261
+ _flash_attn_bwd.compile_cache_post[compile_key_post](
1262
+ dk_accum,
1263
+ dk,
1264
+ softmax_scale,
1265
+ cu_seqlens_k,
1266
+ seqused_k,
1267
+ current_stream,
1268
+ )
1269
+ compile_key_post = (
1270
+ arch,
1271
+ dtype,
1272
+ head_dim_v,
1273
+ n_block_size,
1274
+ num_threads,
1275
+ AtomLayoutNdKV,
1276
+ dKV_swapAB,
1277
+ cu_seqlens_k is None,
1278
+ seqused_k is None,
1279
+ False,
1280
+ cluster_size,
1281
+ get_broadcast_dims(dv_accum),
1282
+ get_broadcast_dims(dv),
1283
+ )
1284
+ if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1285
+ dv_accum_tensor = to_cute_tensor(dv_accum)
1286
+ dv_tensor = to_cute_tensor(dv)
1287
+ cu_seqlens_k_tensor, seqused_k_tensor = [
1288
+ to_cute_tensor(t, assumed_align=4) if t is not None else None
1289
+ for t in (cu_seqlens_k, seqused_k)
1290
+ ]
1291
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
1292
+ dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
1293
+ cluster_size=cluster_size,
1294
+ )
1295
+ # TODO: check @can_implement
1296
+ _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1297
+ fa_bwd_post,
1298
+ dv_accum_tensor,
1299
+ dv_tensor,
1300
+ cutlass.Float32(1.0),
1301
+ cu_seqlens_k_tensor,
1302
+ seqused_k_tensor,
1303
+ current_stream,
1304
+ options="--enable-tvm-ffi",
1305
+ )
1306
+ if not is_fake_mode():
1307
+ _flash_attn_bwd.compile_cache_post[compile_key_post](
1308
+ dv_accum,
1309
+ dv,
1310
+ 1.0,
1311
+ cu_seqlens_k,
1312
+ seqused_k,
1313
+ current_stream,
1314
+ )
1315
+
1316
+ return dq, dk, dv
1317
+
1318
+
1319
+ _flash_attn_bwd.compile_cache_pre = get_jit_cache("bwd_pre")
1320
+ _flash_attn_bwd.compile_cache = get_jit_cache("bwd")
1321
+ _flash_attn_bwd.compile_cache_post = get_jit_cache("bwd_post")
1322
+
1323
+
1324
+ class FlashAttnFunc(torch.autograd.Function):
1325
+ @staticmethod
1326
+ def forward(
1327
+ ctx,
1328
+ q: torch.Tensor,
1329
+ k: torch.Tensor,
1330
+ v: torch.Tensor,
1331
+ softmax_scale: Optional[float] = None,
1332
+ causal: bool = False,
1333
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1334
+ learnable_sink: Optional[torch.Tensor] = None,
1335
+ softcap: float = 0.0,
1336
+ num_splits: int = 1,
1337
+ pack_gqa: Optional[bool] = None,
1338
+ deterministic: bool = False,
1339
+ mask_mod: Optional[Callable] = None,
1340
+ full_block_cnt: Optional[torch.Tensor] = None,
1341
+ full_block_idx: Optional[torch.Tensor] = None,
1342
+ mask_block_cnt: Optional[torch.Tensor] = None,
1343
+ mask_block_idx: Optional[torch.Tensor] = None,
1344
+ block_size: Optional[Tuple[int, int]] = None,
1345
+ return_lse: bool = False,
1346
+ ):
1347
+ # Only create block sparse tensors if at least one block sparse parameter is provided
1348
+ block_sparse_tensors = None
1349
+ if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]):
1350
+ block_sparse_tensors = BlockSparseTensorsTorch(
1351
+ full_block_cnt=full_block_cnt,
1352
+ full_block_idx=full_block_idx,
1353
+ mask_block_cnt=mask_block_cnt,
1354
+ mask_block_idx=mask_block_idx,
1355
+ block_size=block_size,
1356
+ )
1357
+ out, lse = _flash_attn_fwd(
1358
+ q,
1359
+ k,
1360
+ v,
1361
+ softmax_scale=softmax_scale,
1362
+ causal=causal,
1363
+ window_size_left=window_size[0],
1364
+ window_size_right=window_size[1],
1365
+ learnable_sink=learnable_sink,
1366
+ softcap=softcap,
1367
+ num_splits=num_splits,
1368
+ pack_gqa=pack_gqa,
1369
+ mask_mod=mask_mod,
1370
+ block_sparse_tensors=block_sparse_tensors,
1371
+ return_lse=return_lse,
1372
+ )
1373
+ ctx.save_for_backward(q, k, v, out, lse)
1374
+ ctx.softmax_scale = softmax_scale
1375
+ ctx.causal = causal
1376
+ ctx.window_size = window_size
1377
+ ctx.softcap = softcap
1378
+ ctx.deterministic = deterministic
1379
+ # LSE gradient is not supported yet
1380
+ if lse is not None:
1381
+ ctx.mark_non_differentiable(lse)
1382
+ return out, lse
1383
+
1384
+ @staticmethod
1385
+ def backward(ctx, dout, *args):
1386
+ q, k, v, out, lse = ctx.saved_tensors
1387
+ dq, dk, dv = _flash_attn_bwd(
1388
+ q,
1389
+ k,
1390
+ v,
1391
+ out,
1392
+ dout,
1393
+ lse,
1394
+ ctx.softmax_scale,
1395
+ ctx.causal,
1396
+ ctx.softcap,
1397
+ window_size_left=ctx.window_size[0],
1398
+ window_size_right=ctx.window_size[1],
1399
+ deterministic=ctx.deterministic,
1400
+ )
1401
+ return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
1402
+
1403
+
1404
+ class FlashAttnVarlenFunc(torch.autograd.Function):
1405
+ @staticmethod
1406
+ def forward(
1407
+ ctx,
1408
+ q: torch.Tensor,
1409
+ k: torch.Tensor,
1410
+ v: torch.Tensor,
1411
+ cu_seqlens_q: Optional[torch.Tensor],
1412
+ cu_seqlens_k: Optional[torch.Tensor],
1413
+ seqused_q: Optional[torch.Tensor] = None,
1414
+ seqused_k: Optional[torch.Tensor] = None,
1415
+ max_seqlen_q: Optional[int] = None,
1416
+ max_seqlen_k: Optional[int] = None,
1417
+ page_table: Optional[torch.Tensor] = None,
1418
+ softmax_scale: Optional[float] = None,
1419
+ causal: bool = False,
1420
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1421
+ learnable_sink: Optional[torch.Tensor] = None,
1422
+ softcap: float = 0.0,
1423
+ num_splits: int = 1,
1424
+ pack_gqa: Optional[bool] = None,
1425
+ deterministic: bool = False,
1426
+ score_mod: Optional[Callable] = None,
1427
+ aux_tensors: Optional[list] = None,
1428
+ return_lse: bool = False,
1429
+ ):
1430
+ out, lse = _flash_attn_fwd(
1431
+ q,
1432
+ k,
1433
+ v,
1434
+ cu_seqlens_q,
1435
+ cu_seqlens_k,
1436
+ seqused_q,
1437
+ seqused_k,
1438
+ max_seqlen_q=max_seqlen_q,
1439
+ max_seqlen_k=max_seqlen_k,
1440
+ page_table=page_table,
1441
+ softmax_scale=softmax_scale,
1442
+ causal=causal,
1443
+ window_size_left=window_size[0],
1444
+ window_size_right=window_size[1],
1445
+ learnable_sink=learnable_sink,
1446
+ softcap=softcap,
1447
+ num_splits=num_splits,
1448
+ pack_gqa=pack_gqa,
1449
+ score_mod=score_mod,
1450
+ aux_tensors=aux_tensors,
1451
+ return_lse=return_lse,
1452
+ )
1453
+ ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
1454
+ ctx.softmax_scale = softmax_scale
1455
+ ctx.causal = causal
1456
+ ctx.window_size = window_size
1457
+ ctx.softcap = softcap
1458
+ ctx.deterministic = deterministic
1459
+ ctx.max_seqlen_q = max_seqlen_q
1460
+ ctx.max_seqlen_k = max_seqlen_k
1461
+ # LSE gradient is not supported yet
1462
+ if lse is not None:
1463
+ ctx.mark_non_differentiable(lse)
1464
+ return out, lse
1465
+
1466
+ @staticmethod
1467
+ def backward(ctx, dout, *args):
1468
+ q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
1469
+ assert ctx.softcap == 0.0
1470
+ dq, dk, dv = _flash_attn_bwd(
1471
+ q,
1472
+ k,
1473
+ v,
1474
+ out,
1475
+ dout,
1476
+ lse,
1477
+ ctx.softmax_scale,
1478
+ ctx.causal,
1479
+ ctx.softcap,
1480
+ window_size_left=ctx.window_size[0],
1481
+ window_size_right=ctx.window_size[1],
1482
+ cu_seqlens_q=cu_seqlens_q,
1483
+ cu_seqlens_k=cu_seqlens_k,
1484
+ seqused_q=seqused_q,
1485
+ seqused_k=seqused_k,
1486
+ max_seqlen_q=ctx.max_seqlen_q,
1487
+ max_seqlen_k=ctx.max_seqlen_k,
1488
+ deterministic=ctx.deterministic,
1489
+ )
1490
+
1491
+ return dq, dk, dv, *((None,) * 20)
1492
+
1493
+
1494
+ def flash_attn_func(
1495
+ q: torch.Tensor,
1496
+ k: torch.Tensor,
1497
+ v: torch.Tensor,
1498
+ softmax_scale: Optional[float] = None,
1499
+ causal: bool = False,
1500
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1501
+ learnable_sink: Optional[torch.Tensor] = None,
1502
+ softcap: float = 0.0,
1503
+ num_splits: int = 1,
1504
+ pack_gqa: Optional[bool] = None,
1505
+ deterministic: bool = False,
1506
+ mask_mod: Optional[Callable] = None,
1507
+ full_block_cnt: Optional[torch.Tensor] = None,
1508
+ full_block_idx: Optional[torch.Tensor] = None,
1509
+ mask_block_cnt: Optional[torch.Tensor] = None,
1510
+ mask_block_idx: Optional[torch.Tensor] = None,
1511
+ block_size: Optional[Tuple[int, int]] = None,
1512
+ return_lse: bool = False,
1513
+ ):
1514
+ return FlashAttnFunc.apply(
1515
+ q,
1516
+ k,
1517
+ v,
1518
+ softmax_scale,
1519
+ causal,
1520
+ window_size,
1521
+ learnable_sink,
1522
+ softcap,
1523
+ num_splits,
1524
+ pack_gqa,
1525
+ deterministic,
1526
+ mask_mod,
1527
+ full_block_cnt,
1528
+ full_block_idx,
1529
+ mask_block_cnt,
1530
+ mask_block_idx,
1531
+ block_size,
1532
+ return_lse,
1533
+ )
1534
+
1535
+
1536
+ def flash_attn_varlen_func(
1537
+ q: torch.Tensor,
1538
+ k: torch.Tensor,
1539
+ v: torch.Tensor,
1540
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1541
+ cu_seqlens_k: Optional[torch.Tensor] = None,
1542
+ max_seqlen_q: Optional[int] = None,
1543
+ max_seqlen_k: Optional[int] = None,
1544
+ seqused_q: Optional[torch.Tensor] = None,
1545
+ seqused_k: Optional[torch.Tensor] = None,
1546
+ page_table: Optional[torch.Tensor] = None,
1547
+ softmax_scale: Optional[float] = None,
1548
+ causal: bool = False,
1549
+ window_size: Tuple[Optional[int], Optional[int]] = (None, None),
1550
+ learnable_sink: Optional[torch.Tensor] = None,
1551
+ softcap: float = 0.0,
1552
+ num_splits: int = 1,
1553
+ pack_gqa: Optional[bool] = None,
1554
+ deterministic: bool = False,
1555
+ score_mod: Optional[Callable] = None,
1556
+ aux_tensors: Optional[list] = None,
1557
+ return_lse: bool = False,
1558
+ ):
1559
+ return FlashAttnVarlenFunc.apply(
1560
+ q,
1561
+ k,
1562
+ v,
1563
+ cu_seqlens_q,
1564
+ cu_seqlens_k,
1565
+ seqused_q,
1566
+ seqused_k,
1567
+ max_seqlen_q,
1568
+ max_seqlen_k,
1569
+ page_table,
1570
+ softmax_scale,
1571
+ causal,
1572
+ window_size,
1573
+ learnable_sink,
1574
+ softcap,
1575
+ num_splits,
1576
+ pack_gqa,
1577
+ deterministic,
1578
+ score_mod,
1579
+ aux_tensors,
1580
+ return_lse,
1581
+ )
1582
+
1583
+
1584
+ def _flash_attn_fwd_combine(
1585
+ out_partial: torch.Tensor,
1586
+ lse_partial: torch.Tensor,
1587
+ out: torch.Tensor,
1588
+ lse: Optional[torch.Tensor] = None,
1589
+ cu_seqlens: Optional[torch.Tensor] = None,
1590
+ seqused: Optional[torch.Tensor] = None,
1591
+ num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
1592
+ semaphore_to_reset: Optional[torch.Tensor] = None,
1593
+ ) -> None:
1594
+ """Forward combine kernel for split attention computation.
1595
+
1596
+ Combines partial outputs and log-sum-exp values from multiple splits
1597
+ of attention computation into final outputs.
1598
+
1599
+ Args:
1600
+ out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or
1601
+ (num_splits, total_q, nheads, headdim) if there's cu_seqlens
1602
+ lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or
1603
+ (num_splits, total_q, nheads) if there's cu_seqlens
1604
+ out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens
1605
+ lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens.
1606
+ cu_seqlens: Cumulative sequence lengths for variable length sequences
1607
+ seqused: Used sequence lengths for each batch
1608
+ num_splits_dynamic_ptr: Dynamic number of splits per batch
1609
+ semaphore_to_reset: Semaphore for synchronization
1610
+ k_block_size: Block size for head dimension
1611
+
1612
+ Returns:
1613
+ None
1614
+ """
1615
+ # Input validation
1616
+ assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
1617
+ assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
1618
+ assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
1619
+ "out_partial must be fp16, bf16, or fp32"
1620
+ )
1621
+ assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
1622
+ assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
1623
+ assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension"
1624
+ assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension"
1625
+ assert lse_partial.shape == out_partial.shape[:-1]
1626
+
1627
+ # Determine if this is variable length based on dimensions
1628
+ is_varlen = out_partial.dim() == 4
1629
+
1630
+ # Validate output tensor shapes and types
1631
+ assert out.shape == out_partial.shape[1:], "out shape mismatch"
1632
+ if lse is not None:
1633
+ assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
1634
+ assert lse.dtype == torch.float32, "lse must be fp32"
1635
+
1636
+ # Validate optional tensors
1637
+ for t, name in [
1638
+ (cu_seqlens, "cu_seqlens"),
1639
+ (seqused, "seqused"),
1640
+ (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
1641
+ ]:
1642
+ if t is not None:
1643
+ assert t.dtype == torch.int32, f"{name} must be int32"
1644
+ assert t.is_cuda, f"{name} must be on CUDA device"
1645
+ assert t.is_contiguous(), f"{name} must be contiguous"
1646
+
1647
+ head_dim = out_partial.shape[-1]
1648
+ num_splits = out_partial.shape[0]
1649
+ assert num_splits <= 256
1650
+ # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
1651
+ # so that kBlockM is smaller and we have more parallelism.
1652
+ k_block_size = 64 if head_dim <= 64 else 128
1653
+ # We want kBlockM to be as small as possible to maximize parallelism.
1654
+ # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
1655
+ m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
1656
+ log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
1657
+ if m_block_size == 8:
1658
+ # If kBlockM == 8 then the minimum number of splits is 32.
1659
+ # TODO: we can deal w this by using 128 threads instead
1660
+ log_max_splits = max(log_max_splits, 5)
1661
+
1662
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1663
+
1664
+ # Create combine kernel configuration
1665
+ dtype = torch2cute_dtype_map[out.dtype]
1666
+ dtype_partial = torch2cute_dtype_map[out_partial.dtype]
1667
+
1668
+ compile_key = (
1669
+ dtype,
1670
+ dtype_partial,
1671
+ head_dim,
1672
+ m_block_size,
1673
+ k_block_size,
1674
+ log_max_splits,
1675
+ cu_seqlens is not None,
1676
+ seqused is not None,
1677
+ lse is not None,
1678
+ )
1679
+
1680
+ if compile_key not in _flash_attn_fwd_combine.compile_cache:
1681
+ out_partial_tensor = to_cute_tensor(
1682
+ out_partial, leading_dim=4 if not is_varlen else 3
1683
+ )
1684
+ lse_partial_tensor = to_cute_tensor(
1685
+ lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
1686
+ )
1687
+ out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
1688
+ lse_tensor = (
1689
+ to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
1690
+ if lse is not None
1691
+ else None
1692
+ )
1693
+
1694
+ optional_tensors = [
1695
+ to_cute_tensor(t, assumed_align=4, leading_dim=0)
1696
+ if t is not None
1697
+ else None
1698
+ for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
1699
+ ]
1700
+ cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
1701
+ optional_tensors
1702
+ )
1703
+ fa_combine = FlashAttentionForwardCombine(
1704
+ dtype=dtype,
1705
+ dtype_partial=dtype_partial,
1706
+ head_dim=head_dim,
1707
+ m_block_size=m_block_size,
1708
+ k_block_size=k_block_size,
1709
+ log_max_splits=log_max_splits,
1710
+ )
1711
+
1712
+ # Check if implementation is supported
1713
+ if not fa_combine.can_implement(
1714
+ dtype,
1715
+ dtype_partial,
1716
+ head_dim,
1717
+ m_block_size,
1718
+ k_block_size,
1719
+ log_max_splits,
1720
+ num_threads=256,
1721
+ ):
1722
+ raise RuntimeError(
1723
+ "FlashAttention combine kernel cannot be implemented with given parameters"
1724
+ )
1725
+
1726
+ _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
1727
+ fa_combine,
1728
+ out_partial_tensor,
1729
+ lse_partial_tensor,
1730
+ out_tensor,
1731
+ lse_tensor,
1732
+ cu_seqlens_tensor,
1733
+ seqused_tensor,
1734
+ num_splits_dynamic_tensor,
1735
+ semaphore_tensor,
1736
+ current_stream,
1737
+ options="--enable-tvm-ffi",
1738
+ )
1739
+ if not is_fake_mode():
1740
+ _flash_attn_fwd_combine.compile_cache[compile_key](
1741
+ out_partial,
1742
+ lse_partial,
1743
+ out,
1744
+ lse,
1745
+ cu_seqlens,
1746
+ seqused,
1747
+ num_splits_dynamic_ptr,
1748
+ semaphore_to_reset,
1749
+ current_stream,
1750
+ )
1751
+
1752
+
1753
+ _flash_attn_fwd_combine.compile_cache = get_jit_cache("fwd_combine")
1754
+
1755
+
1756
+ def flash_attn_combine(
1757
+ out_partial: torch.Tensor,
1758
+ lse_partial: torch.Tensor,
1759
+ out: Optional[torch.Tensor] = None,
1760
+ out_dtype: Optional[torch.dtype] = None,
1761
+ cu_seqlens: Optional[torch.Tensor] = None,
1762
+ seqused: Optional[torch.Tensor] = None,
1763
+ return_lse: bool = True,
1764
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1765
+ """Flash Attention combine function for split attention computation.
1766
+
1767
+ Combines partial outputs and log-sum-exp values from multiple splits
1768
+ of attention computation into final outputs. This is the main user-facing
1769
+ interface for the combine kernel.
1770
+
1771
+ Args:
1772
+ out_partial: Partial outputs tensor with shape:
1773
+ - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input
1774
+ - (num_splits, total_q, num_heads, head_size) for variable length input
1775
+ lse_partial: Partial LSE tensor with shape:
1776
+ - (num_splits, batch_size, seqlen, num_heads) for regular batched input
1777
+ - (num_splits, total_q, num_heads) for variable length input
1778
+ out: Optional output tensor. If None, will be created automatically.
1779
+ out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
1780
+ cu_seqlens: Cumulative sequence lengths for variable length sequences
1781
+ seqused: Used sequence lengths for each batch
1782
+ return_lse: Whether to return the combined LSE tensor. Default is True.
1783
+
1784
+ Returns:
1785
+ Tuple of (out, lse) where:
1786
+ - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size)
1787
+ or (total_q, num_heads, head_size) for varlen
1788
+ - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads)
1789
+ or (total_q, num_heads) for varlen. None if return_lse=False
1790
+
1791
+ Note:
1792
+ This function expects the input tensors to be in the format produced by
1793
+ split attention computation, where the first dimension is num_splits.
1794
+ The permuting from user format to kernel format is now done inside the kernel.
1795
+ """
1796
+ # Input validation
1797
+ assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
1798
+ assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
1799
+ assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)"
1800
+ assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
1801
+
1802
+ # Determine if this is variable length based on dimensions
1803
+ is_varlen = out_partial.dim() == 4
1804
+
1805
+ if is_varlen:
1806
+ # Variable length: (num_splits, total_q, num_heads, head_size)
1807
+ num_splits, total_q, num_heads, head_size = out_partial.shape
1808
+ assert lse_partial.shape == (num_splits, total_q, num_heads), (
1809
+ "lse_partial shape mismatch for varlen"
1810
+ )
1811
+ batch_size = 1 # Treat as single batch for varlen
1812
+ seqlen = total_q
1813
+ else:
1814
+ # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
1815
+ num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
1816
+ assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), (
1817
+ "lse_partial shape mismatch"
1818
+ )
1819
+
1820
+ # Determine output dtype
1821
+ if out_dtype is None:
1822
+ out_dtype = out_partial.dtype
1823
+
1824
+ # Create output if not provided
1825
+ device = out_partial.device
1826
+ if out is None:
1827
+ if is_varlen:
1828
+ out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device)
1829
+ else:
1830
+ out = torch.empty(
1831
+ batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
1832
+ )
1833
+
1834
+ # Create lse output only if requested
1835
+ if return_lse:
1836
+ if is_varlen:
1837
+ lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(
1838
+ 0, 1
1839
+ )
1840
+ else:
1841
+ lse = torch.empty(
1842
+ batch_size, num_heads, seqlen, dtype=torch.float32, device=device
1843
+ ).transpose(1, 2)
1844
+ else:
1845
+ lse = None
1846
+
1847
+ _flash_attn_fwd_combine(
1848
+ out_partial,
1849
+ lse_partial,
1850
+ out,
1851
+ lse,
1852
+ cu_seqlens,
1853
+ seqused,
1854
+ )
1855
+ return out, lse
build/torch-cuda/mask.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional, Callable
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Float32, Int32, const_expr
9
+
10
+ from .quack import layout_utils
11
+ from . import utils
12
+ from .seqlen_info import SeqlenInfoQK
13
+
14
+
15
+ @cute.jit
16
+ def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None:
17
+ # Bit manipulation, compiles down to the R2P instruction
18
+ # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using.
19
+ # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ...,
20
+ # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
21
+ if const_expr(arch == 90):
22
+ col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2)
23
+ else:
24
+ col_limit_transformed = col_limit
25
+ ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
26
+ # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
27
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
28
+ # Don't need to clamp to 32 since the shr.u32 instruction does that already
29
+ col_limit_right_s = max(col_limit_transformed - s * 24, 0)
30
+ # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
31
+ mask = (1 << col_limit_right_s) - 1
32
+ # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
33
+ for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
34
+ in_bound = cutlass.Boolean(mask & (1 << i))
35
+ c = s * 24 + i
36
+ if const_expr(rank1):
37
+ X[c] = X[c] if in_bound else -Float32.inf
38
+ # This is the equivalent of:
39
+ # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf
40
+ else:
41
+ for r in cutlass.range_constexpr(cute.size(X.shape[0])):
42
+ X[r, c] = X[r, c] if in_bound else -Float32.inf
43
+
44
+
45
+ @cute.jit
46
+ def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None:
47
+ # Bit manipulation, compiles down to the R2P instruction
48
+ # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127
49
+ # or 0, 1, ..., 15, 32, ..., 47, 64, ...
50
+ # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
51
+ # Here we hardcode for the case of 2 warp groups.
52
+ num_wg = 2
53
+ row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min(
54
+ row_limit_top % (num_rep * num_wg), num_rep
55
+ )
56
+ ncol = cute.size(X.shape)
57
+ # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
58
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
59
+ row_limit_top_s = max(row_limit_top_transformed - s * 24, 0)
60
+ # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
61
+ mask = (1 << row_limit_top_s) - 1
62
+ # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
63
+ for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
64
+ out_bound = cutlass.Boolean(mask & (1 << i))
65
+ c = s * 24 + i
66
+ X[c] = -Float32.inf if out_bound else X[c]
67
+ # tidx = cute.arch.thread_idx()[0] % 256
68
+ # if tidx == 128:
69
+ # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
70
+
71
+
72
+ @cute.jit
73
+ def mask_r2p_dual_bound(
74
+ X: cute.Tensor,
75
+ col_limit_left: Int32, # Inclusive lower bound
76
+ col_limit_right: Int32, # Exclusive upper bound
77
+ ) -> None:
78
+ """
79
+ Dual-bound masking using two bitmasks for SM100, following mask_r2p.
80
+ Masks elements where: NOT (col_limit_left <= col < col_limit_right)
81
+
82
+ Uses bit manipulation to create a range mask:
83
+ mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1
84
+ mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1
85
+ mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
86
+ """
87
+ ncol = const_expr(cute.size(X.shape))
88
+
89
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
90
+ right_s = max(col_limit_right - s * 24, 0)
91
+ left_s = max(col_limit_left - s * 24, 0)
92
+
93
+ # otherwise cute dsl complains about python int too large to convert into c long
94
+ right_s = min(right_s, 24)
95
+ left_s = min(left_s, 24)
96
+
97
+ # bits (right-1)..left are 1
98
+ mask_right = (1 << right_s) - 1
99
+ mask_left = (1 << left_s) - 1
100
+ mask_range = mask_right & ~mask_left
101
+
102
+ # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
103
+ for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
104
+ in_bound = cutlass.Boolean(mask_range & (1 << i))
105
+ c = s * 24 + i
106
+ X[c] = X[c] if in_bound else -Float32.inf
107
+
108
+
109
+ @dataclass(frozen=True)
110
+ class AttentionMask:
111
+ tile_m: cutlass.Constexpr[int]
112
+ tile_n: cutlass.Constexpr[int]
113
+ seqlen_info: SeqlenInfoQK
114
+ window_size_left: Optional[Int32] = None
115
+ window_size_right: Optional[Int32] = None
116
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA
117
+ swap_AB: cutlass.Constexpr[bool] = False
118
+
119
+ @property
120
+ def seqlen_q(self) -> Int32:
121
+ return self.seqlen_info.seqlen_q
122
+
123
+ @property
124
+ def seqlen_k(self) -> Int32:
125
+ return self.seqlen_info.seqlen_k
126
+
127
+ @cute.jit
128
+ def apply_mask(
129
+ self,
130
+ acc_S: cute.Tensor,
131
+ batch_idx: cutlass.Int32,
132
+ head_idx: cutlass.Int32,
133
+ m_block: cutlass.Int32,
134
+ n_block: cutlass.Int32,
135
+ thr_mma: cute.TiledMma,
136
+ mask_seqlen: cutlass.Constexpr[bool],
137
+ mask_causal: cutlass.Constexpr[bool],
138
+ mask_local: cutlass.Constexpr[bool] = False,
139
+ mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
140
+ aux_tensors: Optional[list] = None,
141
+ fastdiv_mods=(None, None),
142
+ ) -> None:
143
+ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
144
+ acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB)
145
+ acc_shape = (self.tile_m, self.tile_n)
146
+ cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
147
+ tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB)
148
+ # We use t0ScS as these indices are known at compile time. We then must subtract the
149
+ # column limit by the thread column offset.
150
+ t0ScS_mn = layout_utils.reshape_acc_to_mn(
151
+ thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
152
+ )
153
+ ROW = 0 if const_expr(not self.swap_AB) else 1
154
+ COL = 1 if const_expr(not self.swap_AB) else 0
155
+ thr_col_offset = tScS_mn[0][COL]
156
+ # To handle edge cases of completely masked out rows where n_block_max = 0,
157
+ # we treat negative n_blocks as 0th n_block
158
+ # TODO: find more transparent solution
159
+ if n_block < 0:
160
+ n_block = 0
161
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
162
+ if const_expr(not mask_causal and not mask_local and mask_mod is None):
163
+ if const_expr(mask_seqlen):
164
+ # The compiler now choses not to use R2P
165
+ r2p = const_expr(False and not self.swap_AB)
166
+ if const_expr(not r2p):
167
+ # traverse column index.
168
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
169
+ oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit
170
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
171
+ acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
172
+ else:
173
+ mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90)
174
+
175
+ elif const_expr(
176
+ not mask_causal and not mask_local and mask_mod is not None
177
+ ): # FlexAttention mask mod
178
+ nrow = const_expr(cute.size(tScS_mn.shape[0]))
179
+ ncol = const_expr(cute.size(tScS_mn.shape[1]))
180
+ has_fastdiv = const_expr(
181
+ fastdiv_mods is not None
182
+ and fastdiv_mods[0] is not None
183
+ and fastdiv_mods[1] is not None
184
+ )
185
+ wrap_aux_indices = const_expr(
186
+ has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
187
+ )
188
+
189
+ for r in cutlass.range_constexpr(nrow):
190
+ # Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
191
+ local_row = tScS_mn[r, 0][ROW]
192
+ global_row_idx = local_row + m_block * self.tile_m
193
+ row_for_mod = global_row_idx
194
+ head_idx_for_mod = head_idx
195
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
196
+ head_offset = global_row_idx % self.qhead_per_kvhead_packgqa
197
+ head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
198
+ row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa
199
+ row_for_seqlen = row_for_mod
200
+ if const_expr(wrap_aux_indices):
201
+ _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])
202
+
203
+ for col in cutlass.range_constexpr(ncol):
204
+ col_idx_local = t0ScS_mn[0, col][COL]
205
+ # Convert to absolute column index
206
+ global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
207
+ col_for_mod = global_col_idx
208
+ if const_expr(wrap_aux_indices):
209
+ _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])
210
+
211
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
212
+ head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
213
+ q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
214
+ kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
215
+ mask_value = mask_mod(
216
+ batch_idx_ssa,
217
+ head_idx_ssa,
218
+ q_idx_ssa,
219
+ kv_idx_ssa,
220
+ self.seqlen_info,
221
+ aux_tensors,
222
+ )
223
+ cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
224
+ if const_expr(mask_seqlen):
225
+ out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (
226
+ global_col_idx >= self.seqlen_k
227
+ )
228
+ if out_of_bounds:
229
+ acc_S_mn[r, col] = -cutlass.Float32.inf
230
+ else:
231
+ acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
232
+ else:
233
+ acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
234
+
235
+ else: # Causal or local
236
+ if const_expr(not self.swap_AB):
237
+ # If PackGQA, we split the work of compute divmod among threads in the same row
238
+ threads_per_row = thr_mma.tv_layout_C.shape[0][0]
239
+ mma_m_idx = None
240
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
241
+ assert not self.swap_AB, "swap_AB with PackGQA not supported yet"
242
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, (
243
+ "threads_per_row must divide WARP_SIZE"
244
+ )
245
+ assert cute.size(acc_S_mn.shape[0]) <= threads_per_row
246
+ tidx = thr_mma.thr_idx
247
+ mma_m_idx = (
248
+ m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0]
249
+ ) // self.qhead_per_kvhead_packgqa
250
+ causal_row_offset = (
251
+ 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset
252
+ )
253
+ if const_expr(mask_causal):
254
+ r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100
255
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
256
+ # get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
257
+ if const_expr(self.qhead_per_kvhead_packgqa == 1):
258
+ row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
259
+ else:
260
+ row_idx = utils.shuffle_sync(
261
+ mma_m_idx, r % threads_per_row, width=threads_per_row
262
+ )
263
+ col_limit_right = row_idx + causal_row_offset
264
+ if const_expr(mask_seqlen):
265
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
266
+ if const_expr(not r2p):
267
+ # traverse column index.
268
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
269
+ acc_S_mn[r, c] = (
270
+ -Float32.inf
271
+ if t0ScS_mn[0, c][1] >= col_limit_right
272
+ else acc_S_mn[r, c]
273
+ )
274
+ else:
275
+ mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True)
276
+ else: # Local
277
+ local_row_offset_right = (
278
+ causal_row_offset + self.window_size_right
279
+ if const_expr(self.window_size_right is not None)
280
+ else None
281
+ )
282
+ local_row_offset_left = (
283
+ causal_row_offset - 1 - self.window_size_left
284
+ if const_expr(self.window_size_left is not None)
285
+ else None
286
+ )
287
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
288
+ if const_expr(self.qhead_per_kvhead_packgqa == 1):
289
+ row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
290
+ else:
291
+ row_idx = utils.shuffle_sync(
292
+ mma_m_idx, r % threads_per_row, width=threads_per_row
293
+ )
294
+ if const_expr(self.window_size_right is not None):
295
+ col_limit_right = row_idx + local_row_offset_right
296
+ else:
297
+ col_limit_right = self.tile_n
298
+ if const_expr(mask_seqlen):
299
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
300
+ col_limit_left = (
301
+ row_idx + local_row_offset_left
302
+ if const_expr(self.window_size_left is not None)
303
+ else 0
304
+ )
305
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left)
306
+ # traverse column index.
307
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
308
+ col_idx = t0ScS_mn[0, c][1]
309
+ # only consider the column index, so the row index sets to 0.
310
+ if col_idx >= col_limit_right or col_idx < col_limit_left:
311
+ acc_S_mn[r, c] = -Float32.inf
312
+ else: # swap_AB
313
+ assert self.qhead_per_kvhead_packgqa == 1
314
+ thr_row_offset = tScS_mn[0][ROW]
315
+ causal_row_offset = (
316
+ seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset
317
+ )
318
+ if const_expr(mask_causal):
319
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
320
+ col0 = t0ScS_mn[0, c][COL]
321
+ # If col0 is beyond the column limit, we want to mask out the entire
322
+ # column, by setting row limit to be self.tile_m.
323
+ row_limit_top = (
324
+ self.tile_m
325
+ if col0 >= seqlenk_col_limit and mask_seqlen
326
+ else col0 - causal_row_offset
327
+ )
328
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
329
+ acc_S_mn[r, c] = (
330
+ -Float32.inf
331
+ if t0ScS_mn[r, 0][ROW] < row_limit_top
332
+ else acc_S_mn[r, c]
333
+ )
334
+ else:
335
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
336
+ col0 = t0ScS_mn[0, c][COL]
337
+ # If col0 is beyond the column limit, we want to mask out the entire
338
+ # column, by setting row limit to be self.tile_m.
339
+ row_limit_top = (
340
+ self.tile_m
341
+ if col0 >= seqlenk_col_limit
342
+ else col0 - causal_row_offset - self.window_size_right
343
+ )
344
+ # TODO: do we need col_limit_sink?
345
+ row_limit_bot = col0 - causal_row_offset + self.window_size_left
346
+ for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
347
+ row_idx = t0ScS_mn[r, 0][ROW]
348
+ acc_S_mn[r, c] = (
349
+ -Float32.inf
350
+ if row_idx < row_limit_top or row_idx > row_limit_bot
351
+ else acc_S_mn[r, c]
352
+ )
353
+
354
+ @cute.jit
355
+ def apply_mask_sm100(
356
+ self,
357
+ acc_S: cute.Tensor,
358
+ m_block: Int32,
359
+ n_block: Int32,
360
+ thr_mma: cute.TiledMma,
361
+ thr_tmem_load: cute.TiledCopy,
362
+ mask_seqlen: cutlass.Constexpr[bool],
363
+ mask_causal: cutlass.Constexpr[bool],
364
+ mask_local: cutlass.Constexpr[bool] = False,
365
+ mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
366
+ batch_idx: Int32 = None,
367
+ head_idx: Int32 = None,
368
+ aux_tensors: Optional[list] = None,
369
+ fastdiv_mods=(None, None),
370
+ head_divmod=None,
371
+ check_q_boundary: bool = False,
372
+ ) -> None:
373
+ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
374
+ acc_shape = (self.tile_m, self.tile_n)
375
+ cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
376
+ tScS = thr_mma.partition_C(cS)
377
+ tScS = tScS[(None, None), 0, 0]
378
+ tScS_t2r = thr_tmem_load.partition_D(tScS)
379
+ # To handle edge cases of completely masked out rows where n_block_max = 0,
380
+ # we treat negative n_blocks as 0th n_block
381
+ # TODO: find more transparent solution
382
+ if n_block < 0:
383
+ n_block = 0
384
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n
385
+ r2p = True
386
+ if const_expr(not mask_causal and not mask_local and mask_mod is None):
387
+ if const_expr(mask_seqlen):
388
+ if const_expr(not r2p):
389
+ for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
390
+ # if tScS_t2r[i][1] >= seqlenk_col_limit:
391
+ # acc_S[i] = -Float32.inf
392
+ # For some reason the 2 lines above generate really bad SASS
393
+ acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
394
+ else:
395
+ mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)
396
+
397
+ elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
398
+ # Block sparse case w/ mask_mod
399
+ has_fastdiv = const_expr(
400
+ fastdiv_mods is not None
401
+ and fastdiv_mods[0] is not None
402
+ and fastdiv_mods[1] is not None
403
+ )
404
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
405
+
406
+ ncol = const_expr(cute.size(tScS_t2r.shape))
407
+ for i in cutlass.range_constexpr(ncol):
408
+ row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
409
+ col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
410
+ global_row = row_coord + m_block * self.tile_m
411
+ global_col = col_coord + n_block * self.tile_n
412
+
413
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
414
+ assert head_divmod is not None
415
+ mask_row, head_offset = divmod(global_row, head_divmod)
416
+ head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
417
+ else:
418
+ head_idx_for_mod = head_idx
419
+ mask_row = global_row
420
+
421
+ mask_row_for_mod = mask_row
422
+ if const_expr(has_fastdiv and aux_tensors is not None):
423
+ if check_q_boundary:
424
+ _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
425
+ global_col_for_mod = global_col
426
+ if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
427
+ _, global_col_for_mod = divmod(global_col, fastdiv_mods[1])
428
+
429
+ head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
430
+ mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
431
+ kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
432
+ mask_value = mask_mod(
433
+ batch_idx_ssa,
434
+ head_idx_ssa,
435
+ mask_row_ssa,
436
+ kv_idx_ssa,
437
+ self.seqlen_info,
438
+ aux_tensors,
439
+ )
440
+ cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
441
+ acc_S[i] = acc_S[i] if cond else -Float32.inf
442
+ if const_expr(mask_seqlen):
443
+ acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
444
+ if check_q_boundary:
445
+ acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
446
+
447
+ else: # Causal or local
448
+ causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q
449
+ row_idx = tScS_t2r[0][0] + m_block * self.tile_m
450
+ if const_expr(self.qhead_per_kvhead_packgqa != 1):
451
+ row_idx = row_idx // self.qhead_per_kvhead_packgqa
452
+ if const_expr(mask_causal):
453
+ col_limit_right = row_idx + causal_row_offset
454
+ if const_expr(mask_seqlen):
455
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
456
+ # if cute.arch.thread_idx()[0] % 32 == 0:
457
+ # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset)
458
+ ncol = const_expr(cute.size(tScS_t2r.shape))
459
+ if const_expr(not r2p):
460
+ for i in cutlass.range(ncol, unroll_full=True):
461
+ acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
462
+ else:
463
+ mask_r2p(acc_S, col_limit_right, arch=100, rank1=True)
464
+ else:
465
+ local_row_offset_right = (
466
+ causal_row_offset + self.window_size_right
467
+ if const_expr(self.window_size_right is not None)
468
+ else None
469
+ )
470
+ local_row_offset_left = (
471
+ causal_row_offset - 1 - self.window_size_left
472
+ if const_expr(self.window_size_left is not None)
473
+ else None
474
+ )
475
+ if const_expr(self.window_size_right is not None):
476
+ col_limit_right = row_idx + local_row_offset_right
477
+ else:
478
+ col_limit_right = self.tile_n
479
+ if const_expr(mask_seqlen):
480
+ col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
481
+ col_limit_left = (
482
+ row_idx + local_row_offset_left
483
+ if const_expr(self.window_size_left is not None)
484
+ else 0
485
+ )
486
+ if const_expr(not r2p):
487
+ # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
488
+ for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
489
+ col_idx = tScS_t2r[i][1]
490
+ acc_S[i] = (
491
+ -Float32.inf
492
+ if col_idx >= col_limit_right or col_idx < col_limit_left
493
+ else acc_S[i]
494
+ )
495
+ else:
496
+ # XOR-based R2P dual bound masking
497
+ mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right)
498
+
499
+ @cute.jit
500
+ def apply_mask_sm100_transposed(
501
+ self,
502
+ acc_S: cute.Tensor,
503
+ tScS_t2r: cute.Tensor,
504
+ t0ScS_t2r: cute.Tensor,
505
+ m_block: cutlass.Int32,
506
+ n_block: cutlass.Int32,
507
+ mask_seqlen: cutlass.Constexpr,
508
+ mask_causal: cutlass.Constexpr,
509
+ mask_local: cutlass.Constexpr,
510
+ mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
511
+ batch_idx: Int32 = None,
512
+ head_idx: Int32 = None,
513
+ aux_tensors: Optional[list] = None,
514
+ fastdiv_mods=(None, None),
515
+ is_full_block: bool = False,
516
+ check_m_boundary: bool = True,
517
+ ) -> None:
518
+ """
519
+ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.
520
+
521
+ Coordinate conventio:
522
+ - ROW corresponds to Q (m_block)
523
+ - COL corresponds to KV (n_block)
524
+
525
+ is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking.
526
+ check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks).
527
+ When iterating m_blocks in forward order, only the last m_block may be partial.
528
+ """
529
+ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
530
+ ROW = 0 if const_expr(not self.swap_AB) else 1
531
+ COL = 1 if const_expr(not self.swap_AB) else 0
532
+ # assert t0ScS_t2r[0][COL] == 0, "col0 == 0" # tmp comment for 2-cta bwd
533
+ thr_col_offset = tScS_t2r[0][COL]
534
+ seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
535
+
536
+ if const_expr(not mask_causal and not mask_local and mask_mod is not None):
537
+ # Block sparse case with mask_mod (backward)
538
+ #
539
+ # Coordinate convention: ROW → Q (m_block), COL → KV (n_block).
540
+ # These already account for swap_AB.
541
+ #
542
+ # FULL blocks: mask_mod returns True for all elements, so skip it.
543
+ # Still need seqlen bounds check (elements may be OOB on last m_block).
544
+ # PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds.
545
+ if is_full_block:
546
+ if const_expr(mask_seqlen):
547
+ if seqlenk_col_limit <= 0:
548
+ # Entire tile is OOB for K
549
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
550
+ acc_S[i] = -cutlass.Float32.inf
551
+ elif check_m_boundary:
552
+ # Last m_block: check Q and K boundaries
553
+ ncol = const_expr(cute.size(tScS_t2r.shape))
554
+ for i in cutlass.range_constexpr(ncol):
555
+ row_coord = tScS_t2r[i][ROW]
556
+ col_coord = tScS_t2r[i][COL]
557
+ global_q = row_coord + m_block * self.tile_m
558
+ global_kv = col_coord + n_block * self.tile_n
559
+ q_out_of_bounds = global_q >= self.seqlen_q
560
+ kv_out_of_bounds = global_kv >= self.seqlen_k
561
+ out_of_bounds = q_out_of_bounds or kv_out_of_bounds
562
+ acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
563
+ else:
564
+ # Partial block
565
+ has_fastdiv = const_expr(
566
+ fastdiv_mods is not None
567
+ and fastdiv_mods[0] is not None
568
+ and fastdiv_mods[1] is not None
569
+ )
570
+ wrap_aux_indices = const_expr(
571
+ has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
572
+ )
573
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
574
+ head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
575
+
576
+ ncol = const_expr(cute.size(tScS_t2r.shape))
577
+ for i in cutlass.range_constexpr(ncol):
578
+ row_coord = tScS_t2r[i][ROW]
579
+ col_coord = tScS_t2r[i][COL]
580
+ global_q = row_coord + m_block * self.tile_m
581
+ global_kv = col_coord + n_block * self.tile_n
582
+
583
+ q_idx_for_mod = global_q
584
+ kv_idx_for_mod = global_kv
585
+ if const_expr(wrap_aux_indices):
586
+ _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0])
587
+ _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1])
588
+
589
+ q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32)
590
+ kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32)
591
+
592
+ mask_value = mask_mod(
593
+ batch_idx_ssa,
594
+ head_idx_ssa,
595
+ q_idx_ssa,
596
+ kv_idx_ssa,
597
+ self.seqlen_info,
598
+ aux_tensors,
599
+ )
600
+ cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
601
+ acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf
602
+
603
+ if const_expr(mask_seqlen):
604
+ # check_m_boundary=False skips q check for non-boundary m_blocks
605
+ q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q)
606
+ kv_out_of_bounds = global_kv >= self.seqlen_k
607
+ out_of_bounds = q_out_of_bounds or kv_out_of_bounds
608
+ acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
609
+
610
+ elif const_expr(not mask_causal and not mask_local):
611
+ if const_expr(mask_seqlen):
612
+ if seqlenk_col_limit <= 0:
613
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
614
+ acc_S[i] = -cutlass.Float32.inf
615
+ else: # Causal or local
616
+ thr_row_offset = tScS_t2r[0][ROW]
617
+ seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
618
+ causal_offset = seqlenq_row_limit - seqlenk_col_limit
619
+ if const_expr(mask_causal):
620
+ # tidx = cute.arch.thread_idx()[0] % 256
621
+ # if tidx < 32:
622
+ # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])
623
+ row_limit_top = causal_offset
624
+ if const_expr(mask_seqlen):
625
+ # If col is beyond the column limit, we want to mask out the entire
626
+ # column, by setting row limit to be self.tile_m.
627
+ if seqlenk_col_limit <= 0:
628
+ row_limit_top = self.tile_m
629
+ r2p = True
630
+ if const_expr(not r2p):
631
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
632
+ acc_S[i] = (
633
+ -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i]
634
+ )
635
+ else:
636
+ num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
637
+ mask_r2p_transposed(acc_S, row_limit_top, num_rep)
638
+ else:
639
+ if const_expr(self.window_size_right is not None):
640
+ row_limit_top = causal_offset - self.window_size_right
641
+ else:
642
+ row_limit_top = 0
643
+ if const_expr(self.window_size_left is not None):
644
+ row_limit_bot = causal_offset + self.window_size_left
645
+ if const_expr(mask_seqlen):
646
+ if seqlenk_col_limit <= 0:
647
+ row_limit_top = self.tile_m
648
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
649
+ row_idx = t0ScS_t2r[i][ROW]
650
+ local_mask = row_idx < row_limit_top
651
+ if const_expr(self.window_size_left is not None):
652
+ local_mask |= row_idx > row_limit_bot
653
+ acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
build/torch-cuda/metadata.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 0,
3
+ "python-depends": [
4
+ "einops",
5
+ "tvm-ffi",
6
+ "nvidia-cutlass-dsl"
7
+ ]
8
+ }
build/torch-cuda/mma_sm100_desc.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+ # Ported Cutlass code from C++ to Python:
3
+ # https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp
4
+ # https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp
5
+
6
+ from enum import IntEnum
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Enumerations that match the HW encodings (values MUST stay identical)
13
+ # ---------------------------------------------------------------------------
14
+
15
+
16
+ class Major(IntEnum): # matrix “layout” in the ISA docs
17
+ K = 0
18
+ MN = 1
19
+
20
+
21
+ class ScaleIn(IntEnum): # negate flags
22
+ One = 0
23
+ Neg = 1
24
+
25
+
26
+ class Saturate(IntEnum):
27
+ False_ = 0
28
+ True_ = 1
29
+
30
+
31
+ class CFormat(IntEnum): # 2-bit field (bits 4-5)
32
+ F16 = 0
33
+ F32 = 1
34
+ S32 = 2
35
+
36
+
37
+ class F16F32Format(IntEnum): # 3-bit field (A/B element type)
38
+ F16 = 0
39
+ BF16 = 1
40
+ TF32 = 2
41
+
42
+
43
+ class S8Format(IntEnum):
44
+ UINT8 = 0
45
+ INT8 = 1
46
+
47
+
48
+ class MXF8F6F4Format(IntEnum):
49
+ E4M3 = 0
50
+ E5M2 = 1
51
+ E2M3 = 3
52
+ E3M2 = 4
53
+ E2M1 = 5
54
+
55
+
56
+ class MaxShift(IntEnum):
57
+ NoShift = 0
58
+ MaxShift8 = 1
59
+ MaxShift16 = 2
60
+ MaxShift32 = 3
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # CUTLASS-type → encoding helpers
65
+ # ---------------------------------------------------------------------------
66
+
67
+
68
+ def to_UMMA_format(cutlass_type) -> int:
69
+ """
70
+ Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.
71
+ """
72
+ if cutlass_type is cutlass.Int8:
73
+ return S8Format.INT8
74
+ # Unsigned 8-bit (if available in your CUTLASS build)
75
+ if cutlass_type is cutlass.Uint8:
76
+ return S8Format.UINT8
77
+ # FP-16 / BF-16
78
+ if cutlass_type is cutlass.Float16:
79
+ return F16F32Format.F16
80
+ if cutlass_type is cutlass.BFloat16:
81
+ return F16F32Format.BF16
82
+ # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)
83
+ if cutlass_type is cutlass.TFloat32:
84
+ return F16F32Format.TF32
85
+ # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them
86
+ if cutlass_type is cutlass.FloatE4M3FN:
87
+ return MXF8F6F4Format.E4M3
88
+ if cutlass_type is cutlass.FloatE5M2:
89
+ return MXF8F6F4Format.E5M2
90
+ raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}")
91
+
92
+
93
+ def to_C_format(cutlass_type) -> int:
94
+ """
95
+ Map a CUTLASS scalar class to the 2-bit accumulator encoding.
96
+ """
97
+ if cutlass_type is cutlass.Float16:
98
+ return CFormat.F16
99
+ if cutlass_type is cutlass.Float32:
100
+ return CFormat.F32
101
+ if cutlass_type is cutlass.Int32:
102
+ return CFormat.S32
103
+ raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}")
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # The constructor – accepts only CUTLASS scalar classes
108
+ # ---------------------------------------------------------------------------
109
+
110
+
111
+ def make_instr_desc(
112
+ a_type, # CUTLASS scalar class, e.g. cutlass.Int8
113
+ b_type,
114
+ c_type,
115
+ M: int, # 64, 128 or 256
116
+ N: int, # 8 … 256 (multiple of 8)
117
+ a_major: Major,
118
+ b_major: Major,
119
+ a_neg: ScaleIn = ScaleIn.One,
120
+ b_neg: ScaleIn = ScaleIn.One,
121
+ c_sat: Saturate = Saturate.False_,
122
+ is_sparse: bool = False,
123
+ max_shift: MaxShift = MaxShift.NoShift,
124
+ ) -> int:
125
+ """
126
+ Build the 32-bit instruction descriptor for Blackwell MMA.
127
+ All matrix/accumulator **types must be CUTLASS scalar classes** –
128
+ passing integers is forbidden.
129
+ """
130
+ # --- encode element formats -------------------------------------------------
131
+ a_fmt = int(to_UMMA_format(a_type))
132
+ b_fmt = int(to_UMMA_format(b_type))
133
+ c_fmt = int(to_C_format(c_type))
134
+
135
+ # --- range checks on M/N -----------------------------------------------------
136
+ if M not in (64, 128, 256):
137
+ raise ValueError("M must be 64, 128 or 256")
138
+ if N < 8 or N > 256 or (N & 7):
139
+ raise ValueError("N must be a multiple of 8 in the range 8…256")
140
+
141
+ m_dim = M >> 4 # 5-bit field
142
+ n_dim = N >> 3 # 6-bit field
143
+
144
+ # fmt: off
145
+ # --- pack the bit-fields -----------------------------------------------------
146
+ desc = 0
147
+ desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here)
148
+ desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag
149
+ desc |= (int(c_sat) & 0x1) << 3 # saturate
150
+ desc |= (c_fmt & 0x3) << 4 # c_format
151
+ desc |= (a_fmt & 0x7) << 7 # a_format
152
+ desc |= (b_fmt & 0x7) << 10 # b_format
153
+ desc |= (int(a_neg) & 0x1) << 13 # a_negate
154
+ desc |= (int(b_neg) & 0x1) << 14 # b_negate
155
+ desc |= (int(a_major) & 0x1) << 15 # a_major
156
+ desc |= (int(b_major) & 0x1) << 16 # b_major
157
+ desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits)
158
+ desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits)
159
+ desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits)
160
+ # fmt: on
161
+
162
+ return desc & 0xFFFF_FFFF # ensure 32-bit result
163
+
164
+
165
+ def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):
166
+ return make_instr_desc(
167
+ op.a_dtype,
168
+ op.b_dtype,
169
+ op.acc_dtype,
170
+ op.shape_mnk[0],
171
+ op.shape_mnk[1],
172
+ Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
173
+ Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,
174
+ )
175
+
176
+
177
+ class LayoutType(IntEnum): # occupies the top-3 bits [61:64)
178
+ SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs)
179
+ SWIZZLE_128B_BASE32B = 1
180
+ SWIZZLE_128B = 2
181
+ SWIZZLE_64B = 4
182
+ SWIZZLE_32B = 6
183
+ # values 3,5,7 are reserved / illegal for UMMA
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Helpers – figure out the SWIZZLE_* family from the tensor layout
188
+ # ---------------------------------------------------------------------------
189
+
190
+
191
+ def _layout_type(swizzle: cute.Swizzle) -> LayoutType:
192
+ B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift
193
+
194
+ if M == 4: # Swizzle<*,4,3>
195
+ if S != 3:
196
+ raise ValueError("Unexpected swizzle shift – want S==3 for M==4")
197
+ return {
198
+ 0: LayoutType.SWIZZLE_NONE,
199
+ 1: LayoutType.SWIZZLE_32B,
200
+ 2: LayoutType.SWIZZLE_64B,
201
+ 3: LayoutType.SWIZZLE_128B,
202
+ }[B] # KeyError ⇒ invalid B→ raise
203
+ if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5)
204
+ if (B, S) != (2, 2):
205
+ raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B")
206
+ return LayoutType.SWIZZLE_128B_BASE32B
207
+
208
+ # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout
209
+ raise ValueError("Unsupported swizzle triple for UMMA smem descriptor")
210
+
211
+
212
+ def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:
213
+ """
214
+ Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit
215
+ smem-descriptor, without the smem start address.
216
+ layout must correspond to layout of an uint128 tensor.
217
+ """
218
+ # ------------------------------------------------------------------ meta
219
+ layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family
220
+
221
+ VERSION = 1 # bits 46–47
222
+ LBO_MODE = 0 # bit 52
223
+ BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0)
224
+
225
+ # ---------------------------------------------------------- strides (units: uint128_t = 16 B)
226
+ swizzle_atom_mn_size = {
227
+ LayoutType.SWIZZLE_NONE: 1,
228
+ LayoutType.SWIZZLE_32B: 2,
229
+ LayoutType.SWIZZLE_64B: 4,
230
+ LayoutType.SWIZZLE_128B: 8,
231
+ LayoutType.SWIZZLE_128B_BASE32B: 8,
232
+ }[layout_type]
233
+
234
+ if major is Major.MN:
235
+ swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8
236
+ canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))
237
+ if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
238
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.")
239
+ stride_00 = canonical_layout.stride[0][0]
240
+ if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:
241
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
242
+ stride_10 = canonical_layout.stride[1][0]
243
+ if stride_10 != swizzle_atom_mn_size:
244
+ raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.")
245
+ stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]
246
+ if layout_type is LayoutType.SWIZZLE_NONE:
247
+ stride_byte_offset, leading_byte_offset = stride_01, stride_11
248
+ else:
249
+ stride_byte_offset, leading_byte_offset = stride_11, stride_01
250
+ else:
251
+ if layout_type == LayoutType.SWIZZLE_128B_BASE32B:
252
+ raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K")
253
+ if not cute.size(layout.shape[0]) % 8 == 0:
254
+ raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.")
255
+ canonical_layout = cute.logical_divide(layout, (8, 2))
256
+ if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):
257
+ raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.")
258
+ stride_00 = canonical_layout.stride[0][0]
259
+ if stride_00 != swizzle_atom_mn_size:
260
+ raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
261
+ stride_10 = canonical_layout.stride[1][0]
262
+ if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:
263
+ raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.")
264
+ stride_01 = canonical_layout.stride[0][1]
265
+ stride_byte_offset, leading_byte_offset = stride_01, stride_10
266
+
267
+ # ------------------------------------------------------------------ pack
268
+ desc = 0
269
+ # leading_byte_offset_ [16:30)
270
+ desc |= (leading_byte_offset & 0x3FFF) << 16
271
+ # stride_byte_offset_ [32:46)
272
+ desc |= (stride_byte_offset & 0x3FFF) << 32
273
+ # version_ [46:48)
274
+ desc |= (VERSION & 0x3) << 46
275
+ # base_offset_ [49:52)
276
+ desc |= (BASE_OFFSET & 0x7) << 49
277
+ # lbo_mode_ [52:53)
278
+ desc |= (LBO_MODE & 0x1) << 52
279
+ # layout_type_ [61:64)
280
+ desc |= (int(layout_type) & 0x7) << 61
281
+
282
+ return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width
283
+
284
+
285
+ def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:
286
+ # 14 bits, remove 4 LSB (bits 0-13 in desc)
287
+ return (start_addr.toint() & 0x3FFFF) >> 4
288
+
289
+
290
+ def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int:
291
+ sA_swizzle = sA.iterator.type.swizzle_type
292
+ return make_smem_desc_base(
293
+ cute.recast_layout(128, sA.element_type.width, sA.layout[0]),
294
+ sA_swizzle,
295
+ major,
296
+ )
build/torch-cuda/named_barrier.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+
3
+ import enum
4
+
5
+
6
+ class NamedBarrierFwd(enum.IntEnum):
7
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
8
+ WarpSchedulerWG1 = enum.auto()
9
+ WarpSchedulerWG2 = enum.auto()
10
+ WarpSchedulerWG3 = enum.auto()
11
+ PFull = enum.auto()
12
+ PEmpty = enum.auto()
13
+
14
+
15
+ class NamedBarrierBwd(enum.IntEnum):
16
+ Epilogue = enum.auto()
17
+ WarpSchedulerWG1 = enum.auto()
18
+ WarpSchedulerWG2 = enum.auto()
19
+ WarpSchedulerWG3 = enum.auto()
20
+ PdS = enum.auto()
21
+ dQFullWG0 = enum.auto()
22
+ dQFullWG1 = enum.auto()
23
+ dQEmptyWG0 = enum.auto()
24
+ dQEmptyWG1 = enum.auto()
25
+
26
+
27
+ class NamedBarrierBwdSm100(enum.IntEnum):
28
+ EpilogueWG1 = enum.auto()
29
+ EpilogueWG2 = enum.auto()
30
+ Compute = enum.auto()
31
+ dQaccReduce = enum.auto()
32
+ TmemPtr = enum.auto()
build/torch-cuda/pack_gqa.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+ from .quack import layout_utils
8
+ from . import utils
9
+
10
+
11
+ class PackGQA:
12
+ def __init__(
13
+ self,
14
+ m_block_size: cutlass.Constexpr[int],
15
+ head_dim_padded: cutlass.Constexpr[int],
16
+ check_hdim_oob: cutlass.Constexpr[bool],
17
+ qhead_per_kvhead: cutlass.Constexpr[bool],
18
+ ):
19
+ self.m_block_size = m_block_size
20
+ self.head_dim_padded = head_dim_padded
21
+ self.check_hdim_oob = check_hdim_oob
22
+ self.qhead_per_kvhead = qhead_per_kvhead
23
+
24
+ @cute.jit
25
+ def compute_ptr(
26
+ self,
27
+ tensor: cute.Tensor,
28
+ cRows: cute.Tensor,
29
+ tidx: cutlass.Int32,
30
+ block: cutlass.Int32,
31
+ threads_per_row: cutlass.Constexpr[int],
32
+ num_threads: cutlass.Constexpr[int],
33
+ ):
34
+ num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)
35
+ tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64)
36
+ for i in cutlass.range_constexpr(num_ptr_per_thread):
37
+ row = i * num_threads + cRows[tidx % threads_per_row][0]
38
+ idx = block * self.m_block_size + row
39
+ m_idx = idx // self.qhead_per_kvhead
40
+ h_idx = idx - m_idx * self.qhead_per_kvhead
41
+ tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()
42
+ return tPrPtr
43
+
44
+ @cute.jit
45
+ def load_Q(
46
+ self,
47
+ mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
48
+ sQ: cute.Tensor, # (m_block_size, head_dim_padded)
49
+ gmem_tiled_copy: cute.TiledCopy,
50
+ tidx: cutlass.Int32,
51
+ block: cutlass.Int32,
52
+ seqlen: cutlass.Int32,
53
+ ):
54
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
55
+ cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
56
+ tQsQ = gmem_thr_copy.partition_D(sQ)
57
+ tQcQ = gmem_thr_copy.partition_S(cQ)
58
+ t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)
59
+ tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])
60
+ tQcQ_row = tQcQ[0, None, 0]
61
+ threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
62
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
63
+ num_threads = gmem_tiled_copy.size
64
+ tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)
65
+ for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
66
+ q_ptr_i64 = utils.shuffle_sync(
67
+ tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
68
+ )
69
+ q_gmem_ptr = cute.make_ptr(
70
+ mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
71
+ )
72
+ if (
73
+ t0QcQ[0, m, 0][0]
74
+ < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]
75
+ ):
76
+ mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))
77
+ elems_per_load = cute.size(tQsQ.shape[0][0])
78
+ mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))
79
+ for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):
80
+ ki = tQcQ[0, 0, k][1] // elems_per_load
81
+ cute.copy(
82
+ gmem_thr_copy,
83
+ mQ_cur_copy[None, ki],
84
+ tQsQ[None, m, k],
85
+ pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
86
+ )
87
+ # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
88
+
89
+ @cute.jit
90
+ def store_LSE(
91
+ self,
92
+ mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q)
93
+ tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded)
94
+ tiled_mma: cute.TiledMma,
95
+ tidx: cutlass.Int32,
96
+ block: cutlass.Int32,
97
+ seqlen: cutlass.Int32,
98
+ ):
99
+ thr_mma = tiled_mma.get_slice(tidx)
100
+ caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
101
+ taccOcO = thr_mma.partition_C(caccO)
102
+ taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]
103
+ assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
104
+ threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
105
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
106
+ assert cute.size(tLSErLSE) <= threads_per_row
107
+ num_threads = tiled_mma.size
108
+ tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)
109
+ for m in cutlass.range_constexpr(cute.size(tLSErLSE)):
110
+ lse_ptr_i64 = utils.shuffle_sync(
111
+ tPrLSEPtr[m // threads_per_row],
112
+ m % threads_per_row,
113
+ width=threads_per_row,
114
+ )
115
+ lse_gmem_ptr = cute.make_ptr(
116
+ mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
117
+ )
118
+ row = block * self.m_block_size + taccOcO_row[m][0]
119
+ # Only the thread corresponding to column 0 writes out the lse to gmem
120
+ if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:
121
+ mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))
122
+ mLSE_copy[0] = tLSErLSE[m]
123
+
124
+ @cute.jit
125
+ def store_O(
126
+ self,
127
+ mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim)
128
+ tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy
129
+ gmem_tiled_copy: cute.TiledCopy,
130
+ tidx: cutlass.Int32,
131
+ block: cutlass.Int32,
132
+ seqlen: cutlass.Int32,
133
+ ):
134
+ gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)
135
+ cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
136
+ tOcO = gmem_thr_copy.partition_S(cO)
137
+ t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)
138
+ tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
139
+ tOcO_row = tOcO[0, None, 0]
140
+ threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]
141
+ assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
142
+ num_threads = gmem_tiled_copy.size
143
+ tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)
144
+ for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
145
+ o_ptr_i64 = utils.shuffle_sync(
146
+ tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row
147
+ )
148
+ o_gmem_ptr = cute.make_ptr(
149
+ mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
150
+ )
151
+ if (
152
+ t0OcO[0, m, 0][0]
153
+ < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]
154
+ ):
155
+ mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))
156
+ elems_per_load = cute.size(tOrO.shape[0][0])
157
+ mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))
158
+ for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):
159
+ ki = tOcO[0, 0, k][1] // elems_per_load
160
+ cute.copy(
161
+ gmem_thr_copy,
162
+ tOrO[None, m, k],
163
+ mO_cur_copy[None, ki],
164
+ pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,
165
+ )
build/torch-cuda/paged_kv.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+ from dataclasses import dataclass
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass.cute.nvgpu import cpasync
7
+ from cutlass import Int32, const_expr
8
+
9
+ from . import utils
10
+ from .quack.cute_dsl_utils import ParamsBase
11
+ from cutlass.cute import FastDivmodDivisor
12
+
13
+ import math
14
+
15
+
16
+ @dataclass
17
+ class PagedKVManager(ParamsBase):
18
+ mPageTable: cute.Tensor
19
+ mK_paged: cute.Tensor
20
+ mV_paged: cute.Tensor
21
+ thread_idx: Int32
22
+
23
+ page_size_divmod: FastDivmodDivisor
24
+ seqlen_k: Int32
25
+ leftpad_k: Int32
26
+ n_block_size: Int32
27
+ num_threads: cutlass.Constexpr[Int32]
28
+ head_dim_padded: cutlass.Constexpr[Int32]
29
+ head_dim_v_padded: cutlass.Constexpr[Int32]
30
+
31
+ gmem_threads_per_row: cutlass.Constexpr[Int32]
32
+ page_entry_per_thread: Int32
33
+ async_copy_elems: Int32
34
+
35
+ gmem_tiled_copy_KV: cute.TiledCopy
36
+ gmem_thr_copy_KV: cute.TiledCopy
37
+ tPrPage: cute.Tensor
38
+ tPrPageOffset: cute.Tensor
39
+ tKpK: cute.Tensor
40
+ tVpV: cute.Tensor
41
+
42
+ @staticmethod
43
+ def create(
44
+ mPageTable: cute.Tensor,
45
+ mK_paged: cute.Tensor,
46
+ mV_paged: cute.Tensor,
47
+ page_size_divmod: FastDivmodDivisor,
48
+ bidb: Int32,
49
+ bidh: Int32,
50
+ thread_idx: Int32,
51
+ seqlen_k: Int32,
52
+ leftpad_k: Int32,
53
+ n_block_size: cutlass.Constexpr[Int32],
54
+ head_dim_padded: cutlass.Constexpr[Int32],
55
+ head_dim_v_padded: cutlass.Constexpr[Int32],
56
+ num_threads: cutlass.Constexpr[Int32],
57
+ dtype: Type[cutlass.Numeric],
58
+ ):
59
+ universal_copy_bits = 128
60
+ async_copy_elems = universal_copy_bits // dtype.width
61
+ dtype_bytes = dtype.width // 8
62
+ gmem_k_block_size = math.gcd(
63
+ head_dim_padded,
64
+ head_dim_v_padded,
65
+ 128 // dtype_bytes,
66
+ )
67
+ assert gmem_k_block_size % async_copy_elems == 0
68
+ gmem_threads_per_row = gmem_k_block_size // async_copy_elems
69
+ assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0
70
+ atom_async_copy = cute.make_copy_atom(
71
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
72
+ dtype,
73
+ num_bits_per_copy=universal_copy_bits,
74
+ )
75
+ thr_layout = cute.make_ordered_layout(
76
+ (num_threads // gmem_threads_per_row, gmem_threads_per_row),
77
+ order=(1, 0),
78
+ )
79
+ val_layout = cute.make_layout((1, async_copy_elems))
80
+ gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
81
+ gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
82
+ page_entry_per_thread = n_block_size // num_threads
83
+
84
+ tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
85
+ tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
86
+
87
+ mPageTable = mPageTable[bidb, None]
88
+ mK_paged = mK_paged[None, None, bidh, None]
89
+ mV_paged = mV_paged[None, None, bidh, None]
90
+
91
+ cK = cute.make_identity_tensor((n_block_size, head_dim_padded))
92
+ tKcK = gmem_thr_copy_KV.partition_S(cK)
93
+ tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1])
94
+
95
+ if const_expr(head_dim_padded == head_dim_v_padded):
96
+ tVpV = tKpK
97
+ else:
98
+ cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
99
+ tVcV = gmem_thr_copy_KV.partition_S(cV)
100
+ tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0])
101
+
102
+ return PagedKVManager(
103
+ mPageTable,
104
+ mK_paged,
105
+ mV_paged,
106
+ thread_idx,
107
+ page_size_divmod,
108
+ seqlen_k,
109
+ leftpad_k,
110
+ n_block_size,
111
+ num_threads,
112
+ head_dim_padded,
113
+ head_dim_v_padded,
114
+ gmem_threads_per_row,
115
+ page_entry_per_thread,
116
+ async_copy_elems,
117
+ gmem_tiled_copy_KV,
118
+ gmem_thr_copy_KV,
119
+ tPrPage,
120
+ tPrPageOffset,
121
+ tKpK,
122
+ tVpV,
123
+ )
124
+
125
+ @cute.jit
126
+ def load_page_table(self, n_block: Int32):
127
+ for i in cutlass.range(self.page_entry_per_thread, unroll=1):
128
+ row = (
129
+ i * self.num_threads
130
+ + (self.thread_idx % self.gmem_threads_per_row)
131
+ * (self.num_threads // self.gmem_threads_per_row)
132
+ + (self.thread_idx // self.gmem_threads_per_row)
133
+ )
134
+ row_idx = n_block * self.n_block_size + row
135
+
136
+ page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)
137
+
138
+ is_valid = (
139
+ (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size
140
+ ) and row_idx < self.seqlen_k
141
+ page = self.mPageTable[page_idx] if is_valid else 0
142
+
143
+ self.tPrPage[i] = page
144
+ self.tPrPageOffset[i] = page_offset
145
+
146
+ @cute.jit
147
+ def compute_X_ptr(self, K_or_V: str):
148
+ tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
149
+ for i in cutlass.range(self.page_entry_per_thread, unroll=1):
150
+ page = self.tPrPage[i]
151
+ page_offset = self.tPrPageOffset[i]
152
+ if const_expr(K_or_V == "K"):
153
+ tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint()
154
+ else:
155
+ tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint()
156
+ return tPrXPtr
157
+
158
+ @cute.jit
159
+ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
160
+ assert K_or_V in ("K", "V")
161
+
162
+ tPrXPtr = self.compute_X_ptr(K_or_V)
163
+
164
+ # Finesse sX layout to be (M, N).
165
+ sX_pi = cute.make_tensor(
166
+ sX.iterator,
167
+ cute.make_layout(
168
+ (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
169
+ stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
170
+ ),
171
+ )
172
+
173
+ if const_expr(K_or_V == "V"):
174
+ # Need to transpose V
175
+ sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
176
+
177
+ head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
178
+ cX = cute.make_identity_tensor((self.n_block_size, head_dim))
179
+ tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)
180
+ tXcX = self.gmem_thr_copy_KV.partition_S(cX)
181
+ tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX)
182
+
183
+ seqlenk_row_limit = (
184
+ self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0
185
+ )
186
+ for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
187
+ row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit
188
+ should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean)
189
+ should_load.fill(row_valid)
190
+
191
+ x_ptr_i64 = utils.shuffle_sync(
192
+ tPrXPtr[m // self.gmem_threads_per_row],
193
+ m % self.gmem_threads_per_row,
194
+ width=self.gmem_threads_per_row,
195
+ )
196
+ x_gmem_ptr = cute.make_ptr(
197
+ self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
198
+ )
199
+ mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,)))
200
+ mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))
201
+
202
+ for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):
203
+ ki = tXcX[0, 0, k][1] // self.async_copy_elems
204
+ mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki]
205
+ tXsX_k = tXsX[None, m, k]
206
+ mX_paged_cur_copy_ki = cute.make_tensor(
207
+ mX_paged_cur_copy_ki.iterator, tXsX_k.layout
208
+ )
209
+ cute.copy(
210
+ self.gmem_tiled_copy_KV,
211
+ mX_paged_cur_copy_ki,
212
+ tXsX_k,
213
+ pred=should_load,
214
+ )
build/torch-cuda/pipeline.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ # import math
4
+ from typing import Optional
5
+ from dataclasses import dataclass
6
+
7
+ import cutlass.cute as cute
8
+ from cutlass import Boolean, Int32, const_expr
9
+ from cutlass.cutlass_dsl import if_generate, dsl_user_op
10
+ from cutlass.pipeline import PipelineState
11
+ from cutlass.pipeline import PipelineUserType
12
+ from cutlass.pipeline import NamedBarrier as NamedBarrierOg
13
+ from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
14
+ from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
15
+ from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
16
+ from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
17
+ from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
18
+
19
+
20
+ class PipelineStateSimple:
21
+ """
22
+ Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
23
+ Use a single Int32 to store both the index and phase bit, then we use divmod to get the
24
+ index and phase. If stages is a power of 2, divmod turns into bit twiddling.
25
+ """
26
+
27
+ def __init__(self, stages: int, phase_index: Int32):
28
+ # assert stages < 2**16
29
+ # self._log_stages = int(math.log2(stages))
30
+ # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2."
31
+ self._stages = stages
32
+ self._phase_index = phase_index
33
+
34
+ def clone(self) -> "PipelineStateSimple":
35
+ return PipelineStateSimple(self.stages, self._phase_index)
36
+
37
+ @property
38
+ def stages(self) -> int:
39
+ # return 1 << self._log_stages
40
+ return self._stages
41
+
42
+ @property
43
+ def index(self) -> Int32:
44
+ # return self._phase_index & 0xFFFF
45
+ # return self._phase_index & ((1 << self._log_stages) - 1)
46
+ if const_expr(self._stages == 1):
47
+ return Int32(0)
48
+ else:
49
+ return self._phase_index % self._stages
50
+
51
+ @property
52
+ def phase(self) -> Int32:
53
+ # return self._phase_index >> 16
54
+ # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
55
+ # take modulo 2. But in practice just passing the phase in without modulo works fine.
56
+ # return (self._phase_index >> self._log_stages) % 2
57
+ # return self._phase_index >> self._log_stages
58
+ if const_expr(self._stages == 1):
59
+ return self._phase_index
60
+ else:
61
+ return self._phase_index // self._stages
62
+
63
+ def advance(self):
64
+ if const_expr(self._stages == 1):
65
+ self._phase_index ^= 1
66
+ else:
67
+ self._phase_index += 1
68
+
69
+ # def then_body(phase_index):
70
+ # # XOR the phase bit and set the index to 0
71
+ # return (phase_index & 0xFFFF0000) ^ (1 << 16)
72
+
73
+ # def else_body(phase_index):
74
+ # return phase_index
75
+
76
+ # self._phase_index = if_generate(
77
+ # (self._phase_index & 0xFFFF) == self.stages,
78
+ # then_body,
79
+ # else_body,
80
+ # [self._phase_index],
81
+ # [Int32],
82
+ # )
83
+
84
+ def __extract_mlir_values__(self):
85
+ phase_index = self._phase_index
86
+ return [phase_index.ir_value()]
87
+
88
+ def __new_from_mlir_values__(self, values):
89
+ return PipelineStateSimple(self.stages, Int32(values[0]))
90
+
91
+
92
+ def make_pipeline_state(type: PipelineUserType, stages: int):
93
+ """
94
+ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
95
+ """
96
+ if type is PipelineUserType.Producer:
97
+ # return PipelineStateSimple(stages, Int32(1 << 16))
98
+ return PipelineStateSimple(stages, Int32(stages))
99
+ elif type is PipelineUserType.Consumer:
100
+ return PipelineStateSimple(stages, Int32(0))
101
+ else:
102
+ assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
103
+
104
+
105
+ @dataclass(frozen=True)
106
+ class NamedBarrier(NamedBarrierOg):
107
+ @staticmethod
108
+ def create(*args, **kwargs):
109
+ obj = NamedBarrierOg.create(*args, **kwargs)
110
+ # Can't assign to __class__ directly since the dataclass is frozen
111
+ object.__setattr__(obj, "__class__", NamedBarrier)
112
+ return obj
113
+
114
+ @dsl_user_op
115
+ def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
116
+ """
117
+ The aligned flavor of arrive is used when all threads in the CTA will execute the
118
+ same instruction. See PTX documentation.
119
+ """
120
+ cute.arch.barrier_arrive(
121
+ barrier_id=self.barrier_id + index,
122
+ number_of_threads=self.num_threads,
123
+ loc=loc,
124
+ ip=ip,
125
+ )
126
+
127
+ @dsl_user_op
128
+ def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
129
+ cute.arch.barrier(
130
+ barrier_id=self.barrier_id + index,
131
+ number_of_threads=self.num_threads,
132
+ loc=loc,
133
+ ip=ip,
134
+ )
135
+
136
+
137
+ @dataclass(frozen=True)
138
+ class PipelineAsync(PipelineAsyncOg):
139
+ @staticmethod
140
+ def create(*args, **kwargs):
141
+ obj = PipelineAsyncOg.create(*args, **kwargs)
142
+ # Can't assign to __class__ directly since the dataclass is frozen
143
+ # obj.__class__ = PipelineAsync
144
+ object.__setattr__(obj, "__class__", PipelineAsync)
145
+ return obj
146
+
147
+ @dsl_user_op
148
+ def producer_acquire_w_index_phase(
149
+ self,
150
+ index: Int32,
151
+ phase: Int32,
152
+ try_acquire_token: Optional[Boolean] = None,
153
+ *,
154
+ loc=None,
155
+ ip=None,
156
+ ):
157
+ if_generate(
158
+ try_acquire_token is None or try_acquire_token == 0,
159
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
160
+ loc=loc,
161
+ ip=ip,
162
+ )
163
+
164
+ @dsl_user_op
165
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
166
+ self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
167
+
168
+ @dsl_user_op
169
+ def consumer_wait_w_index_phase(
170
+ self,
171
+ index: Int32,
172
+ phase: Int32,
173
+ try_wait_token: Optional[Boolean] = None,
174
+ *,
175
+ loc=None,
176
+ ip=None,
177
+ ):
178
+ if_generate(
179
+ try_wait_token is None or try_wait_token == 0,
180
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
181
+ loc=loc,
182
+ ip=ip,
183
+ )
184
+
185
+ @dsl_user_op
186
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
187
+ self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
188
+
189
+
190
+ @dataclass(frozen=True)
191
+ class PipelineTmaAsync(PipelineTmaAsyncOg):
192
+ """
193
+ Override producer_acquire to take in extra_tx_count parameter.
194
+ """
195
+
196
+ @staticmethod
197
+ def create(*args, **kwargs):
198
+ obj = PipelineTmaAsyncOg.create(*args, **kwargs)
199
+ # Can't assign to __class__ directly since the dataclass is frozen
200
+ object.__setattr__(obj, "__class__", PipelineTmaAsync)
201
+ return obj
202
+
203
+ @dsl_user_op
204
+ def producer_acquire(
205
+ self,
206
+ state: PipelineState,
207
+ try_acquire_token: Optional[Boolean] = None,
208
+ extra_tx_count: int = 0,
209
+ *,
210
+ loc=None,
211
+ ip=None,
212
+ ):
213
+ """
214
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
215
+ """
216
+ if_generate(
217
+ try_acquire_token is None or try_acquire_token == 0,
218
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
219
+ loc=loc,
220
+ ip=ip,
221
+ )
222
+ if const_expr(extra_tx_count == 0):
223
+ self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)
224
+ else:
225
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
226
+ self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
227
+
228
+
229
+ @dataclass(frozen=True)
230
+ class PipelineTmaUmma(PipelineTmaUmmaOg):
231
+ """
232
+ Override producer_acquire to take in extra_tx_count parameter.
233
+ """
234
+
235
+ @staticmethod
236
+ def create(*args, **kwargs):
237
+ obj = PipelineTmaUmmaOg.create(*args, **kwargs)
238
+ # Can't assign to __class__ directly since the dataclass is frozen
239
+ # obj.__class__ = PipelineTmaUmma
240
+ object.__setattr__(obj, "__class__", PipelineTmaUmma)
241
+ return obj
242
+
243
+ @dsl_user_op
244
+ def producer_acquire(
245
+ self,
246
+ state: PipelineState,
247
+ try_acquire_token: Optional[Boolean] = None,
248
+ extra_tx_count: int = 0,
249
+ *,
250
+ loc=None,
251
+ ip=None,
252
+ ):
253
+ """
254
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
255
+ """
256
+ if_generate(
257
+ try_acquire_token is None or try_acquire_token == 0,
258
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
259
+ loc=loc,
260
+ ip=ip,
261
+ )
262
+ if const_expr(extra_tx_count == 0):
263
+ if_generate(
264
+ self.is_leader_cta,
265
+ lambda: self.sync_object_full.arrive(
266
+ state.index, self.producer_mask, loc=loc, ip=ip
267
+ ),
268
+ loc=loc,
269
+ ip=ip,
270
+ )
271
+ else:
272
+ tx_count = self.sync_object_full.tx_count + extra_tx_count
273
+ if_generate(
274
+ self.is_leader_cta,
275
+ lambda: self.sync_object_full.arrive_and_expect_tx(
276
+ state.index, tx_count, loc=loc, ip=ip
277
+ ),
278
+ loc=loc,
279
+ ip=ip,
280
+ )
281
+
282
+ @dsl_user_op
283
+ def producer_acquire_w_index_phase(
284
+ self,
285
+ index: Int32,
286
+ phase: Int32,
287
+ try_acquire_token: Optional[Boolean] = None,
288
+ *,
289
+ loc=None,
290
+ ip=None,
291
+ ):
292
+ """
293
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
294
+ """
295
+ if_generate(
296
+ try_acquire_token is None or try_acquire_token == 0,
297
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
298
+ loc=loc,
299
+ ip=ip,
300
+ )
301
+ if_generate(
302
+ self.is_leader_cta,
303
+ lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),
304
+ loc=loc,
305
+ ip=ip,
306
+ )
307
+
308
+ @dsl_user_op
309
+ def consumer_wait_w_index_phase(
310
+ self,
311
+ index: Int32,
312
+ phase: Int32,
313
+ try_wait_token: Optional[Boolean] = None,
314
+ *,
315
+ loc=None,
316
+ ip=None,
317
+ ):
318
+ if_generate(
319
+ try_wait_token is None or try_wait_token == 0,
320
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
321
+ loc=loc,
322
+ ip=ip,
323
+ )
324
+
325
+ @dsl_user_op
326
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
327
+ """
328
+ UMMA consumer release buffer empty, cta_group needs to be provided.
329
+ """
330
+ self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
331
+
332
+
333
+ @dataclass(frozen=True)
334
+ class PipelineUmmaAsync(PipelineUmmaAsyncOg):
335
+ @staticmethod
336
+ def create(*args, **kwargs):
337
+ obj = PipelineUmmaAsyncOg.create(*args, **kwargs)
338
+ # Can't assign to __class__ directly since the dataclass is frozen
339
+ object.__setattr__(obj, "__class__", PipelineUmmaAsync)
340
+ return obj
341
+
342
+ @dsl_user_op
343
+ def producer_acquire_w_index_phase(
344
+ self,
345
+ index: Int32,
346
+ phase: Int32,
347
+ try_acquire_token: Optional[Boolean] = None,
348
+ *,
349
+ loc=None,
350
+ ip=None,
351
+ ):
352
+ if_generate(
353
+ try_acquire_token is None or try_acquire_token == 0,
354
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
355
+ loc=loc,
356
+ ip=ip,
357
+ )
358
+
359
+ @dsl_user_op
360
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
361
+ """
362
+ UMMA producer commit buffer full, cta_group needs to be provided.
363
+ """
364
+ self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)
365
+
366
+ @dsl_user_op
367
+ def consumer_wait_w_index_phase(
368
+ self,
369
+ index: Int32,
370
+ phase: Int32,
371
+ try_wait_token: Optional[Boolean] = None,
372
+ *,
373
+ loc=None,
374
+ ip=None,
375
+ ):
376
+ if_generate(
377
+ try_wait_token is None or try_wait_token == 0,
378
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
379
+ loc=loc,
380
+ ip=ip,
381
+ )
382
+
383
+ @dsl_user_op
384
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
385
+ self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
386
+
387
+
388
+ @dataclass(frozen=True)
389
+ class PipelineAsyncUmma(PipelineAsyncUmmaOg):
390
+ @staticmethod
391
+ def create(*args, **kwargs):
392
+ obj = PipelineAsyncUmmaOg.create(*args, **kwargs)
393
+ # Can't assign to __class__ directly since the dataclass is frozen
394
+ object.__setattr__(obj, "__class__", PipelineAsyncUmma)
395
+ return obj
396
+
397
+ @dsl_user_op
398
+ def producer_acquire_w_index_phase(
399
+ self,
400
+ index: Int32,
401
+ phase: Int32,
402
+ try_acquire_token: Optional[Boolean] = None,
403
+ *,
404
+ loc=None,
405
+ ip=None,
406
+ ):
407
+ if_generate(
408
+ try_acquire_token is None or try_acquire_token == 0,
409
+ lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
410
+ loc=loc,
411
+ ip=ip,
412
+ )
413
+
414
+ @dsl_user_op
415
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
416
+ self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
417
+
418
+ @dsl_user_op
419
+ def consumer_wait_w_index_phase(
420
+ self,
421
+ index: Int32,
422
+ phase: Int32,
423
+ try_wait_token: Optional[Boolean] = None,
424
+ *,
425
+ loc=None,
426
+ ip=None,
427
+ ):
428
+ if_generate(
429
+ try_wait_token is None or try_wait_token == 0,
430
+ lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
431
+ loc=loc,
432
+ ip=ip,
433
+ )
434
+
435
+ @dsl_user_op
436
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
437
+ """
438
+ UMMA consumer release buffer empty, cta_group needs to be provided.
439
+ """
440
+ self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
build/torch-cuda/quack/__init__.py ADDED
File without changes
build/torch-cuda/quack/activation.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+ from functools import partial
6
+
7
+ import cutlass.cute as cute
8
+ from cutlass import Float32, Boolean, const_expr
9
+ from cutlass.cutlass_dsl import T, dsl_user_op
10
+ from cutlass._mlir.dialects import llvm, nvvm
11
+
12
+
13
+ F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
14
+
15
+
16
+ sub_packed_f32x2 = partial(
17
+ cute.arch.calc_packed_f32x2_op,
18
+ src_c=None,
19
+ calc_func=nvvm.sub_packed_f32x2,
20
+ )
21
+
22
+
23
+ @dsl_user_op
24
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
25
+ return Float32(
26
+ llvm.inline_asm(
27
+ T.f32(),
28
+ [Float32(a).ir_value(loc=loc, ip=ip)],
29
+ "tanh.approx.f32 $0, $1;",
30
+ "=f,f",
31
+ has_side_effects=False,
32
+ is_align_stack=False,
33
+ asm_dialect=llvm.AsmDialect.AD_ATT,
34
+ )
35
+ )
36
+
37
+
38
+ @dsl_user_op
39
+ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
40
+ if const_expr(not isinstance(x, tuple)):
41
+ # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
42
+ return 0.5 + 0.5 * tanh(0.5 * x)
43
+ else:
44
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
45
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
46
+ return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
47
+
48
+
49
+ @dsl_user_op
50
+ def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
51
+ # return dout * out * (1.0 - out)
52
+ return dout * (out - out * out)
53
+
54
+
55
+ @dsl_user_op
56
+ def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
57
+ if const_expr(not isinstance(x, tuple)):
58
+ return cute.arch.fmax(x, Float32(0.0))
59
+ else:
60
+ return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
61
+
62
+
63
+ @dsl_user_op
64
+ @cute.jit
65
+ def drelu(
66
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
67
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
68
+ if const_expr(not isinstance(x, tuple)):
69
+ x_pos = Boolean(x > 0)
70
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
71
+ else:
72
+ x0_pos = Boolean(x[0] > 0)
73
+ x1_pos = Boolean(x[1] > 0)
74
+ dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
75
+ return dx, relu(x)
76
+
77
+
78
+ @dsl_user_op
79
+ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
80
+ if const_expr(not isinstance(x, tuple)):
81
+ return cute.arch.fmax(x, Float32(0.0)) * x
82
+ else:
83
+ relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
84
+ return cute.arch.mul_packed_f32x2(relu_x, x)
85
+
86
+
87
+ @dsl_user_op
88
+ @cute.jit
89
+ def drelu_sq(
90
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
91
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
92
+ """
93
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
94
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
95
+ Returns: (dx, relu_sq_out) where:
96
+ - dx = dout * 2 * x if x > 0, else 0
97
+ - relu_sq_out = max(x, 0) * x
98
+ """
99
+ if const_expr(not isinstance(x, tuple)):
100
+ relu_x = relu(x)
101
+ relu_sq_out = relu_x * x
102
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
103
+ dx = 2.0 * (dout * relu_x)
104
+ return dx, relu_sq_out
105
+ else:
106
+ relu_x = relu(x)
107
+ relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
108
+ dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
109
+ return dx, relu_sq_out
110
+
111
+
112
+ @dsl_user_op
113
+ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
114
+ """
115
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
116
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
117
+ """
118
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
119
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
120
+ if const_expr(not isinstance(x, tuple)):
121
+ return 0.5 * (
122
+ x
123
+ # Currently cute.math.tanh(x, fastmath=True) generates very slow code
124
+ # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
125
+ * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
126
+ )
127
+ else:
128
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
129
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
130
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
131
+ )
132
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
133
+ tanh_z = (tanh(z[0]), tanh(z[1]))
134
+ x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
135
+ return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
136
+
137
+
138
+ @dsl_user_op
139
+ def dgelu_tanh_approx(
140
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
141
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
142
+ """
143
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
144
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
145
+ Returns: (dx, gelu_out)
146
+
147
+ Derivative uses the chain rule:
148
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
149
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
150
+ and sech^2(z) = 1 - tanh^2(z)
151
+ """
152
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
153
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
154
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
155
+
156
+ if const_expr(not isinstance(x, tuple)):
157
+ # Compute z = x * (c1 + c2 * x^2)
158
+ x_sq = x * x
159
+ # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
160
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
161
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
162
+ gelu_out = x * half_tanh_z_plus_one
163
+
164
+ # Compute gradient
165
+ # sech^2(z) = 1 - tanh^2(z)
166
+ sech2_z = 1 - tanh_z * tanh_z
167
+ # dz/dx = c1 + 3 * c2 * x^2
168
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
169
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
170
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
171
+
172
+ dx = dout * dgelu
173
+ return dx, gelu_out
174
+ else:
175
+ # Compute z = x * (c1 + c2 * x^2)
176
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
177
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
178
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
179
+ )
180
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
181
+ tanh_z = (tanh(z[0]), tanh(z[1]))
182
+ half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
183
+ gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
184
+
185
+ # Compute gradient
186
+ # sech^2(z) = 1 - tanh^2(z)
187
+ sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
188
+ # dz/dx = c1 + 3 * c2 * x^2
189
+ dz_dx = cute.arch.fma_packed_f32x2(
190
+ x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
191
+ )
192
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
193
+ sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
194
+ x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
195
+ dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
196
+
197
+ dx = cute.arch.mul_packed_f32x2(dout, dgelu)
198
+ return dx, gelu_out
199
+
200
+
201
+ @dsl_user_op
202
+ @cute.jit
203
+ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
204
+ if const_expr(not isinstance(x, tuple)):
205
+ use_linear = Boolean(x > 20.0)
206
+ return (
207
+ cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
208
+ if not use_linear
209
+ else x
210
+ )
211
+ else:
212
+ log2_e = math.log2(math.e)
213
+ x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
214
+ x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
215
+ x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
216
+ log_x_exp_p1 = (
217
+ cute.math.log2(x_exp_p1[0], fastmath=True),
218
+ cute.math.log2(x_exp_p1[1], fastmath=True),
219
+ )
220
+ ln2 = math.log(2.0)
221
+ softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
222
+ use_linear_0 = Boolean(x[0] > 20.0)
223
+ use_linear_1 = Boolean(x[1] > 20.0)
224
+ return (
225
+ softplus_x[0] if not use_linear_0 else x[0],
226
+ softplus_x[1] if not use_linear_1 else x[1],
227
+ )
228
+
229
+
230
+ @dsl_user_op
231
+ @cute.jit
232
+ def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
233
+ use_linear = Boolean(out > 20.0)
234
+ # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
235
+ dx = dout - dout * cute.math.exp(-out, fastmath=True)
236
+ return dx if not use_linear else dout
237
+
238
+
239
+ @dsl_user_op
240
+ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
241
+ """
242
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
243
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
244
+ """
245
+ if const_expr(not isinstance(x, tuple)):
246
+ x_half = 0.5 * x if const_expr(not already_halved) else x
247
+ # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
248
+ return x_half * tanh(x_half) + x_half
249
+ else:
250
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
251
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
252
+ return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
253
+
254
+
255
+ @dsl_user_op
256
+ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
257
+ if const_expr(not isinstance(x, tuple)):
258
+ return silu(x) * y
259
+ else:
260
+ return cute.arch.mul_packed_f32x2(silu(x), y)
261
+
262
+
263
+ @dsl_user_op
264
+ def dswiglu(
265
+ x: F32_or_F32x2,
266
+ y: F32_or_F32x2,
267
+ dout: F32_or_F32x2,
268
+ *,
269
+ already_halved: bool = False,
270
+ loc=None,
271
+ ip=None,
272
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
273
+ """
274
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
275
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
276
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
277
+
278
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
279
+
280
+ This has been optimized to use fewer instructions (i.e. we expand things out
281
+ to use FFMA instead of FADD and FMUL).
282
+ """
283
+ if const_expr(not isinstance(x, tuple)):
284
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
285
+ # FMUL, MUFU.TANH, then FFMA
286
+ if const_expr(not already_halved):
287
+ sigmoid_x = sigmoid(x)
288
+ silu_x = x * sigmoid_x # FMUL
289
+ else:
290
+ tanh_x = tanh(x) # MUFU.TANH
291
+ sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
292
+ silu_x = x * tanh_x + x # FFMA
293
+ silu_x_dout = silu_x * dout # FMUL
294
+ # d_silu(x) * dout
295
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
296
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
297
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
298
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
299
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
300
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
301
+ dx = d_silu_x_dout * y # FMUL
302
+ dy = silu_x_dout
303
+ swiglu_out = silu_x * y # FMUL
304
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
305
+ return dx, dy, swiglu_out
306
+ else:
307
+ # Compute sigmoid(x) and silu(x)
308
+ if const_expr(not already_halved):
309
+ sigmoid_x = sigmoid(x)
310
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
311
+ else:
312
+ tanh_x = (tanh(x[0]), tanh(x[1]))
313
+ sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
314
+ silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
315
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
316
+ # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
317
+ sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
318
+ sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
319
+ )
320
+ d_silu_x_dout = cute.arch.fma_packed_f32x2(
321
+ sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
322
+ )
323
+ dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
324
+ dy = silu_x_dout
325
+ swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
326
+ return dx, dy, swiglu_out
327
+
328
+
329
+ @dsl_user_op
330
+ def swiglu_oai(
331
+ x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
332
+ ) -> F32_or_F32x2:
333
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
334
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
335
+ x * sigmoid(alpha * x) * (y + 1)
336
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
337
+ """
338
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
339
+ if const_expr(not isinstance(x, tuple)):
340
+ x_half = 0.5 * x
341
+ # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
342
+ silu_x = x_half * tanh(alpha * x_half) + x_half
343
+ return silu_x * y + silu_x
344
+ else:
345
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
346
+ alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
347
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
348
+ silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
349
+ return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
350
+
351
+
352
+ @dsl_user_op
353
+ def dswiglu_oai(
354
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
355
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
356
+ """
357
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
358
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
359
+ Returns: (dx, dy, swiglu_oai_out)
360
+
361
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
362
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
363
+ """
364
+ if const_expr(not isinstance(x, tuple)):
365
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
366
+ alpha_x_half = (0.5 * alpha) * x # FMUL
367
+ # MUFU.TANH, then FFMA
368
+ # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
369
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
370
+ silu_x = x * sigmoid_alpha_x # FMUL
371
+ silu_x_dout = silu_x * dout # FMUL
372
+ # FFMA, FFMA, FMUL
373
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
374
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
375
+ dy = silu_x_dout
376
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
377
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
378
+ return dx, dy, swiglu_out
379
+ else:
380
+ # Compute sigmoid(alpha * x)
381
+ alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
382
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
383
+ sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
384
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
385
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
386
+ # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
387
+ silu_x_minus_product = cute.arch.fma_packed_f32x2(
388
+ silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
389
+ )
390
+ sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
391
+ (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
392
+ )
393
+ d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
394
+ dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
395
+ dy = silu_x_dout
396
+ swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
397
+ return dx, dy, swiglu_out
398
+
399
+
400
+ @dsl_user_op
401
+ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
402
+ """GLU: Gated Linear Unit
403
+ glu(x, y) = sigmoid(x) * y
404
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
405
+ """
406
+ if const_expr(not isinstance(x, tuple)):
407
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
408
+ return sigmoid_x * y # FMUL
409
+ else:
410
+ sigmoid_x = sigmoid(x)
411
+ return cute.arch.mul_packed_f32x2(sigmoid_x, y)
412
+
413
+
414
+ @dsl_user_op
415
+ def dglu(
416
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
417
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
418
+ """
419
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
420
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
421
+ Returns: (dx, dy, glu_out) where:
422
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
423
+ - dy = dout * sigmoid(x)
424
+ - glu_out = sigmoid(x) * y
425
+ """
426
+ if const_expr(not isinstance(x, tuple)):
427
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
428
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
429
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
430
+ glu_out = sigmoid_x * y # FMUL
431
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
432
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
433
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
434
+ # = (y - glu_out) * sigmoid_x_dout
435
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
436
+ dy = sigmoid_x_dout
437
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
438
+ return dx, dy, glu_out
439
+ else:
440
+ sigmoid_x = sigmoid(x)
441
+ sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
442
+ glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
443
+ # dx = (y - glu_out) * sigmoid_x_dout
444
+ y_minus_glu_out = sub_packed_f32x2(y, glu_out)
445
+ dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
446
+ dy = sigmoid_x_dout
447
+ return dx, dy, glu_out
448
+
449
+
450
+ @dsl_user_op
451
+ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
452
+ """ReGLU: ReLU Gated Linear Unit
453
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
454
+ """
455
+ if const_expr(not isinstance(x, tuple)):
456
+ return cute.arch.fmax(x, Float32(0.0)) * y
457
+ else:
458
+ relu_x = relu(x)
459
+ return cute.arch.mul_packed_f32x2(relu_x, y)
460
+
461
+
462
+ @dsl_user_op
463
+ @cute.jit
464
+ def dreglu(
465
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
466
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
467
+ """
468
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
469
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
470
+ Returns: (dx, dy, reglu_out) where:
471
+ - dx = dout * y if x > 0, else 0
472
+ - dy = dout * relu(x)
473
+ - reglu_out = relu(x) * y
474
+ """
475
+ if const_expr(not isinstance(x, tuple)):
476
+ x_pos = Boolean(x > 0)
477
+ relu_x = cute.arch.fmax(x, Float32(0.0))
478
+ dx = (dout * y) if x_pos else Float32(0.0)
479
+ dy = dout * relu_x
480
+ reglu_out = relu_x * y
481
+ return dx, dy, reglu_out
482
+ else:
483
+ x0_pos = Boolean(x[0] > 0)
484
+ x1_pos = Boolean(x[1] > 0)
485
+ relu_x = relu(x)
486
+ dout_y = cute.arch.mul_packed_f32x2(dout, y)
487
+ dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
488
+ dy = cute.arch.mul_packed_f32x2(dout, relu_x)
489
+ reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
490
+ return dx, dy, reglu_out
491
+
492
+
493
+ @dsl_user_op
494
+ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
495
+ """GeGLU: GELU Gated Linear Unit
496
+ geglu(x, y) = gelu(x) * y
497
+ Uses the tanh approximation of GELU
498
+ """
499
+ if const_expr(not isinstance(x, tuple)):
500
+ return gelu_tanh_approx(x) * y
501
+ else:
502
+ return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
503
+
504
+
505
+ @dsl_user_op
506
+ def dgeglu(
507
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
508
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
509
+ """
510
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
511
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
512
+ Returns: (dx, dy, geglu_out) where:
513
+ - dx = dout * y * d_gelu(x)
514
+ - dy = dout * gelu(x)
515
+ - geglu_out = gelu(x) * y
516
+ """
517
+ if const_expr(not isinstance(x, tuple)):
518
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
+ # Compute gradients for geglu
521
+ dx = dgelu_x_dout * y
522
+ dy = gelu_x * dout
523
+ geglu_out = gelu_x * y
524
+ return dx, dy, geglu_out
525
+ else:
526
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
527
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
528
+ # Compute gradients for geglu
529
+ dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
530
+ dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
531
+ geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
532
+ return dx, dy, geglu_out
533
+
534
+
535
+ # ============================================================================
536
+ # Activation name -> function maps
537
+ # ============================================================================
538
+
539
+ act_fn_map = {
540
+ None: None,
541
+ "silu": silu,
542
+ "relu": relu,
543
+ "relu_sq": relu_sq,
544
+ "gelu_tanh_approx": gelu_tanh_approx,
545
+ }
546
+
547
+ dact_fn_map = {
548
+ None: None,
549
+ "relu": drelu,
550
+ "relu_sq": drelu_sq,
551
+ "gelu_tanh_approx": dgelu_tanh_approx,
552
+ }
553
+
554
+ gate_fn_map = {
555
+ "swiglu": swiglu,
556
+ "swiglu_oai": swiglu_oai,
557
+ "reglu": reglu,
558
+ "geglu": geglu,
559
+ "glu": glu,
560
+ }
561
+
562
+ dgate_fn_map = {
563
+ "swiglu": dswiglu,
564
+ "swiglu_oai": dswiglu_oai,
565
+ "reglu": dreglu,
566
+ "geglu": dgeglu,
567
+ "glu": dglu,
568
+ }
build/torch-cuda/quack/compile_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional
4
+
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9
+ if leading_dim < 0:
10
+ leading_dim = len(shape) + leading_dim
11
+ if dtype is None:
12
+ return None
13
+ stride = tuple(
14
+ cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15
+ for i in range(len(shape))
16
+ )
17
+ return cute.runtime.make_fake_tensor(
18
+ dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19
+ )
build/torch-cuda/quack/copy_utils.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional, Type, Tuple, Callable, Sequence
4
+ from functools import partial
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ from cutlass import Int32, Int16, Boolean, const_expr
10
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
11
+ from cutlass.cute.nvgpu.tcgen05.mma import CtaGroup # noqa
12
+ from cutlass.cutlass_dsl import dsl_user_op
13
+ import cutlass.pipeline
14
+ from cutlass._mlir.dialects import llvm
15
+ from cutlass._mlir import ir
16
+ from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
17
+
18
+
19
+ Sm100MmaPeerBitMask = 0xFEFFFFFF
20
+
21
+
22
+ @dsl_user_op
23
+ def cvt_copy(
24
+ tiled_copy: cute.TiledCopy,
25
+ src: cute.Tensor,
26
+ dst: cute.Tensor,
27
+ *,
28
+ pred: Optional[cute.Tensor] = None,
29
+ retile: bool = False,
30
+ loc=None,
31
+ ip=None,
32
+ **kwargs,
33
+ ) -> None:
34
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
35
+ if const_expr(src.element_type != dst.element_type):
36
+ src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
37
+ src_cvt.store(src.load().to(dst.element_type))
38
+ src = src_cvt
39
+ if const_expr(retile):
40
+ src = tiled_copy.retile(src)
41
+ cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
42
+
43
+
44
+ @dsl_user_op
45
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
46
+ dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
47
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
48
+ return dst
49
+
50
+
51
+ @dsl_user_op
52
+ def load_s2r_retile(
53
+ tiled_copy: cute.TiledCopy,
54
+ src: cute.Tensor,
55
+ dst_shape: cute.Tensor | cute.Shape,
56
+ *,
57
+ loc=None,
58
+ ip=None,
59
+ ) -> cute.Tensor:
60
+ # Will also accept dst_shape being a tensor, in which case we write into that tensor
61
+ if const_expr(not isinstance(dst_shape, cute.Tensor)):
62
+ dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
63
+ else:
64
+ dst = dst_shape
65
+ cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
66
+ return dst
67
+
68
+
69
+ @dsl_user_op
70
+ def load_t2r(
71
+ thr_copy: cute.ThrCopy, shape: cute.Shape, src: cute.Tensor, *, loc=None, ip=None
72
+ ) -> cute.Tensor:
73
+ cDst = cute.make_identity_tensor(shape)
74
+ dst = cute.make_rmem_tensor(thr_copy.partition_D(cDst).shape, src.element_type, loc=loc, ip=ip)
75
+ cute.copy(thr_copy, src, dst, loc=loc, ip=ip)
76
+ return dst
77
+
78
+
79
+ @dsl_user_op
80
+ def get_copy_atom(
81
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
82
+ ) -> cute.CopyAtom:
83
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
84
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
85
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
86
+
87
+
88
+ @dsl_user_op
89
+ def copy(
90
+ src: cute.Tensor,
91
+ dst: cute.Tensor,
92
+ *,
93
+ pred: Optional[cute.Tensor] = None,
94
+ is_async: bool = False,
95
+ loc=None,
96
+ ip=None,
97
+ **kwargs,
98
+ ) -> None:
99
+ num_copy_elems = src.shape[0][0]
100
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
101
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
102
+
103
+
104
+ def tiled_copy_1d(
105
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
106
+ ) -> cute.TiledCopy:
107
+ num_copy_bits = num_copy_elems * dtype.width
108
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
109
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
110
+ thr_layout = cute.make_layout(num_threads)
111
+ val_layout = cute.make_layout(num_copy_elems)
112
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
113
+
114
+
115
+ def tiled_copy_2d(
116
+ dtype: Type[cutlass.Numeric],
117
+ threads_per_row: int,
118
+ num_threads: int,
119
+ num_copy_elems: int = 1,
120
+ is_async: bool = False,
121
+ ) -> cute.TiledCopy:
122
+ num_copy_bits = num_copy_elems * dtype.width
123
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
124
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
125
+ assert num_threads % threads_per_row == 0
126
+ thr_layout = cute.make_ordered_layout(
127
+ (num_threads // threads_per_row, threads_per_row),
128
+ order=(1, 0),
129
+ )
130
+ val_layout = cute.make_layout((1, num_copy_elems))
131
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
132
+
133
+
134
+ @cute.jit
135
+ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
136
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
137
+ tApA = cute.make_rmem_tensor(
138
+ cute.make_layout(
139
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
140
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
141
+ ),
142
+ Boolean,
143
+ )
144
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
145
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
146
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
147
+ return tApA
148
+
149
+
150
+ # def tiled_copy_2d(
151
+ # dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
152
+ # ) -> cute.TiledCopy:
153
+ # num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
154
+ # copy_elems = num_copy_bits // dtype.width
155
+ # copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
156
+ # copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
157
+ # gmem_threads_per_row = major_mode_size // copy_elems
158
+ # assert num_threads % gmem_threads_per_row == 0
159
+ # thr_layout = cute.make_ordered_layout(
160
+ # (num_threads // gmem_threads_per_row, gmem_threads_per_row),
161
+ # order=(1, 0),
162
+ # )
163
+ # val_layout = cute.make_layout((1, copy_elems))
164
+ # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
165
+
166
+
167
+ # Ragged tensor trick for TMA: encodes variable-length sequences into a higher-rank
168
+ # tensor so that TMA's out-of-bounds checking handles sequence boundaries.
169
+ #
170
+ # Given a tensor T with a ragged dimension (variable-length across batches), we create
171
+ # a higher-rank tensor where the ragged dim is replaced with a fixed size `big_int`, and
172
+ # extra dim(s) are appended. When indexing into a specific sequence at (offset, length),
173
+ # `offset_ragged_tensor` computes coordinates such that:
174
+ # ragged_coord = big_int - length (OOB check clamps reads past the sequence end)
175
+ # extra_coord(s) = f(offset, length) (selects the correct memory region)
176
+ #
177
+ # ptr_shift=True: 1-extra-dim approach (adds 1 dim, supports up to 4D input):
178
+ # Shape: (*before, big_int, *after, max_int)
179
+ # Stride: (*original_strides, stride_r) where stride_r = T.stride[ragged_dim]
180
+ # Pointer shifted backward by big_int * stride_r elements.
181
+ # Address for coords (big_int - length) in ragged dim, (offset + length) in extra dim:
182
+ # addr = (base - big_int * s_r) + (big_int - length) * s_r + (offset + length) * s_r
183
+ # = base + offset * s_r [correct]
184
+ # Works for epilogue TMA store. Does NOT work for TMA load with large big_int
185
+ # — the shifted pointer must land in physically mapped GPU memory.
186
+ #
187
+ # ptr_shift=False: 2-extra-dim approach (adds 2 dims, supports up to 3D input):
188
+ # Shape: (*before, big_int, *after, max_int, max_int)
189
+ # Stride: (*before_strides, stride_r, *after_strides, 2^34 - stride_r, stride_r)
190
+ # No pointer shift. Uses 64-bit address wraparound to cancel the ragged offset.
191
+ # Let W = 2^34 - stride_r. Address for coords (big_int - length) in ragged dim,
192
+ # big_int in extra dim 0, (offset + length) in extra dim 1:
193
+ # addr = base + (big_int - length) * s_r + big_int * W + (offset + length) * s_r
194
+ # = base + big_int * (s_r + W) - length * s_r + (offset + length) * s_r
195
+ # = base + big_int * 2^34 + offset * s_r
196
+ # Since big_int = 2^30: big_int * 2^34 = 2^64 ≡ 0 (mod 2^64), so:
197
+ # addr = base + offset * s_r [correct]
198
+ # Works for all TMA paths since the base pointer is never shifted.
199
+ #
200
+ # Ragged tensor was adapted from the implementation from Triton, but here we have an option that
201
+ # only needs 1 extra dimension instead of 2.
202
+ # https://github.com/triton-lang/triton/blob/main/python/triton/tools/ragged_tma.py
203
+ BIG_INT = 2**30
204
+ MAX_INT = 2**31 - 1
205
+ BIG_INT_INV = 2**64 // BIG_INT
206
+
207
+
208
+ @dsl_user_op
209
+ def create_ragged_tensor_for_tma(
210
+ T: cute.Tensor,
211
+ ragged_dim: int = 0,
212
+ ptr_shift: bool = False,
213
+ *,
214
+ loc=None,
215
+ ip=None,
216
+ ) -> cute.Tensor:
217
+ rank = cute.rank(T)
218
+ if ragged_dim < 0:
219
+ ragged_dim += rank
220
+ if ptr_shift:
221
+ assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions"
222
+ new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,)
223
+ new_stride = T.stride + (T.stride[ragged_dim],)
224
+ ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1)
225
+ new_ptr = cute.domain_offset(ptr_offset, T).iterator
226
+ return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride))
227
+ else:
228
+ assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions"
229
+ stride_r = T.stride[ragged_dim]
230
+ new_shape = (
231
+ T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT)
232
+ )
233
+ new_stride = (
234
+ T.stride[:ragged_dim]
235
+ + (stride_r,)
236
+ + T.stride[ragged_dim + 1 :]
237
+ + (BIG_INT_INV - stride_r, stride_r)
238
+ )
239
+ return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride))
240
+
241
+
242
+ @dsl_user_op
243
+ def offset_ragged_tensor(
244
+ T: cute.Tensor,
245
+ offset: Int32,
246
+ length: Int32,
247
+ ragged_dim: int = 0,
248
+ ptr_shift: bool = False,
249
+ *,
250
+ loc=None,
251
+ ip=None,
252
+ ) -> cute.Tensor:
253
+ rank = cute.rank(T)
254
+ if ragged_dim < 0:
255
+ ragged_dim += rank
256
+ big_int = cute.size(T, mode=[ragged_dim])
257
+ offset_val = big_int - length
258
+ if ptr_shift:
259
+ # 1-extra-dim: rank = original_rank + 1
260
+ assert rank >= ragged_dim + 2
261
+ offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2)
262
+ index_tuple = (None,) * (rank - 1) + (offset + length,)
263
+ else:
264
+ # 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims
265
+ assert rank >= ragged_dim + 3
266
+ offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3)
267
+ index_tuple = (None,) * (rank - 2) + (big_int, offset + length)
268
+ return cute.domain_offset(offset_tuple, T[index_tuple])
269
+
270
+
271
+ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
272
+ bit_msk = (1 << b) - 1
273
+ yyy_msk = bit_msk << (m + s)
274
+ return ptr_int ^ ((ptr_int & yyy_msk) >> s)
275
+
276
+
277
+ def swizzle_ptr(ptr: cute.Pointer):
278
+ swz = ptr.type.swizzle_type
279
+ ptr_int = swizzle_int(ptr.toint(), swz.num_bits, swz.num_base, swz.num_shift)
280
+ return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
281
+
282
+
283
+ def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
284
+ outer = tensor.layout
285
+ width = tensor.element_type.width
286
+ swizzle_type = tensor.iterator.type.swizzle_type
287
+ inner = cute.make_swizzle(swizzle_type.num_bits, swizzle_type.num_base, swizzle_type.num_shift)
288
+ # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
289
+ # for 16 bits and <3, 2, 3> for 32 bits)
290
+ new_layout = cute.recast_layout(
291
+ width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
292
+ )
293
+ # recast_ptr to remove the pointer swizzle
294
+ return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
295
+
296
+
297
+ def partition_D_position_independent(
298
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
299
+ ) -> cute.Tensor:
300
+ return cute.make_tensor(
301
+ swizzle_ptr(thr_copy.partition_D(tensor).iterator),
302
+ thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
303
+ )
304
+
305
+
306
+ def partition_S_position_independent(
307
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
308
+ ) -> cute.Tensor:
309
+ return cute.make_tensor(
310
+ swizzle_ptr(thr_copy.partition_S(tensor).iterator),
311
+ thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
312
+ )
313
+
314
+
315
+ @dsl_user_op
316
+ def sm90_get_smem_load_op(
317
+ layout_c: cutlass.utils.LayoutEnum,
318
+ elem_ty_c: Type[cutlass.Numeric],
319
+ *,
320
+ loc=None,
321
+ ip=None,
322
+ ) -> cute.CopyAtom:
323
+ """
324
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
325
+
326
+ Parameters:
327
+ -----------
328
+ layout_c : LayoutEnum
329
+ The layout enum of the output tensor D.
330
+
331
+ elem_ty_c : Type[Numeric]
332
+ The element type for output tensor D.
333
+
334
+ Returns:
335
+ --------
336
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
337
+ """
338
+
339
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
340
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
341
+ is_m_major = layout_c.is_m_major_c()
342
+ if elem_ty_c.width == 16:
343
+ return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
344
+ else:
345
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
346
+
347
+
348
+ def get_smem_store_atom(
349
+ arch: cutlass.Constexpr[int],
350
+ element_type: Type[cute.Numeric],
351
+ transpose: bool = False,
352
+ major_mode_size: Optional[int] = None,
353
+ ) -> cute.CopyAtom:
354
+ if const_expr(arch < 90 or element_type.width != 16):
355
+ return cute.make_copy_atom(
356
+ cute.nvgpu.CopyUniversalOp(),
357
+ element_type,
358
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
359
+ )
360
+ else:
361
+ num_matrices = (
362
+ 4
363
+ if major_mode_size is None or major_mode_size % 16 == 0
364
+ else (2 if major_mode_size % 8 == 0 else 1)
365
+ )
366
+ return cute.make_copy_atom(
367
+ warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
368
+ element_type,
369
+ )
370
+
371
+
372
+ def get_smem_load_atom(
373
+ arch: cutlass.Constexpr[int],
374
+ element_type: Type[cute.Numeric],
375
+ transpose: bool = False,
376
+ major_mode_size: Optional[int] = None,
377
+ ) -> cute.CopyAtom:
378
+ if const_expr(arch < 90 or element_type.width != 16):
379
+ return cute.make_copy_atom(
380
+ cute.nvgpu.CopyUniversalOp(),
381
+ element_type,
382
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
383
+ )
384
+ else:
385
+ num_matrices = (
386
+ 4
387
+ if major_mode_size is None or major_mode_size % 16 == 0
388
+ else (2 if major_mode_size % 8 == 0 else 1)
389
+ )
390
+ return cute.make_copy_atom(
391
+ warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=num_matrices),
392
+ element_type,
393
+ )
394
+
395
+
396
+ def get_smem_store_C(
397
+ tiled_mma: cute.TiledMma,
398
+ sC: cute.Tensor,
399
+ tidx: Int32,
400
+ arch: int,
401
+ transpose: bool = False,
402
+ position_independent=False,
403
+ major_mode_size: Optional[int] = None,
404
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
405
+ dtype = sC.element_type
406
+ copy_atom = get_smem_store_atom(arch, dtype, transpose, major_mode_size=major_mode_size)
407
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
408
+ thr_copy = tiled_copy.get_slice(tidx)
409
+ if const_expr(not position_independent):
410
+ tRS_sC = thr_copy.partition_D(sC)
411
+ else:
412
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
413
+
414
+ def copy_fn(src: cute.Tensor, dst_idx: Optional[Int32] = None, **new_kwargs):
415
+ dst_tensor = tRS_sC if const_expr(dst_idx is None) else tRS_sC[None, None, None, dst_idx]
416
+ cvt_copy(tiled_copy, src, dst_tensor, retile=True, **new_kwargs)
417
+
418
+ return copy_fn, thr_copy, tRS_sC
419
+
420
+
421
+ def get_smem_load_C(
422
+ tiled_mma: cute.TiledMma,
423
+ sC: cute.Tensor,
424
+ tidx: Int32,
425
+ arch: int,
426
+ transpose: bool = False,
427
+ position_independent=False,
428
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
429
+ dtype = sC.element_type
430
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
431
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
432
+ thr_copy = tiled_copy.get_slice(tidx)
433
+ if const_expr(not position_independent):
434
+ tSR_sC = thr_copy.partition_S(sC)
435
+ else:
436
+ tSR_sC = partition_S_position_independent(thr_copy, sC)
437
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
438
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
439
+ tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
440
+
441
+ def copy_fn(src_idx: Optional[Int32] = None, **new_kwargs):
442
+ src_tensor = tSR_sC if const_expr(src_idx is None) else tSR_sC[None, None, None, src_idx]
443
+ return load_s2r_retile(tiled_copy, src_tensor, dst_shape=tRS_shape, **new_kwargs)
444
+
445
+ return copy_fn, thr_copy, tSR_sC
446
+
447
+
448
+ def epilog_smem_copy_atom(
449
+ tiled_mma: cute.TiledMma, epi_tile: cute.Shape, transpose: bool = False
450
+ ) -> cute.TiledCopy:
451
+ copy_atom_C = cute.make_copy_atom(
452
+ warp.StMatrix8x8x16bOp(transpose, num_matrices=4 if epi_tile[1] % 16 == 0 else 2),
453
+ cutlass.Float16, # this is just to get the right source layout
454
+ )
455
+ tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma)
456
+ return tiled_copy_C_atom
457
+
458
+
459
+ def get_smem_store_epi(
460
+ tiled_mma: cute.TiledMma,
461
+ epi_tile: cute.Shape,
462
+ sC: Optional[cute.Tensor],
463
+ tidx: Int32,
464
+ arch: int,
465
+ transpose: bool = False,
466
+ position_independent=False,
467
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor, cute.Tensor]:
468
+ dtype = sC.element_type if const_expr(sC is not None) else cutlass.Float16
469
+ tiled_copy_C_atom = epilog_smem_copy_atom(tiled_mma, epi_tile)
470
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
471
+ tiled_copy = cute.make_tiled_copy_S(copy_atom, tiled_copy_C_atom)
472
+ thr_copy = tiled_copy.get_slice(tidx)
473
+ tRS_sC = None
474
+ if const_expr(sC is not None):
475
+ if const_expr(not position_independent):
476
+ tRS_sC = thr_copy.partition_D(sC)
477
+ else:
478
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
479
+ sC_shape = sC.shape[:2] if sC is not None else epi_tile
480
+ # (R2S, R2S_M, R2S_N, PIPE_C)
481
+ tRS_rC_shape = thr_copy.partition_S(cute.make_identity_tensor(sC_shape)).shape
482
+ tRS_rC = cute.make_rmem_tensor(tRS_rC_shape, tiled_mma.op.acc_dtype)
483
+
484
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
485
+ cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], **new_kwargs)
486
+
487
+ return copy_fn if const_expr(sC is not None) else None, thr_copy, tRS_sC, tRS_rC
488
+
489
+
490
+ def get_smem_store_A(
491
+ tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
492
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
493
+ dtype = sA.element_type
494
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
495
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
496
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
497
+ thr_copy = tiled_copy.get_slice(tidx)
498
+ if const_expr(not position_independent):
499
+ tRS_sA = thr_copy.partition_D(sA)
500
+ else:
501
+ tRS_sA = partition_D_position_independent(thr_copy, sA)
502
+
503
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
504
+ cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
505
+
506
+ return copy_fn, thr_copy, tRS_sA
507
+
508
+
509
+ def get_smem_load_A(
510
+ tiled_mma: cute.TiledMma,
511
+ sA: cute.Tensor,
512
+ tidx: Int32,
513
+ arch: int,
514
+ with_dst_tensor: bool = False,
515
+ position_independent=False,
516
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
517
+ dtype = sA.element_type
518
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
519
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
520
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
521
+ thr_copy = tiled_copy.get_slice(tidx)
522
+ if const_expr(not position_independent):
523
+ tSR_sA = thr_copy.partition_S(sA)
524
+ else:
525
+ tSR_sA = partition_S_position_independent(thr_copy, sA)
526
+ tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
527
+
528
+ def copy_fn(src_idx: Int32, **new_kwargs):
529
+ return load_s2r_retile(
530
+ tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
531
+ )
532
+
533
+ def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
534
+ return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
535
+
536
+ return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
537
+
538
+
539
+ @dsl_user_op
540
+ def cpasync_reduce_bulk_add_f32(
541
+ smem_ptr: cute.Pointer,
542
+ gmem_ptr: cute.Pointer,
543
+ store_bytes: int | Int32,
544
+ *,
545
+ loc=None,
546
+ ip=None,
547
+ ):
548
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
549
+ # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST
550
+ llvm.inline_asm(
551
+ None,
552
+ [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],
553
+ "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;",
554
+ "l,r,r",
555
+ # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],
556
+ # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;",
557
+ # "l,r,r,l",
558
+ has_side_effects=True,
559
+ is_align_stack=False,
560
+ asm_dialect=llvm.AsmDialect.AD_ATT,
561
+ )
562
+
563
+
564
+ @dsl_user_op
565
+ def get_tma_desc_addr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
566
+ """
567
+ Get the address of the TMA descriptor embedded in a TMA Copy Atom.
568
+
569
+ Extracts the constant memory address of the TMA descriptor for use with
570
+ custom PTX instructions.
571
+
572
+ :param tma_atom: TMA Copy Atom from make_tiled_tma_atom
573
+ :return: Pointer to TMA descriptor in constant memory
574
+
575
+ Example:
576
+ >>> desc_ptr = get_tma_descriptor_address(tma_atom)
577
+ """
578
+ exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
579
+ tma_desc_ptr_type = ir.Type.parse(
580
+ "!cute.ptr<!cute_nvgpu.tma_descriptor_tiled, generic, align<128>>"
581
+ )
582
+ return _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip)
583
+
584
+
585
+ @dsl_user_op
586
+ def tma_gather4_load(
587
+ tma_desc_ptr: cute.Pointer,
588
+ dst_smem_ptr: cute.Pointer,
589
+ mbarrier_ptr: cute.Pointer,
590
+ col_idx: Int32,
591
+ row_indices: Sequence[Int32],
592
+ *,
593
+ num_cta: int = 1,
594
+ multicast_mask=None,
595
+ loc=None,
596
+ ip=None,
597
+ ) -> None:
598
+ """
599
+ Perform TMA gather4 load from global memory to shared memory.
600
+
601
+ Issues PTX instruction:
602
+ cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
603
+ [dstMem], [tensorMap, {col_idx, row0, row1, row2, row3}], [smem_bar];
604
+
605
+ This loads 4 rows (specified by row_indices) from a 2D tensor at the given
606
+ column index into shared memory, using the TMA descriptor.
607
+
608
+ :param tma_desc_ptr: Pointer to TMA descriptor in constant memory (128-byte aligned)
609
+ :type tma_desc_ptr: Pointer
610
+ :param dst_smem_ptr: Destination address in shared memory
611
+ :type dst_smem_ptr: Pointer
612
+ :param mbarrier_ptr: Pointer to mbarrier in shared memory for completion tracking
613
+ :type mbarrier_ptr: Pointer
614
+ :param col_idx: Column index
615
+ :type col_idx: Int32
616
+ :param row_indices: Sequence of exactly 4 row indices
617
+ :type row_indices: Sequence[Int32]
618
+ :param num_cta: Number of CTAs participating (default: 1)
619
+ :type num_cta: int
620
+ :param multicast_mask: Optional multicast mask
621
+ :type multicast_mask: Int16
622
+
623
+ Requirements:
624
+ - row_indices must contain exactly 4 elements
625
+ - Compute capability >= SM_100 (Blackwell)
626
+ - TMA descriptor must be properly initialized for 2D tensor
627
+
628
+ Example:
629
+ >>> from cutlass.cute.nvgpu import cpasync
630
+ >>> from cutlass.cute import core
631
+ >>>
632
+ >>> # Create TMA descriptor
633
+ >>> tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(...)
634
+ >>> tma_desc_ptr = get_tma_descriptor_address(tma_atom)
635
+ >>>
636
+ >>> # Compute indices (typically from kernel logic)
637
+ >>> col_idx = core.get(...) or 5 # Int32 value
638
+ >>> row_indices = [core.get(...) for _ in range(4)] # 4 Int32 values
639
+ >>>
640
+ >>> # Gather 4 rows at computed column
641
+ >>> tma_gather4_load(
642
+ ... tma_desc_ptr=tma_desc_ptr,
643
+ ... dst_smem_ptr=smem_ptr,
644
+ ... mbarrier_ptr=barrier_ptr,
645
+ ... col_idx=col_idx,
646
+ ... row_indices=row_indices
647
+ ... )
648
+ """
649
+ if len(row_indices) != 4:
650
+ raise ValueError(f"gather4 requires exactly 4 row indices, got {len(row_indices)}")
651
+ col_val = Int32(col_idx).ir_value()
652
+ row_vals = [Int32(row_idx).ir_value() for row_idx in row_indices]
653
+ # Convert pointers to integer addresses
654
+ desc_addr = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value()
655
+ dst_addr = dst_smem_ptr.toint(loc=loc, ip=ip).ir_value()
656
+ mbar_addr = mbarrier_ptr.toint(loc=loc, ip=ip)
657
+ if num_cta > 1:
658
+ # Executed by both CTAs. Set peer bit to 0 so that the
659
+ # transaction bytes will update CTA0's barrier.
660
+ mbar_addr = mbar_addr & Sm100MmaPeerBitMask
661
+ mbar_addr = mbar_addr.ir_value()
662
+ # Handle multicast_mask - may already be ir.Value or Python int
663
+ multicast_mask_val = None
664
+ if multicast_mask is not None:
665
+ multicast_mask_val = Int16(multicast_mask).ir_value()
666
+ assert multicast_mask_val is None, "multicast is not supported yet"
667
+ # Emit inline PTX for TMA gather4
668
+ # PTX: cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes
669
+ # [dstMem], [tensorMap, {col, row0, row1, row2, row3}], [smem_bar];
670
+ ptx = (
671
+ f"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::{num_cta} "
672
+ "[$0], [$1, {$2, $3, $4, $5, $6}], [$7];"
673
+ )
674
+
675
+ llvm.inline_asm(
676
+ None,
677
+ [
678
+ dst_addr,
679
+ desc_addr,
680
+ col_val,
681
+ row_vals[0],
682
+ row_vals[1],
683
+ row_vals[2],
684
+ row_vals[3],
685
+ mbar_addr,
686
+ ],
687
+ ptx,
688
+ "r,l,r,r,r,r,r,r", # constraints: register, long, 6x register
689
+ has_side_effects=True,
690
+ is_align_stack=False,
691
+ asm_dialect=llvm.AsmDialect.AD_ATT,
692
+ loc=loc,
693
+ ip=ip,
694
+ )
695
+
696
+
697
+ def cpasync_bulk_get_copy_fn(
698
+ src_tensor: cute.Tensor,
699
+ dst_tensor: cute.Tensor,
700
+ single_stage: bool = False,
701
+ **kwargs,
702
+ ) -> Callable:
703
+ group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))
704
+ group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))
705
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
706
+ src = cute.group_modes(src_tensor, 0, group_rank_src)
707
+ dst = cute.group_modes(dst_tensor, 0, group_rank_dst)
708
+
709
+ def copy_bulk(src_idx, dst_idx, tma_bar_ptr: cute.Pointer, **new_kwargs):
710
+ atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
711
+ with cute.arch.elect_one():
712
+ cute.copy(
713
+ atom,
714
+ src[None, src_idx],
715
+ dst[None, dst_idx],
716
+ mbar_ptr=tma_bar_ptr,
717
+ **new_kwargs,
718
+ **kwargs,
719
+ )
720
+
721
+ def copy_bulk_single_stage(tma_bar_ptr: cute.Pointer, **new_kwargs):
722
+ atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), src.element_type)
723
+ with cute.arch.elect_one():
724
+ cute.copy(atom, src, dst, mbar_ptr=tma_bar_ptr, **new_kwargs, **kwargs)
725
+
726
+ return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage
727
+
728
+
729
+ @dsl_user_op
730
+ def tma_get_copy_fn(
731
+ atom: cute.CopyAtom,
732
+ cta_coord: cute.Coord,
733
+ cta_layout: cute.Layout,
734
+ src_tensor: cute.Tensor,
735
+ dst_tensor: cute.Tensor,
736
+ filter_zeros: bool = False,
737
+ single_stage: bool = False,
738
+ *,
739
+ loc=None,
740
+ ip=None,
741
+ **kwargs,
742
+ ) -> Callable:
743
+ src_is_smem = const_expr(
744
+ isinstance(src_tensor.iterator, cute.Pointer)
745
+ and src_tensor.memspace == cute.AddressSpace.smem
746
+ )
747
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
748
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
749
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
750
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
751
+ s, g = cpasync.tma_partition(
752
+ atom,
753
+ cta_coord,
754
+ cta_layout,
755
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
756
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
757
+ loc=loc,
758
+ ip=ip,
759
+ )
760
+ if const_expr(filter_zeros):
761
+ s = cute.filter_zeros(s)
762
+ g = cute.filter_zeros(g)
763
+ src, dst = (s, g) if src_is_smem else (g, s)
764
+
765
+ @dsl_user_op
766
+ def copy_tma(src_idx, dst_idx, *, loc=None, ip=None, **new_kwargs):
767
+ cute.copy(
768
+ atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs, loc=loc, ip=ip
769
+ )
770
+
771
+ @dsl_user_op
772
+ def copy_tma_single_stage(*, loc=None, ip=None, **new_kwargs):
773
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs, loc=loc, ip=ip)
774
+
775
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
776
+
777
+
778
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
779
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
780
+ copy(
781
+ src_idx=src_idx,
782
+ dst_idx=producer_state.index,
783
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
784
+ **new_kwargs,
785
+ )
786
+
787
+ return copy_fn
788
+
789
+
790
+ @cute.jit
791
+ def gather_m_get_copy_fn(
792
+ thr_copy_A: cute.ThrCopy,
793
+ mA: cute.Tensor, # (whatever, K)
794
+ sA: cute.Tensor, # (tile_M, tile_K, STAGE)
795
+ gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
796
+ limit_m: Int32,
797
+ limit_k: Int32,
798
+ ) -> Callable:
799
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
800
+ tAsA = thr_copy_A.partition_D(sA)
801
+ # k-major
802
+ assert tAsA.shape[2] == 1
803
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
804
+
805
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
806
+ if const_expr(not is_even_m_smem):
807
+ limit_m = min(limit_m, tile_shape_mk[0])
808
+ elems_per_load = cute.size(tAsA.shape[0][0])
809
+ cA = cute.make_identity_tensor(tile_shape_mk)
810
+ tAcA = thr_copy_A.partition_S(cA)
811
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
812
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
813
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
814
+ # This is so that when we do the comparison, t0AcA is known at compile time.
815
+ limit_m = limit_m - tAcA[0][0]
816
+ limit_k = limit_k - tAcA[0][1]
817
+ # Read and cache indices for A
818
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
819
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
820
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
821
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
822
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
823
+ m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
824
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
825
+ row_idx = tAcA[0, m, 0][0]
826
+ if tApA_m[m]:
827
+ m_idx[m] = gsAIdx[row_idx]
828
+ else:
829
+ m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
830
+
831
+ mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
832
+
833
+ def copy_fn(src_idx, dst_idx, pred: bool = False):
834
+ tApA_k = None
835
+ if const_expr(pred):
836
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
837
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
838
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
839
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
840
+ mA_cur = mA_k[None, (None, src_idx)]
841
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
842
+ # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
843
+ # ((elems_per_load), thread_per_row)
844
+ # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
845
+ # So we append 1s to the last dimension and then do tiled_divide, then slice.
846
+ mA_row = cute.tiled_divide(
847
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
848
+ )[None, None, 0]
849
+ if const_expr(is_even_m_smem) or tApA_m[m]:
850
+ # There's only 1 load per row
851
+ assert cute.size(tAcA.shape, mode=[2]) == 1
852
+ ki = tAcA[0, 0, 0][1] // elems_per_load
853
+ cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
854
+
855
+ return copy_fn
856
+
857
+
858
+ @cute.jit
859
+ def gather_k_get_copy_fn(
860
+ thr_copy_A: cute.ThrCopy,
861
+ mA: cute.Tensor, # (tile_M, whatever)
862
+ sA: cute.Tensor, # (tile_M, tile_K, STAGE)
863
+ gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
864
+ limit_m: Int32,
865
+ limit_k: Int32,
866
+ ) -> Callable:
867
+ gAIdx, sAIdx = None, None
868
+ if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
869
+ gAIdx = gsAIdx
870
+ else:
871
+ assert gsAIdx.memspace == cute.AddressSpace.smem
872
+ sAIdx = gsAIdx
873
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
874
+ # (atom_v, CPY_M, 1, STAGE)
875
+ tAsA = thr_copy_A.partition_D(sA)
876
+ # m-major
877
+ tAsA = cute.group_modes(tAsA, 0, 3)
878
+
879
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
880
+ if const_expr(not is_even_m_smem):
881
+ limit_m = min(limit_m, tile_shape_mk[0])
882
+ elems_per_load = cute.size(tAsA.shape[0][0])
883
+ cA = cute.make_identity_tensor(tile_shape_mk)
884
+ tAcA = thr_copy_A.partition_S(cA)
885
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
886
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
887
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
888
+ # This is so that when we do the comparison, t0AcA is known at compile time.
889
+ limit_m = limit_m - tAcA[0][0]
890
+ limit_k = limit_k - tAcA[0][1]
891
+ # Read and cache indices for A
892
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
893
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
894
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
895
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
896
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
897
+ threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
898
+ # This is very convoluted but idk a better way
899
+ # for tile_M=128, flat_divide gives (8, 16, K),
900
+ # then logical_divide gives ((8, 1), (8, 2), K).
901
+ tidx = thr_copy_A.thr_idx
902
+ tAmA = cute.logical_divide(
903
+ cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
904
+ )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
905
+
906
+ def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
907
+ # Prefetch mAIdx early, even before smem is free
908
+ tApA_k = None
909
+ if const_expr(pred):
910
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
911
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
912
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
913
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
914
+ gAIdx_cur = gAIdx[None, src_idx]
915
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
916
+ for k in cutlass.range(cols_per_thread):
917
+ col_idx = tAcA[0, 0, k][1]
918
+ if const_expr(not pred):
919
+ k_idx[k] = gAIdx_cur[col_idx]
920
+ else:
921
+ if tApA_k[k]:
922
+ k_idx[k] = gAIdx_cur[col_idx]
923
+ else:
924
+ k_idx[k] = -1
925
+ return k_idx, tApA_k
926
+
927
+ def prefetch_from_smem_fn(
928
+ a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
929
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
930
+ tApA_k = None
931
+ if const_expr(pred):
932
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
933
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
934
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
935
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
936
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
937
+ sAIdx_cur = sAIdx[None, dst_idx]
938
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
939
+ for k in cutlass.range(cols_per_thread):
940
+ col_idx = tAcA[0, 0, k][1]
941
+ k_idx[k] = sAIdx_cur[col_idx]
942
+ cute.arch.sync_warp()
943
+ with cute.arch.elect_one():
944
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
945
+ return k_idx, tApA_k
946
+
947
+ def copy_fn(
948
+ src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
949
+ ):
950
+ k_idx, tApA_k = k_idx_tApA_k
951
+ tApA_k_pred = None
952
+ if const_expr(pred):
953
+ tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
954
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
955
+ # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
956
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
957
+ if tApA_m[m]:
958
+ cute.copy(
959
+ thr_copy_A,
960
+ tAmA[None, m, k_idx[k]],
961
+ tAsA[(None, m, k), dst_idx],
962
+ pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
963
+ )
964
+
965
+ return copy_fn, prefetch_from_gmem_fn if const_expr(
966
+ gAIdx is not None
967
+ ) else prefetch_from_smem_fn
968
+
969
+
970
+ @cute.jit
971
+ def gather_m_get_tma_copy_fn(
972
+ tma_atom: cute.CopyAtom,
973
+ mA: cute.Tensor, # (whatever, K)
974
+ sA: cute.Tensor, # ((4, 32), (64, 1), STAGE)
975
+ sAIdx: cute.Tensor, # (tile_M),
976
+ warp_idx: Int32,
977
+ num_warps: int,
978
+ num_cta: int = 1,
979
+ ) -> Callable:
980
+ tile_M = cute.size(sAIdx, mode=[0])
981
+ tile_K = cute.size(sA[None, None, 0]) // tile_M
982
+ assert tile_M % 4 == 0
983
+ # cta_group = 1 if tma_atom.op.cta_group == CtaGroup.ONE else 2
984
+ cta_group = num_cta # Somehow all tma_atom has CtaGroup.ONE inside the kernel
985
+
986
+ copy_AIdx_s2r = cute.make_tiled_copy_tv(
987
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
988
+ cute.make_layout(num_warps), # thr_layout
989
+ cute.make_layout(4), # val_layout
990
+ )
991
+ warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
992
+ tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx)
993
+ # ((4, 1), 8, (64, 1), STAGE)
994
+ tSR_sA = warp_copy_AIdx_s2r.partition_S(sA)
995
+ tSR_rAIdx = load_s2r(tSR_sAIdx)
996
+ tma_desc_ptr = get_tma_desc_addr(tma_atom)
997
+ tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
998
+
999
+ def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
1000
+ col_idx = tile_K * src_idx
1001
+ for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1002
+ row_indices = [tSR_rAIdx[v, m] for v in range(4)]
1003
+ smem_ptr = tSR_sA[None, m, None, dst_idx].iterator
1004
+ with cute.arch.elect_one():
1005
+ tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
1006
+
1007
+ return copy_fn
build/torch-cuda/quack/cute_dsl_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple, get_origin
4
+ from functools import lru_cache
5
+ from dataclasses import dataclass, fields
6
+
7
+ import torch
8
+
9
+ try:
10
+ from triton.tools.disasm import extract
11
+ except ImportError:
12
+ extract = None
13
+
14
+ import cutlass
15
+ import cutlass.cute as cute
16
+ from cutlass import Int32, Int64, Float16, BFloat16, Float32
17
+ from cutlass.base_dsl.typing import JitArgument
18
+ from cutlass.base_dsl.tvm_ffi_builder import spec
19
+ from cutlass.cutlass_dsl import NumericMeta
20
+
21
+
22
+ StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
23
+
24
+
25
+ load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
26
+ cute_compile_og = cute.compile
27
+
28
+
29
+ # Patch TVM-FFI converter to handle Constexpr type annotations as compile-time constants.
30
+ # Fields annotated with cutlass.Constexpr[T] are emitted as ConstNone (not runtime args).
31
+ # At call time, pass None for these fields; the compile-time value is baked in.
32
+ import cutlass.cute._tvm_ffi_args_spec_converter as _converter_module # noqa
33
+
34
+ _original_convert_single_arg = _converter_module._convert_single_arg
35
+
36
+
37
+ def _patched_convert_single_arg(arg, arg_name, arg_type, ctx):
38
+ if arg_type is not None and get_origin(arg_type) is cutlass.Constexpr:
39
+ return spec.ConstNone(arg_name)
40
+ # If arg is a NamedTuple but arg_type doesn't have _fields (e.g. annotated as tuple),
41
+ # redirect so the converter uses the NamedTuple's own type hints.
42
+ if (
43
+ isinstance(arg, tuple)
44
+ and hasattr(type(arg), "_fields")
45
+ and (arg_type is None or not hasattr(arg_type, "_fields"))
46
+ ):
47
+ return _original_convert_single_arg(arg, arg_name, type(arg), ctx)
48
+ return _original_convert_single_arg(arg, arg_name, arg_type, ctx)
49
+
50
+
51
+ _converter_module._convert_single_arg = _patched_convert_single_arg
52
+
53
+
54
+ torch2cute_dtype_map = {
55
+ torch.float16: Float16,
56
+ torch.bfloat16: BFloat16,
57
+ torch.float32: Float32,
58
+ torch.int32: Int32,
59
+ torch.int64: Int64,
60
+ }
61
+
62
+
63
+ @lru_cache
64
+ def get_max_active_clusters(cluster_size):
65
+ return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
66
+
67
+
68
+ @lru_cache
69
+ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
70
+ return torch.cuda.get_device_capability(device)
71
+
72
+
73
+ def _partition_fields(obj):
74
+ """Split dataclass fields into (constexpr_dict, non_constexpr_dict) by type."""
75
+ all_fields = {field.name: getattr(obj, field.name) for field in fields(obj)}
76
+ constexpr = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
77
+ non_constexpr = {n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)}
78
+ return constexpr, non_constexpr
79
+
80
+
81
+ def _new_from_mlir_values(self, values):
82
+ constexpr_fields, non_constexpr_fields = _partition_fields(self)
83
+ for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
84
+ non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
85
+ values = values[n_items:]
86
+ return self.__class__(**non_constexpr_fields, **constexpr_fields)
87
+
88
+
89
+ def _namedtuple_new_from_mlir_values(self, values):
90
+ """Generic __new_from_mlir_values__ for NamedTuples.
91
+
92
+ Applied to NamedTuple classes via the ``@mlir_namedtuple`` decorator.
93
+
94
+ Fields that are None or Constexpr (StaticTypes) are preserved from ``self`` (the compile-time
95
+ template). Only non-static fields consume MLIR values. Multi-value fields (e.g. cute.Tensor)
96
+ consume the correct number of values via ``cutlass.new_from_mlir_values``.
97
+
98
+ Constexpr fields (annotated ``cutlass.Constexpr[T]``) are baked into the compiled kernel via
99
+ a converter patch (see above). At call time, pass None for these fields.
100
+ """
101
+ from cutlass.base_dsl.typing import get_mlir_types
102
+
103
+ values = list(values)
104
+ new_fields = []
105
+ for field_val in self:
106
+ if field_val is None or isinstance(field_val, StaticTypes):
107
+ new_fields.append(field_val)
108
+ else:
109
+ n_items = len(get_mlir_types(field_val))
110
+ new_fields.append(cutlass.new_from_mlir_values(field_val, values[:n_items]))
111
+ values = values[n_items:]
112
+ return self.__class__(*new_fields)
113
+
114
+
115
+ def mlir_namedtuple(cls):
116
+ """Decorator that adds MLIR value reconstruction to a NamedTuple class.
117
+
118
+ Usage::
119
+
120
+ @mlir_namedtuple
121
+ class MyArgs(NamedTuple):
122
+ tensor_arg: cute.Tensor
123
+ const_arg: cutlass.Constexpr[int] = 0
124
+ """
125
+ cls.__new_from_mlir_values__ = _namedtuple_new_from_mlir_values
126
+ return cls
127
+
128
+
129
+ @dataclass
130
+ class ParamsBase:
131
+ def __extract_mlir_values__(self):
132
+ _, non_constexpr_fields = _partition_fields(self)
133
+ values, self._values_pos = [], []
134
+ for obj in non_constexpr_fields.values():
135
+ obj_values = cutlass.extract_mlir_values(obj)
136
+ values += obj_values
137
+ self._values_pos.append(len(obj_values))
138
+ return values
139
+
140
+ __new_from_mlir_values__ = _new_from_mlir_values
141
+
142
+
143
+ @dataclass
144
+ class ArgumentsBase(JitArgument):
145
+ def __c_pointers__(self):
146
+ _, non_constexpr_fields = _partition_fields(self)
147
+ c_ptrs = []
148
+ for obj in non_constexpr_fields.values():
149
+ if hasattr(obj, "__c_pointers__"):
150
+ c_ptrs.extend(obj.__c_pointers__())
151
+ return c_ptrs
152
+
153
+ def __get_mlir_types__(self):
154
+ _, non_constexpr_fields = _partition_fields(self)
155
+ types, self._values_pos = [], []
156
+ for obj in non_constexpr_fields.values():
157
+ if hasattr(obj, "__get_mlir_types__"):
158
+ obj_types = obj.__get_mlir_types__()
159
+ types.extend(obj_types)
160
+ self._values_pos.append(len(obj_types))
161
+ else:
162
+ self._values_pos.append(0)
163
+ return types
164
+
165
+ __new_from_mlir_values__ = _new_from_mlir_values
build/torch-cuda/quack/layout_utils.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+ from cutlass import Int32, const_expr
8
+
9
+
10
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
11
+ """Transpose the first two dimensions of a tensor on smem."""
12
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
13
+ order = (1, 0, *range(2, cute.rank(a)))
14
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
15
+
16
+
17
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
18
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
19
+
20
+
21
+ def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
22
+ shape = (*a.shape[:dim], size, *a.shape[dim:])
23
+ stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
24
+ return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
25
+
26
+
27
+ @cute.jit
28
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
29
+ assert t.element_type.width == 16
30
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
31
+ t_u32 = cute.recast_tensor(t, Int32)
32
+
33
+ quad_idx = cute.arch.lane_idx() % 4
34
+ lane_03 = quad_idx == 0 or quad_idx == 3
35
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
36
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
37
+ # upper_map = [0, 3, 1, 2]
38
+ # lower_map = [1, 2, 0, 3]
39
+ # upper_idx = upper_map[quad_idx]
40
+ # indexing isn't supported so we have to do arithmetic
41
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
42
+ lower_idx = upper_idx ^ 1
43
+
44
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
45
+ width = 4
46
+ mask = cute.arch.WARP_SIZE - width
47
+ clamp = cute.arch.WARP_SIZE - 1
48
+ mask_and_clamp = mask << 8 | clamp
49
+
50
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
51
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
52
+ upper0 = upper if lane_03 else lower
53
+ lower0 = lower if lane_03 else upper
54
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
55
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
56
+ t_u32[i * 2 + 0] = cute.arch.prmt(upper0, lower0, selector_upper)
57
+ t_u32[i * 2 + 1] = cute.arch.prmt(upper0, lower0, selector_lower)
58
+
59
+
60
+ @cute.jit
61
+ def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
62
+ """Permute and shuffle within 4 threads to change the layout from
63
+ T0 | T1 | T2 | T3
64
+ a b | c d | e f | g h
65
+ to
66
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
67
+ a | b | c | d | e | f | g | h
68
+ This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
69
+ """
70
+
71
+ assert t.element_type.width == 32
72
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
73
+
74
+ quad_idx = cute.arch.lane_idx() % 4
75
+ # left_map = [0, 2, 1, 3]
76
+ # right_map = [2, 0, 3, 1]
77
+ # indexing isn't supported so we have to do arithmetic
78
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
79
+ right_idx = left_idx ^ 0b10
80
+
81
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
82
+ width = 4
83
+ mask = cute.arch.WARP_SIZE - width
84
+ clamp = cute.arch.WARP_SIZE - 1
85
+ mask_and_clamp = mask << 8 | clamp
86
+
87
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
88
+ for r in cutlass.range(2, unroll_full=True):
89
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
90
+ # a b | c d | e f | g h -> a b | c d | f e | h g
91
+ left0 = left if quad_idx < 2 else right
92
+ right0 = right if quad_idx < 2 else left
93
+ # a b | c d | f e | h g -> a b | f d | c e | h g
94
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
95
+ # a b | f d | c e | h g -> a e | f b | c g | h d
96
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
97
+ # a e | f b | c g | h d -> a e | b f | c g | d h
98
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
99
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
100
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
101
+
102
+
103
+ @cute.jit
104
+ def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
105
+ """Permute and shuffle within 4 threads to change the layout from
106
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
107
+ a | b | c | d | e | f | g | h
108
+ to
109
+ T0 | T1 | T2 | T3
110
+ a b | c d | e f | g h
111
+ This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
112
+ """
113
+
114
+ assert t.element_type.width == 32
115
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
116
+
117
+ quad_idx = cute.arch.lane_idx() % 4
118
+ # left_map = [0, 2, 1, 3]
119
+ # right_map = [1, 3, 0, 2]
120
+ # indexing isn't supported so we have to do arithmetic
121
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
122
+ right_idx = left_idx ^ 0b01
123
+
124
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
125
+ width = 4
126
+ mask = cute.arch.WARP_SIZE - width
127
+ clamp = cute.arch.WARP_SIZE - 1
128
+ mask_and_clamp = mask << 8 | clamp
129
+
130
+ # This is just the inverse of permute_Cregs_b32_for_stsm
131
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
132
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
133
+ for r in cutlass.range(2, unroll_full=True):
134
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
135
+ # a e | b f | c g | d h -> a e | f b | c g | h d
136
+ left0 = left if quad_idx % 2 == 0 else right
137
+ right0 = right if quad_idx % 2 == 0 else left
138
+ # a e | f b | c g | h d -> a b | f d | c e | h g
139
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
140
+ # a b | f d | c e | h g -> a b | c d | f e | h g
141
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
142
+ # a b | c d | f e | h g -> a b | c d | e f | g h
143
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
144
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
145
+
146
+
147
+ @cute.jit
148
+ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
149
+ return cute.make_layout(
150
+ tuple(l.shape for l in layouts),
151
+ stride=tuple(l.stride for l in layouts),
152
+ )
153
+
154
+
155
+ def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
156
+ """
157
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
158
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
159
+ """
160
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
161
+ shape = (
162
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
163
+ (
164
+ acc_layout_col_major.shape[0][0],
165
+ *acc_layout_col_major.shape[0][2:],
166
+ acc_layout_col_major.shape[2],
167
+ ), # MMA_N
168
+ *acc_layout_col_major.shape[3:],
169
+ )
170
+ stride = (
171
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
172
+ (
173
+ acc_layout_col_major.stride[0][0],
174
+ *acc_layout_col_major.stride[0][2:],
175
+ acc_layout_col_major.stride[2],
176
+ ), # MMA_N
177
+ *acc_layout_col_major.stride[3:],
178
+ )
179
+ if const_expr(transpose):
180
+ shape = (shape[1], shape[0], *shape[2:])
181
+ stride = (stride[1], stride[0], *stride[2:])
182
+ acc_layout_mn = cute.make_layout(shape, stride=stride)
183
+ return cute.composition(acc_layout, acc_layout_mn)
184
+
185
+
186
+ def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
187
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
188
+
189
+
190
+ def reshape_acc_to_mn(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
191
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
192
+
193
+
194
+ @cute.jit
195
+ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
196
+ # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
197
+ # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
198
+ # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
199
+ # If N / 8 is odd, we'll convert to ((2, 2, 1), MMA_M, N / 8, MMA_N).
200
+ # TODO: Sm90 FP8
201
+ if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
202
+ div = 2 if const_expr(acc_layout.shape[0][2] % 2 == 0) else 1
203
+ l = cute.logical_divide(
204
+ acc_layout, ((None, None, div), None, None)
205
+ ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
206
+ rA_mma_view = cute.make_layout(
207
+ (
208
+ (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
209
+ l.shape[1],
210
+ (l.shape[0][2][1], l.shape[2]),
211
+ ),
212
+ stride=(
213
+ (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
214
+ l.stride[1],
215
+ (l.stride[0][2][1], l.stride[2]),
216
+ ),
217
+ )
218
+ else: # Sm80
219
+ # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
220
+ l = cute.logical_divide(acc_layout, (None, None, 2))
221
+ rA_mma_view = cute.make_layout(
222
+ (
223
+ (l.shape[0], l.shape[2][0]),
224
+ l.shape[1],
225
+ l.shape[2][1],
226
+ ),
227
+ stride=(
228
+ (l.stride[0], l.stride[2][0]),
229
+ l.stride[1],
230
+ l.stride[2][1],
231
+ ),
232
+ )
233
+ return rA_mma_view
234
+
235
+
236
+ def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
237
+ return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
238
+
239
+
240
+ def convert_layout_zero_stride(
241
+ input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
242
+ ) -> cute.Layout:
243
+ layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
244
+ # Group the modes with non-zero stride in the ref_layout together,
245
+ # and the modes with zero stride together
246
+ layout_flat = cute.flatten(layout)
247
+ ref_layout_flat = cute.flatten(ref_layout)
248
+ nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
249
+ zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
250
+ # There's an edge case when all modes are zero stride
251
+ new_shape = (
252
+ tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
253
+ tuple(layout_flat[i].shape for i in zero_modes),
254
+ )
255
+ new_stride = (
256
+ tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
257
+ tuple(layout_flat[i].stride for i in zero_modes),
258
+ )
259
+ out_layout = cute.make_layout(new_shape, stride=new_stride)
260
+ if const_expr(isinstance(input, cute.Tensor)):
261
+ return cute.make_tensor(input.iterator, out_layout)
262
+ else:
263
+ return out_layout
264
+
265
+
266
+ def mma_partition_C_vec(
267
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
268
+ ) -> cute.Tensor:
269
+ assert cute.rank(sVec) == 2
270
+ assert sVec.stride[0] == 1
271
+ stage = sVec.shape[1]
272
+ shape = (
273
+ (sVec.shape[0], expand_shape, stage)
274
+ if const_expr(is_colvec)
275
+ else (expand_shape, sVec.shape[0], stage)
276
+ )
277
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
278
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
279
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
280
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
281
+
282
+
283
+ def mma_partition_A_vec(
284
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
285
+ ) -> cute.Tensor:
286
+ assert cute.rank(sVec) == 2
287
+ assert sVec.stride[0] == 1
288
+ stage = sVec.shape[1]
289
+ shape = (
290
+ (sVec.shape[0], expand_shape, stage)
291
+ if const_expr(is_colvec)
292
+ else (expand_shape, sVec.shape[0], stage)
293
+ )
294
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
295
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
296
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
297
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
build/torch-cuda/quack/sm90_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Type, Union, Optional
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ import cutlass.utils.hopper_helpers as sm90_utils_og
8
+ from cutlass.cute.nvgpu import warpgroup
9
+ from cutlass.cutlass_dsl import Numeric, dsl_user_op
10
+ from cutlass import Float32, Int32, Boolean, const_expr
11
+ from cutlass.utils import LayoutEnum
12
+
13
+
14
+ @dsl_user_op
15
+ def make_smem_layout(
16
+ dtype: Type[Numeric],
17
+ layout: LayoutEnum,
18
+ tile: cute.Tile,
19
+ stage: Optional[int] = None,
20
+ major_mode_size: Optional[int] = None,
21
+ *,
22
+ loc=None,
23
+ ip=None,
24
+ ) -> Union[cute.Layout, cute.ComposedLayout]:
25
+ shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
26
+ if const_expr(major_mode_size is None):
27
+ major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
28
+ smem_layout_atom = warpgroup.make_smem_layout_atom(
29
+ sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
30
+ dtype,
31
+ )
32
+ order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
33
+ smem_layout_staged = cute.tile_to_shape(
34
+ smem_layout_atom,
35
+ cute.append(shape, stage) if const_expr(stage is not None) else shape,
36
+ order=order if const_expr(stage is not None) else order[:2],
37
+ )
38
+ return smem_layout_staged
39
+
40
+
41
+ # For compatibility with blackwell_helpers.py
42
+ make_smem_layout_epi = make_smem_layout
43
+
44
+
45
+ @dsl_user_op
46
+ def partition_for_epilogue(
47
+ cT: cute.Tensor,
48
+ epi_tile: cute.Tile,
49
+ tiled_copy: cute.TiledCopy,
50
+ tidx: Int32,
51
+ reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
52
+ *,
53
+ loc=None,
54
+ ip=None,
55
+ ) -> cute.Tensor:
56
+ thr_copy = tiled_copy.get_slice(tidx)
57
+ cT_epi = cute.flat_divide(cT, epi_tile)
58
+ # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
59
+ if const_expr(reference_src):
60
+ return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
61
+ else:
62
+ return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
63
+
64
+
65
+ @cute.jit
66
+ def gemm(
67
+ tiled_mma: cute.TiledMma,
68
+ acc: cute.Tensor,
69
+ tCrA: cute.Tensor,
70
+ tCrB: cute.Tensor,
71
+ zero_init: cutlass.Constexpr[bool] = False,
72
+ wg_wait: cutlass.Constexpr[int] = 0,
73
+ # A_in_regs: cutlass.Constexpr[bool] = False,
74
+ swap_AB: cutlass.Constexpr[bool] = False,
75
+ ) -> None:
76
+ if const_expr(swap_AB):
77
+ gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
78
+ else:
79
+ warpgroup.fence()
80
+ # We make a new mma_atom since we'll be modifying its attribute (accumulate).
81
+ # Otherwise the compiler complains "operand #0 does not dominate this use"
82
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
83
+ mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
84
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
85
+ cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
86
+ mma_atom.set(warpgroup.Field.ACCUMULATE, True)
87
+ warpgroup.commit_group()
88
+ if const_expr(wg_wait >= 0):
89
+ warpgroup.wait_group(wg_wait)
90
+
91
+
92
+ def gemm_zero_init(
93
+ tiled_mma: cute.TiledMma,
94
+ shape: cute.Shape,
95
+ tCrA: cute.Tensor,
96
+ tCrB: cute.Tensor,
97
+ A_idx: Optional[Int32] = None,
98
+ B_idx: Optional[Int32] = None,
99
+ wg_wait: int = -1,
100
+ swap_AB: bool = False,
101
+ ) -> cute.Tensor:
102
+ if const_expr(swap_AB):
103
+ return gemm_zero_init(
104
+ tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
105
+ )
106
+ else:
107
+ acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32)
108
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
109
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
110
+ gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
111
+ return acc
112
+
113
+
114
+ def gemm_w_idx(
115
+ tiled_mma: cute.TiledMma,
116
+ acc: cute.Tensor,
117
+ tCrA: cute.Tensor,
118
+ tCrB: cute.Tensor,
119
+ zero_init: Boolean,
120
+ A_idx: Optional[Int32] = None,
121
+ B_idx: Optional[Int32] = None,
122
+ wg_wait: int = -1,
123
+ swap_AB: bool = False,
124
+ ) -> None:
125
+ if const_expr(swap_AB):
126
+ gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
127
+ else:
128
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
129
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
130
+ gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
131
+
132
+
133
+ def partition_fragment_ABC(
134
+ thr_mma: cute.ThrMma,
135
+ shape_mnk: cute.Shape,
136
+ sA: Optional[cute.Tensor],
137
+ sB: Optional[cute.Tensor],
138
+ swap_AB: bool = False,
139
+ ):
140
+ is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
141
+ if const_expr(not swap_AB):
142
+ acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
143
+ if const_expr(not is_rs):
144
+ assert sA is not None
145
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
146
+ else:
147
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
148
+ assert sB is not None
149
+ tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
150
+ else:
151
+ acc = cute.make_rmem_tensor(
152
+ thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32
153
+ )
154
+ if const_expr(not is_rs):
155
+ assert sB is not None
156
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
157
+ else: # B in rmem
158
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
159
+ assert sA is not None
160
+ tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
161
+ return acc, tCrA, tCrB
build/torch-cuda/seqlen_info.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from dataclasses import dataclass
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Int32, const_expr
7
+
8
+ """
9
+ This consolidates all the info related to sequence length. This is so that we can do all
10
+ the gmem reads once at the beginning of each tile, rather than having to repeat these reads
11
+ to compute various things like n_block_min, n_block_max, etc.
12
+ """
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class SeqlenInfo:
17
+ offset: cutlass.Int32
18
+ seqlen: cutlass.Int32
19
+
20
+ @staticmethod
21
+ def create(
22
+ batch_idx: cutlass.Int32,
23
+ seqlen_static: cutlass.Int32,
24
+ cu_seqlens: Optional[cute.Tensor] = None,
25
+ seqused: Optional[cute.Tensor] = None,
26
+ ):
27
+ offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
28
+ if const_expr(seqused is not None):
29
+ seqlen = seqused[batch_idx]
30
+ elif const_expr(cu_seqlens is not None):
31
+ seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
32
+ else:
33
+ seqlen = seqlen_static
34
+ return SeqlenInfo(offset, seqlen)
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class SeqlenInfoQK:
39
+ offset_q: cutlass.Int32
40
+ offset_k: cutlass.Int32
41
+ padded_offset_q: cutlass.Int32
42
+ padded_offset_k: cutlass.Int32
43
+ seqlen_q: cutlass.Int32
44
+ seqlen_k: cutlass.Int32
45
+ has_cu_seqlens_q: cutlass.Constexpr[bool]
46
+ has_cu_seqlens_k: cutlass.Constexpr[bool]
47
+ has_seqused_q: cutlass.Constexpr[bool]
48
+ has_seqused_k: cutlass.Constexpr[bool]
49
+
50
+ @staticmethod
51
+ def create(
52
+ batch_idx: cutlass.Int32,
53
+ seqlen_q_static: cutlass.Int32,
54
+ seqlen_k_static: cutlass.Int32,
55
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
56
+ mCuSeqlensK: Optional[cute.Tensor] = None,
57
+ mSeqUsedQ: Optional[cute.Tensor] = None,
58
+ mSeqUsedK: Optional[cute.Tensor] = None,
59
+ tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
60
+ tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
61
+ ):
62
+ offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
63
+ offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
64
+ padded_offset_q = (
65
+ 0
66
+ if const_expr(mCuSeqlensQ is None)
67
+ else (offset_q + batch_idx * tile_m) // tile_m * tile_m
68
+ )
69
+ padded_offset_k = (
70
+ 0
71
+ if const_expr(mCuSeqlensK is None)
72
+ else (offset_k + batch_idx * tile_n) // tile_n * tile_n
73
+ )
74
+ if const_expr(mSeqUsedQ is not None):
75
+ seqlen_q = mSeqUsedQ[batch_idx]
76
+ else:
77
+ seqlen_q = (
78
+ seqlen_q_static
79
+ if const_expr(mCuSeqlensQ is None)
80
+ else mCuSeqlensQ[batch_idx + 1] - offset_q
81
+ )
82
+ if const_expr(mSeqUsedK is not None):
83
+ seqlen_k = mSeqUsedK[batch_idx]
84
+ else:
85
+ seqlen_k = (
86
+ seqlen_k_static
87
+ if const_expr(mCuSeqlensK is None)
88
+ else mCuSeqlensK[batch_idx + 1] - offset_k
89
+ )
90
+ has_cu_seqlens_q: int = mCuSeqlensQ is not None
91
+ has_cu_seqlens_k: int = mCuSeqlensK is not None
92
+ has_seqused_q: int = mSeqUsedQ is not None
93
+ has_seqused_k: int = mSeqUsedK is not None
94
+ return SeqlenInfoQK(
95
+ offset_q,
96
+ offset_k,
97
+ padded_offset_q,
98
+ padded_offset_k,
99
+ seqlen_q,
100
+ seqlen_k,
101
+ has_cu_seqlens_q,
102
+ has_cu_seqlens_k,
103
+ has_seqused_q,
104
+ has_seqused_k,
105
+ )
106
+
107
+ def offset_batch_Q(
108
+ self,
109
+ mQ: cute.Tensor,
110
+ batch_idx: Int32,
111
+ dim: int,
112
+ padded: cutlass.Constexpr[bool] = False,
113
+ ) -> cute.Tensor:
114
+ """Seqlen must be the first dimension of mQ"""
115
+ if const_expr(not self.has_cu_seqlens_q):
116
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
117
+ return mQ[idx]
118
+ else:
119
+ offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
120
+ offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
121
+ idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
122
+ return cute.domain_offset(idx, mQ)
123
+
124
+ def offset_batch_K(
125
+ self,
126
+ mK: cute.Tensor,
127
+ batch_idx: Int32,
128
+ dim: int,
129
+ padded: cutlass.Constexpr[bool] = False,
130
+ ) -> cute.Tensor:
131
+ """Seqlen must be the first dimension of mK"""
132
+ if const_expr(not self.has_cu_seqlens_k):
133
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
134
+ return mK[idx]
135
+ else:
136
+ offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
137
+ idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
138
+ return cute.domain_offset(idx, mK)
build/torch-cuda/softmax.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ import operator
5
+ from typing import Tuple
6
+ from dataclasses import dataclass
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+ from cutlass import Float32
11
+
12
+ from .quack import layout_utils
13
+ from . import utils
14
+ from .quack.cute_dsl_utils import ParamsBase
15
+ from .seqlen_info import SeqlenInfoQK
16
+
17
+
18
+ @dataclass
19
+ class Softmax(ParamsBase):
20
+ scale_log2: Float32
21
+ num_rows: cutlass.Constexpr[int]
22
+ row_max: cute.Tensor
23
+ row_sum: cute.Tensor
24
+ arch: cutlass.Constexpr[int] = 80
25
+ softmax_scale: Float32 | None = None
26
+
27
+ @staticmethod
28
+ def create(
29
+ scale_log2: Float32,
30
+ num_rows: cutlass.Constexpr[int],
31
+ arch: cutlass.Constexpr[int] = 80,
32
+ softmax_scale: Float32 | None = None,
33
+ ):
34
+ row_max = cute.make_rmem_tensor(num_rows, Float32)
35
+ row_sum = cute.make_rmem_tensor(num_rows, Float32)
36
+ return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)
37
+
38
+ def reset(self) -> None:
39
+ self.row_max.fill(-Float32.inf)
40
+ self.row_sum.fill(0.0)
41
+
42
+ def _compute_row_max(
43
+ self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None
44
+ ) -> Float32:
45
+ return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)
46
+
47
+ def _compute_row_sum(
48
+ self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None
49
+ ) -> Float32:
50
+ return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)
51
+
52
+ @cute.jit
53
+ def online_softmax(
54
+ self,
55
+ acc_S: cute.Tensor,
56
+ is_first: cutlass.Constexpr[bool] = False,
57
+ check_inf: cutlass.Constexpr[bool] = True,
58
+ ) -> cute.Tensor:
59
+ """Apply online softmax and return the row_scale to rescale O.
60
+
61
+ :param acc_S: acc_S tensor
62
+ :type acc_S: cute.Tensor
63
+ :param is_first: is first n_block
64
+ :type is_first: cutlass.Constexpr
65
+ """
66
+ # Change acc_S to M,N layout view.
67
+ acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
68
+ row_scale = cute.make_fragment_like(self.row_max, Float32)
69
+
70
+ row_max = self.row_max
71
+ row_sum = self.row_sum
72
+ scale_log2 = self.scale_log2
73
+ arch = self.arch
74
+
75
+ # Each iteration processes one row of acc_S
76
+ for r in cutlass.range(cute.size(row_max), unroll_full=True):
77
+ acc_S_row = acc_S_mn[r, None].load() # (n_block_size)
78
+
79
+ row_max_cur = utils.fmax_reduce(
80
+ acc_S_row,
81
+ init_val=row_max[r] if cutlass.const_expr(not is_first) else None,
82
+ arch=arch,
83
+ )
84
+
85
+ row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4)
86
+ # Update row_max before changing row_max_cur to safe value for -inf
87
+ row_max_prev = row_max[r]
88
+ row_max[r] = row_max_cur
89
+
90
+ if cutlass.const_expr(check_inf):
91
+ row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur
92
+
93
+ if cutlass.const_expr(is_first):
94
+ row_max_cur_scaled = row_max_cur * scale_log2
95
+ acc_S_row_exp = cute.math.exp2(
96
+ acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
97
+ )
98
+ acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
99
+ row_scale[r] = 1.0
100
+ else:
101
+ row_max_cur_scaled = row_max_cur * scale_log2
102
+ acc_S_row_exp = cute.math.exp2(
103
+ acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
104
+ )
105
+ # row_scale[r] = cute.math.exp2(row_max_prev * self.scale_log2 - row_max_cur_scaled)
106
+ row_scale[r] = cute.math.exp2(
107
+ (row_max_prev - row_max_cur) * scale_log2, fastmath=True
108
+ )
109
+ acc_S_row_sum = utils.fadd_reduce(
110
+ acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
111
+ )
112
+
113
+ row_sum[r] = acc_S_row_sum
114
+ acc_S_mn[r, None].store(acc_S_row_exp)
115
+
116
+ return row_scale
117
+
118
+ @cute.jit
119
+ def finalize(
120
+ self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None
121
+ ) -> cute.Tensor:
122
+ """Finalize the online softmax by computing the scale and logsumexp."""
123
+ if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):
124
+ assert cute.size(sink_val) == cute.size(self.row_sum)
125
+ row_sum = self.row_sum
126
+ row_max = self.row_max
127
+ scale_log2 = self.scale_log2
128
+
129
+ # quad reduction for row_sum as we didn't do it during each iteration of online softmax
130
+ row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))
131
+ row_scale = cute.make_fragment_like(row_max, Float32)
132
+
133
+ for r in cutlass.range(cute.size(row_sum), unroll_full=True):
134
+ if cutlass.const_expr(sink_val is not None):
135
+ sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
136
+ LOG2_E = math.log2(math.e)
137
+ row_sum[r] += cute.math.exp2(
138
+ sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True
139
+ )
140
+
141
+ # if row_sum is zero or nan, set acc_O_mn_row to 1.0
142
+ acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
143
+ row_scale[r] = (
144
+ cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)
145
+ ) * final_scale
146
+ row_sum_cur = row_sum[r]
147
+ LN2 = math.log(2.0)
148
+ row_sum[r] = (
149
+ (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2
150
+ if not acc_O_mn_row_is_zero_or_nan
151
+ else -Float32.inf
152
+ )
153
+ return row_scale
154
+
155
+ @cute.jit
156
+ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:
157
+ """Scale each row of acc_O by the given scale tensor.
158
+ :param acc_O: input tensor
159
+ :type acc_O: cute.Tensor
160
+ :param row_scale: row_scale tensor
161
+ :type row_scale: cute.Tensor
162
+ """
163
+ acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)
164
+ assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
165
+ for r in cutlass.range(cute.size(row_scale), unroll_full=True):
166
+ acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
167
+
168
+
169
+ @dataclass
170
+ class SoftmaxSm100(Softmax):
171
+ rescale_threshold: cutlass.Constexpr[float] = 0.0
172
+
173
+ @staticmethod
174
+ def create(
175
+ scale_log2: Float32,
176
+ rescale_threshold: cutlass.Constexpr[float] = 0.0,
177
+ softmax_scale: Float32 | None = None,
178
+ ):
179
+ num_rows = 1
180
+ arch = 100
181
+ row_max = cute.make_rmem_tensor(num_rows, Float32)
182
+ row_sum = cute.make_rmem_tensor(num_rows, Float32)
183
+ return SoftmaxSm100(
184
+ scale_log2,
185
+ num_rows,
186
+ row_max,
187
+ row_sum,
188
+ arch,
189
+ softmax_scale,
190
+ rescale_threshold=rescale_threshold,
191
+ )
192
+
193
+ @cute.jit
194
+ def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:
195
+ if cutlass.const_expr(is_first):
196
+ row_max_new = self._compute_row_max(acc_S_row)
197
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
198
+ acc_scale = 0.0
199
+ else:
200
+ row_max_old = self.row_max[0]
201
+ row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
202
+ row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
203
+ acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
204
+ acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
205
+ if cutlass.const_expr(self.rescale_threshold > 0.0):
206
+ if acc_scale_ >= -self.rescale_threshold:
207
+ row_max_new = row_max_old
208
+ row_max_safe = row_max_old
209
+ acc_scale = 1.0
210
+ self.row_max[0] = row_max_new
211
+ return row_max_safe, acc_scale
212
+
213
+ def update_row_sum(
214
+ self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False
215
+ ) -> None:
216
+ init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None
217
+ # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale)
218
+ self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)
219
+ # tmp = self._compute_row_sum(acc_S_row_exp)
220
+ # self.row_sum[0] = self.row_sum[0] * row_scale + tmp
221
+
222
+ @cute.jit
223
+ def scale_subtract_rowmax(
224
+ self,
225
+ acc_S_row: cute.Tensor,
226
+ row_max: Float32,
227
+ ):
228
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
229
+ row_max_scaled = row_max * self.scale_log2
230
+ for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):
231
+ acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
232
+ (acc_S_row[i], acc_S_row[i + 1]),
233
+ (self.scale_log2, self.scale_log2),
234
+ (-row_max_scaled, -row_max_scaled),
235
+ )
236
+
237
+ @cute.jit
238
+ def apply_exp2_convert(
239
+ self,
240
+ acc_S_row: cute.Tensor,
241
+ acc_S_row_converted: cute.Tensor,
242
+ ex2_emu_freq: cutlass.Constexpr[int] = 0,
243
+ ex2_emu_res: cutlass.Constexpr[int] = 4,
244
+ ex2_emu_start_frg: cutlass.Constexpr[int] = 0,
245
+ ):
246
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
247
+ frg_tile = 32
248
+ assert frg_tile % 2 == 0
249
+ frg_cnt = cute.size(acc_S_row) // frg_tile
250
+ assert cute.size(acc_S_row) % frg_tile == 0
251
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
252
+ acc_S_row_converted_frg = cute.logical_divide(
253
+ acc_S_row_converted, cute.make_layout(frg_tile)
254
+ )
255
+ for j in cutlass.range_constexpr(frg_cnt):
256
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
257
+ # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
258
+ # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
259
+ if cutlass.const_expr(ex2_emu_freq == 0):
260
+ acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
261
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
262
+ else:
263
+ if cutlass.const_expr(
264
+ k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res
265
+ or j >= frg_cnt - 1
266
+ or j < ex2_emu_start_frg
267
+ ):
268
+ acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
269
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(
270
+ acc_S_row_frg[k + 1, j], fastmath=True
271
+ )
272
+ else:
273
+ # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])
274
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
275
+ acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]
276
+ )
277
+ acc_S_row_converted_frg[None, j].store(
278
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
279
+ )
280
+
281
+ @cute.jit
282
+ def scale_apply_exp2_convert(
283
+ self,
284
+ acc_S_row: cute.Tensor,
285
+ row_max: Float32,
286
+ acc_S_row_converted: cute.Tensor,
287
+ ):
288
+ assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
289
+ minus_row_max_scaled = -row_max * self.scale_log2
290
+ for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
291
+ acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
292
+ (acc_S_row[i], acc_S_row[i + 1]),
293
+ (self.scale_log2, self.scale_log2),
294
+ (minus_row_max_scaled, minus_row_max_scaled),
295
+ )
296
+
297
+ # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):
298
+ # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(
299
+ # (acc_S_row[i], acc_S_row[i + 1]),
300
+ # (self.scale_log2, self.scale_log2),
301
+ # (minus_row_max_scaled, minus_row_max_scaled),
302
+ # )
303
+ # acc_S_row[i] = cute.math.exp2(acc_S_row[i], fastmath=True)
304
+ # acc_S_row[i + 1] = cute.math.exp2(acc_S_row[i + 1], fastmath=True)
305
+
306
+ frg_tile = 32
307
+ assert frg_tile % 2 == 0
308
+ frg_cnt = cute.size(acc_S_row) // frg_tile
309
+ assert cute.size(acc_S_row) % frg_tile == 0
310
+ acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))
311
+ acc_S_row_converted_frg = cute.logical_divide(
312
+ acc_S_row_converted, cute.make_layout(frg_tile)
313
+ )
314
+ for j in cutlass.range_constexpr(frg_cnt):
315
+ for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
316
+ # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (
317
+ # cute.arch.fma_packed_f32x2(
318
+ # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),
319
+ # (self.scale_log2, self.scale_log2),
320
+ # (minus_row_max_scaled, minus_row_max_scaled),
321
+ # )
322
+ # )
323
+ # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
324
+ # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
325
+ acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
326
+ acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
327
+ acc_S_row_converted_frg[None, j].store(
328
+ acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
329
+ )
330
+
331
+
332
+ @cute.jit
333
+ def floor_if_packed(
334
+ q_idx,
335
+ qhead_per_kvhead: cutlass.Constexpr[int],
336
+ ) -> cute.Tensor:
337
+ """Convert q_idx to packed format for Pack-GQA."""
338
+ if cutlass.const_expr(qhead_per_kvhead == 1):
339
+ return q_idx
340
+ return q_idx // qhead_per_kvhead
341
+
342
+
343
+ @cute.jit
344
+ def apply_score_mod_inner(
345
+ score_tensor,
346
+ index_tensor,
347
+ score_mod: cutlass.Constexpr,
348
+ batch_idx,
349
+ head_idx,
350
+ softmax_scale,
351
+ vec_size: cutlass.Constexpr,
352
+ qk_acc_dtype: cutlass.Constexpr,
353
+ aux_tensors,
354
+ fastdiv_mods,
355
+ seqlen_info: SeqlenInfoQK,
356
+ constant_q_idx: cutlass.Constexpr,
357
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
358
+ transpose_indices: cutlass.Constexpr[bool] = False,
359
+ ):
360
+ """Shared implementation for applying score modification.
361
+
362
+ Args:
363
+ score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100)
364
+ index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100)
365
+ score_mod: The score modification function to apply
366
+ batch_idx: Batch index
367
+ head_idx: Head index
368
+ softmax_scale: Scale to apply
369
+ vec_size: Vector size for processing elements
370
+ qk_acc_dtype: Data type for accumulator
371
+ aux_tensors: Optional aux_tensors for FlexAttention
372
+ fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
373
+ seqlen_info: Sequence length info
374
+ constant_q_idx: If provided, use this constant for all q_idx values
375
+ If None, compute q_idx per-element
376
+ qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this
377
+ when greater than 1 so score mods see logical heads.
378
+ transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed)
379
+ """
380
+ # Index positions in the index_tensor tuple
381
+ # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
382
+ # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
383
+ if cutlass.const_expr(transpose_indices):
384
+ q_idx_pos = cutlass.const_expr(1)
385
+ kv_idx_pos = cutlass.const_expr(0)
386
+ else:
387
+ q_idx_pos = cutlass.const_expr(0)
388
+ kv_idx_pos = cutlass.const_expr(1)
389
+
390
+ n_vals = cutlass.const_expr(cute.size(score_tensor.shape))
391
+ score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype)
392
+ kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
393
+
394
+ # SSA values for batch (constant across all elements)
395
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
396
+
397
+ # Handle q_idx based on whether it's constant
398
+ q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
399
+
400
+ # For Pack-GQA with non-constant q_idx, we need per-element head indices
401
+ # since a thread my process multiple query head indices
402
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
403
+ head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)
404
+
405
+ for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
406
+ for j in cutlass.range(vec_size, unroll_full=True):
407
+ score_vec[j] = score_tensor[i + j] * softmax_scale
408
+
409
+ # Extract head offset from packed q_idx for Pack-GQA
410
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
411
+ q_idx_packed = index_tensor[i + j][q_idx_pos]
412
+ # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)
413
+ q_idx_logical = q_idx_packed // qhead_per_kvhead
414
+ head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
415
+ head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
416
+
417
+ # If we will do loads we mod, in order to not read OOB
418
+ if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
419
+ if cutlass.const_expr(constant_q_idx is None):
420
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
421
+ q_idx_floored = floor_if_packed(
422
+ index_tensor[i + j][q_idx_pos], qhead_per_kvhead
423
+ )
424
+ _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
425
+ q_idx_vec[j] = q_idx_wrapped
426
+ else:
427
+ _, seqlen_k_divmod = fastdiv_mods
428
+
429
+ _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
430
+ kv_idx_vec[j] = kv_idx_wrapped
431
+ else:
432
+ # No bounds checking - direct indexing
433
+ if constant_q_idx is None:
434
+ q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
435
+ kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
436
+
437
+ # Convert to SSA for score_mod call
438
+ score_ssa = score_vec.load()
439
+ kv_idx_ssa = kv_idx_vec.load()
440
+ if cutlass.const_expr(constant_q_idx is None):
441
+ q_idx_ssa = q_idx_vec.load()
442
+ else:
443
+ # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical
444
+ q_idx_const = constant_q_idx
445
+ q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,))
446
+
447
+ # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise
448
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
449
+ head_idx_ssa = head_idx_vec.load()
450
+ else:
451
+ head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
452
+
453
+ aux_args = []
454
+ if cutlass.const_expr(aux_tensors is not None):
455
+ aux_args = aux_tensors
456
+
457
+ post_mod_scores = score_mod(
458
+ score_ssa,
459
+ batch_idx_ssa,
460
+ head_idx_ssa,
461
+ q_idx=q_idx_ssa,
462
+ kv_idx=kv_idx_ssa,
463
+ seqlen_info=seqlen_info,
464
+ aux_tensors=aux_args,
465
+ )
466
+
467
+ # Write back modified scores
468
+ score_vec.store(post_mod_scores)
469
+ for j in cutlass.range(vec_size, unroll_full=True):
470
+ score_tensor[i + j] = score_vec[j]
471
+
472
+
473
+ @cute.jit
474
+ def apply_score_mod_bwd_inner(
475
+ grad_tensor,
476
+ score_tensor,
477
+ index_tensor,
478
+ score_mod_bwd: cutlass.Constexpr,
479
+ batch_idx,
480
+ head_idx,
481
+ softmax_scale,
482
+ vec_size: cutlass.Constexpr,
483
+ qk_acc_dtype: cutlass.Constexpr,
484
+ aux_tensors,
485
+ fastdiv_mods,
486
+ seqlen_info,
487
+ constant_q_idx: cutlass.Constexpr,
488
+ qhead_per_kvhead: cutlass.Constexpr[int] = 1,
489
+ transpose_indices: cutlass.Constexpr[bool] = False,
490
+ ):
491
+ """Apply backward score modification (joint graph).
492
+
493
+ Args:
494
+ grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores)
495
+ score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally
496
+ index_tensor: Index positions (same as forward)
497
+ score_mod_bwd: The backward score modification function (joint graph)
498
+ batch_idx: Batch index
499
+ head_idx: Head index
500
+ softmax_scale: Scale to apply to score_tensor
501
+ vec_size: Vector size for processing elements
502
+ qk_acc_dtype: Data type for accumulator
503
+ aux_tensors: Optional aux_tensors for FlexAttention
504
+ fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
505
+ seqlen_info: Sequence length info
506
+ constant_q_idx: If provided, use this constant for all q_idx values
507
+ qhead_per_kvhead: Pack-GQA replication factor
508
+ transpose_indices: If True, swap q_idx/kv_idx in index_tensor
509
+ """
510
+ # Index positions in the index_tensor tuple
511
+ # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx
512
+ # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx
513
+ if cutlass.const_expr(transpose_indices):
514
+ q_idx_pos = cutlass.const_expr(1)
515
+ kv_idx_pos = cutlass.const_expr(0)
516
+ else:
517
+ q_idx_pos = cutlass.const_expr(0)
518
+ kv_idx_pos = cutlass.const_expr(1)
519
+ n_vals = cutlass.const_expr(cute.size(grad_tensor.shape))
520
+ grad_vec = cute.make_fragment(vec_size, qk_acc_dtype)
521
+ score_vec = cute.make_fragment(vec_size, qk_acc_dtype)
522
+ kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
523
+ batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))
524
+ q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
525
+
526
+ # For Pack-GQA with non-constant q_idx, we need per-element head indices
527
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
528
+ head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
529
+
530
+ for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
531
+ for j in cutlass.range(vec_size, unroll_full=True):
532
+ grad_vec[j] = grad_tensor[i + j]
533
+ # Scale score so joint graph sees same value as forward score_mod
534
+ score_vec[j] = score_tensor[i + j] * softmax_scale
535
+
536
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
537
+ q_idx_packed = index_tensor[i + j][q_idx_pos]
538
+ q_idx_logical = q_idx_packed // qhead_per_kvhead
539
+ head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead
540
+ head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset
541
+
542
+ if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):
543
+ if cutlass.const_expr(constant_q_idx is None):
544
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
545
+ q_idx_floored = floor_if_packed(
546
+ index_tensor[i + j][q_idx_pos], qhead_per_kvhead
547
+ )
548
+ _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)
549
+ q_idx_vec[j] = q_idx_wrapped
550
+ else:
551
+ _, seqlen_k_divmod = fastdiv_mods
552
+
553
+ _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)
554
+ kv_idx_vec[j] = kv_idx_wrapped
555
+ else:
556
+ # No bounds checking - direct indexing
557
+ if constant_q_idx is None:
558
+ q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)
559
+ kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]
560
+
561
+ grad_ssa = grad_vec.load()
562
+ score_ssa = score_vec.load()
563
+ kv_idx_ssa = kv_idx_vec.load()
564
+
565
+ if cutlass.const_expr(constant_q_idx is None):
566
+ q_idx_ssa = q_idx_vec.load()
567
+ else:
568
+ q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,))
569
+
570
+ if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
571
+ head_idx_ssa = head_idx_vec.load()
572
+ else:
573
+ head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))
574
+
575
+ aux_args = []
576
+ if cutlass.const_expr(aux_tensors is not None):
577
+ aux_args = aux_tensors
578
+
579
+ grad_out_ssa = score_mod_bwd(
580
+ grad_ssa,
581
+ score_ssa,
582
+ batch_idx_ssa,
583
+ head_idx_ssa,
584
+ q_idx=q_idx_ssa,
585
+ kv_idx=kv_idx_ssa,
586
+ seqlen_info=seqlen_info,
587
+ aux_tensors=aux_args,
588
+ )
589
+
590
+ grad_vec.store(grad_out_ssa)
591
+ for j in cutlass.range(vec_size, unroll_full=True):
592
+ grad_tensor[i + j] = grad_vec[j]
build/torch-cuda/testing.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import nullcontext
3
+ from functools import wraps
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch._guards import active_fake_mode
10
+ from torch._subclasses.fake_tensor import FakeTensorMode
11
+
12
+
13
+ class IndexFirstAxis(torch.autograd.Function):
14
+ @staticmethod
15
+ def forward(ctx, input, indices):
16
+ ctx.save_for_backward(indices)
17
+ assert input.ndim >= 2
18
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
19
+ second_dim = other_shape.numel()
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"),
22
+ 0,
23
+ repeat(indices, "z -> z d", d=second_dim),
24
+ ).reshape(-1, *other_shape)
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ (indices,) = ctx.saved_tensors
29
+ assert grad_output.ndim >= 2
30
+ other_shape = grad_output.shape[1:]
31
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
32
+ grad_input = torch.zeros(
33
+ [ctx.first_axis_dim, grad_output.shape[1]],
34
+ device=grad_output.device,
35
+ dtype=grad_output.dtype,
36
+ )
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ output[indices] = values
54
+ return output
55
+
56
+ @staticmethod
57
+ def backward(ctx, grad_output):
58
+ (indices,) = ctx.saved_tensors
59
+ grad_values = grad_output[indices]
60
+ return grad_values, None, None
61
+
62
+
63
+ index_put_first_axis = IndexPutFirstAxis.apply
64
+
65
+
66
+ def unpad_input(hidden_states, attention_mask, unused_mask=None):
67
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
68
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
69
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ in_fake_mode = active_fake_mode() is not None
71
+ if not in_fake_mode:
72
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ else:
75
+ # torch.nonzero and .item() are not supported in FakeTensorMode
76
+ batch_size, seqlen = attention_mask.shape
77
+ indices = torch.arange(batch_size * seqlen, device=hidden_states.device)
78
+ max_seqlen_in_batch = seqlen
79
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
80
+ return (
81
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
82
+ indices,
83
+ cu_seqlens,
84
+ max_seqlen_in_batch,
85
+ used_seqlens_in_batch,
86
+ )
87
+
88
+
89
+ def pad_input(hidden_states, indices, batch, seqlen):
90
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
91
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
92
+
93
+
94
+ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
95
+ assert mode in ["full", "random", "third"]
96
+ if mode == "full":
97
+ lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
98
+ elif mode == "random":
99
+ lengths = torch.randint(
100
+ max(0 if zero_lengths else 1, max_seqlen - 20),
101
+ max_seqlen + 1,
102
+ (batch_size, 1),
103
+ device=device,
104
+ )
105
+ else:
106
+ lengths = torch.randint(
107
+ max(0 if zero_lengths else 1, max_seqlen // 3),
108
+ max_seqlen + 1,
109
+ (batch_size, 1),
110
+ device=device,
111
+ )
112
+
113
+ if zero_lengths:
114
+ for i in range(batch_size):
115
+ if i % 5 == 0:
116
+ lengths[i] = 0
117
+ lengths[-1] = 0
118
+ padding_mask = (
119
+ repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
120
+ )
121
+ return padding_mask
122
+
123
+
124
+ def generate_qkv(
125
+ q,
126
+ k,
127
+ v,
128
+ query_padding_mask=None,
129
+ key_padding_mask=None,
130
+ qv=None,
131
+ kvpacked=False,
132
+ qkvpacked=False,
133
+ query_unused_mask=None,
134
+ key_unused_mask=None,
135
+ ):
136
+ assert not (kvpacked and qkvpacked)
137
+ batch_size, seqlen_q, nheads, d = q.shape
138
+ d_v = v.shape[-1]
139
+ _, seqlen_k, nheads_k, _ = k.shape
140
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
141
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)
142
+ if query_unused_mask is not None or key_unused_mask is not None:
143
+ assert not kvpacked
144
+ assert not qkvpacked
145
+
146
+ if query_padding_mask is not None:
147
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
148
+ q, query_padding_mask, query_unused_mask
149
+ )
150
+ output_pad_fn = lambda output_unpad: pad_input(
151
+ output_unpad, indices_q, batch_size, seqlen_q
152
+ )
153
+ qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
154
+ else:
155
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
156
+ cu_seqlens_q = torch.arange(
157
+ 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
158
+ )
159
+ seqused_q = None
160
+ max_seqlen_q = seqlen_q
161
+ output_pad_fn = lambda output_unpad: rearrange(
162
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
163
+ )
164
+ qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
165
+
166
+ if key_padding_mask is not None:
167
+ k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
168
+ k, key_padding_mask, key_unused_mask
169
+ )
170
+ v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask)
171
+ else:
172
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
173
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
174
+ cu_seqlens_k = torch.arange(
175
+ 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
176
+ )
177
+ seqused_k = None
178
+ max_seqlen_k = seqlen_k
179
+
180
+ if qkvpacked:
181
+ assert (query_padding_mask == key_padding_mask).all()
182
+ assert nheads == nheads_k
183
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
184
+ qkv = torch.stack([q, k, v], dim=2)
185
+ if query_padding_mask is not None:
186
+ dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
187
+ else:
188
+ dqkv_pad_fn = lambda dqkv_unpad: rearrange(
189
+ dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
190
+ )
191
+ return (
192
+ qkv_unpad.detach().requires_grad_(),
193
+ cu_seqlens_q,
194
+ max_seqlen_q,
195
+ qkv.detach().requires_grad_(),
196
+ output_pad_fn,
197
+ dqkv_pad_fn,
198
+ )
199
+ elif kvpacked:
200
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
201
+ kv = torch.stack([k, v], dim=2)
202
+ dq_pad_fn = output_pad_fn
203
+ if key_padding_mask is not None:
204
+ dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
205
+ else:
206
+ dkv_pad_fn = lambda dkv_unpad: rearrange(
207
+ dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
208
+ )
209
+ return (
210
+ q_unpad.detach().requires_grad_(),
211
+ kv_unpad.detach().requires_grad_(),
212
+ cu_seqlens_q,
213
+ cu_seqlens_k,
214
+ max_seqlen_q,
215
+ max_seqlen_k,
216
+ q.detach().requires_grad_(),
217
+ kv.detach().requires_grad_(),
218
+ output_pad_fn,
219
+ dq_pad_fn,
220
+ dkv_pad_fn,
221
+ )
222
+ else:
223
+ dq_pad_fn = output_pad_fn
224
+ if key_padding_mask is not None:
225
+ dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
226
+ else:
227
+ dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
228
+ return (
229
+ q_unpad.detach().requires_grad_(),
230
+ k_unpad.detach().requires_grad_(),
231
+ v_unpad.detach().requires_grad_(),
232
+ qv_unpad.detach() if qv is not None else None,
233
+ cu_seqlens_q,
234
+ cu_seqlens_k,
235
+ seqused_q,
236
+ seqused_k,
237
+ max_seqlen_q,
238
+ max_seqlen_k,
239
+ q.detach().requires_grad_(),
240
+ k.detach().requires_grad_(),
241
+ v.detach().requires_grad_(),
242
+ qv.detach() if qv is not None else None,
243
+ output_pad_fn,
244
+ dq_pad_fn,
245
+ dk_pad_fn,
246
+ )
247
+
248
+
249
+ def construct_local_mask(
250
+ seqlen_q,
251
+ seqlen_k,
252
+ window_size=(None, None),
253
+ sink_token_length=0,
254
+ query_padding_mask=None,
255
+ key_padding_mask=None,
256
+ key_leftpad=None,
257
+ device=None,
258
+ ):
259
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
260
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
261
+ if key_leftpad is not None:
262
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
263
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
264
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
265
+ sk = (
266
+ seqlen_k
267
+ if key_padding_mask is None
268
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
269
+ )
270
+ sq = (
271
+ seqlen_q
272
+ if query_padding_mask is None
273
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
274
+ )
275
+ if window_size[0] is None:
276
+ return col_idx > row_idx + sk - sq + window_size[1]
277
+ else:
278
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
279
+ if window_size[1] is None:
280
+ local_mask_left = col_idx > sk
281
+ else:
282
+ local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk)
283
+ return torch.logical_or(
284
+ local_mask_left,
285
+ torch.logical_and(
286
+ col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
287
+ ),
288
+ )
289
+
290
+
291
+ def construct_chunk_mask(
292
+ seqlen_q,
293
+ seqlen_k,
294
+ attention_chunk,
295
+ query_padding_mask=None,
296
+ key_padding_mask=None,
297
+ key_leftpad=None,
298
+ device=None,
299
+ ):
300
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
301
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
302
+ if key_leftpad is not None:
303
+ key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
304
+ col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
305
+ col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
306
+ sk = (
307
+ seqlen_k
308
+ if key_padding_mask is None
309
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
310
+ )
311
+ sq = (
312
+ seqlen_q
313
+ if query_padding_mask is None
314
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
315
+ )
316
+ sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
317
+ col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk
318
+ return torch.logical_or(
319
+ col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk
320
+ )
321
+
322
+
323
+ def attention_ref(
324
+ q,
325
+ k,
326
+ v,
327
+ query_padding_mask=None,
328
+ key_padding_mask=None,
329
+ key_leftpad=None,
330
+ attn_bias=None,
331
+ dropout_p=0.0,
332
+ dropout_mask=None,
333
+ causal=False,
334
+ qv=None,
335
+ q_descale=None,
336
+ k_descale=None,
337
+ v_descale=None,
338
+ window_size=(None, None),
339
+ attention_chunk=0,
340
+ sink_token_length=0,
341
+ learnable_sink: Optional[torch.Tensor] = None,
342
+ softcap=0.0,
343
+ upcast=True,
344
+ reorder_ops=False,
345
+ intermediate_dtype=None,
346
+ ):
347
+ if causal:
348
+ window_size = (window_size[0], 0)
349
+ dtype_og = q.dtype
350
+ if upcast:
351
+ q, k, v = q.float(), k.float(), v.float()
352
+ qv = qv.float() if qv is not None else None
353
+ if q_descale is not None:
354
+ q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2])
355
+ q = (q.float() * q_descale).to(q.dtype)
356
+ qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None
357
+ if k_descale is not None:
358
+ k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype)
359
+ if v_descale is not None:
360
+ v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype)
361
+ seqlen_q, seqlen_k = q.shape[1], k.shape[1]
362
+ k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
363
+ v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
364
+ d = q.shape[-1]
365
+ dv = v.shape[-1]
366
+ softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)
367
+ if not reorder_ops:
368
+ scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
369
+ else:
370
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
371
+ if qv is not None:
372
+ scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v)
373
+ if softcap > 0:
374
+ scores = torch.tanh(scores / softcap) * softcap
375
+ if key_padding_mask is not None:
376
+ scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
377
+ local_mask = None
378
+ if window_size[0] is not None or window_size[1] is not None:
379
+ local_mask = construct_local_mask(
380
+ seqlen_q,
381
+ seqlen_k,
382
+ window_size,
383
+ sink_token_length,
384
+ query_padding_mask,
385
+ key_padding_mask,
386
+ key_leftpad=key_leftpad,
387
+ device=q.device,
388
+ )
389
+ if attention_chunk > 0:
390
+ chunk_mask = construct_chunk_mask(
391
+ seqlen_q,
392
+ seqlen_k,
393
+ attention_chunk,
394
+ query_padding_mask,
395
+ key_padding_mask,
396
+ key_leftpad=key_leftpad,
397
+ device=q.device,
398
+ )
399
+ local_mask = (
400
+ torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
401
+ )
402
+ if local_mask is not None:
403
+ scores.masked_fill_(local_mask, float("-inf"))
404
+ if attn_bias is not None:
405
+ scores = scores + attn_bias
406
+ if learnable_sink is None:
407
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
408
+ else:
409
+ scores_fp32 = scores.to(torch.float32)
410
+ logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
411
+ learnable_sink = rearrange(learnable_sink, "h -> h 1 1")
412
+ logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)
413
+ unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
414
+ normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
415
+ learnable_sink - logits_or_sinks_max
416
+ )
417
+ attention = (unnormalized_scores / normalizer).to(v.dtype)
418
+ if query_padding_mask is not None:
419
+ attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
420
+ if key_padding_mask is not None:
421
+ attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
422
+ if local_mask is not None:
423
+ attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
424
+ dropout_scaling = 1.0 / (1 - dropout_p)
425
+ if dropout_mask is not None:
426
+ attention_drop = attention.masked_fill(~dropout_mask, 0.0)
427
+ else:
428
+ attention_drop = attention
429
+ if intermediate_dtype is not None:
430
+ attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
431
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
432
+ if query_padding_mask is not None:
433
+ output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
434
+ return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
435
+
436
+
437
+ def maybe_fake_tensor_mode(fake: bool = True):
438
+ """
439
+ One way to populate/pre-compile cache is to use torch fake tensor mode,
440
+ which does not allocate actual GPU tensors but retains tensor shape/dtype
441
+ metadata for cute.compile.
442
+ """
443
+
444
+ def decorator(fn):
445
+ @wraps(fn)
446
+ def wrapper(*args, **kwargs):
447
+ with FakeTensorMode() if fake else nullcontext():
448
+ return fn(*args, **kwargs)
449
+
450
+ return wrapper
451
+
452
+ return decorator
453
+
454
+
455
+ def is_fake_mode() -> bool:
456
+ return active_fake_mode() is not None
build/torch-cuda/tile_scheduler.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional, Tuple
4
+ from dataclasses import dataclass
5
+
6
+ try:
7
+ from typing import override
8
+ except ImportError: # Python < 3.12
9
+ from typing_extensions import override
10
+
11
+ import cutlass
12
+ from cutlass._mlir import ir
13
+ import cutlass.cute as cute
14
+ from cutlass import Int32, const_expr
15
+ from cutlass.cute import FastDivmodDivisor
16
+
17
+ from .quack.cute_dsl_utils import ParamsBase
18
+
19
+ from . import utils
20
+ from .fast_math import clz
21
+
22
+
23
+ class WorkTileInfo(cutlass.utils.WorkTileInfo):
24
+ """Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
25
+
26
+ @override
27
+ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
28
+ assert len(values) == 5
29
+ new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])
30
+ new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])
31
+ return WorkTileInfo(new_tile_idx, new_is_valid_tile)
32
+
33
+
34
+ @dataclass
35
+ class TileSchedulerArguments(ParamsBase):
36
+ num_block: Int32
37
+ num_head: Int32
38
+ num_batch: Int32
39
+ num_splits: Int32
40
+ seqlen_k: Int32
41
+ headdim: Int32
42
+ headdim_v: Int32
43
+ total_q: Int32
44
+ tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
45
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
46
+ mCuSeqlensQ: Optional[cute.Tensor] = None
47
+ mSeqUsedQ: Optional[cute.Tensor] = None
48
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
49
+ element_size: cutlass.Constexpr[int] = 2
50
+ is_persistent: cutlass.Constexpr[bool] = False
51
+ lpt: cutlass.Constexpr[bool] = False
52
+ is_split_kv: cutlass.Constexpr[bool] = False
53
+ head_swizzle: cutlass.Constexpr[bool] = False
54
+
55
+
56
+ class SingleTileScheduler:
57
+ @dataclass
58
+ class Params(ParamsBase):
59
+ num_block: Int32
60
+ num_head: Int32
61
+ num_batch: Int32
62
+ num_splits: Int32
63
+ num_splits_divmod: FastDivmodDivisor
64
+ is_split_kv: cutlass.Constexpr[bool] = False
65
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
66
+
67
+ @staticmethod
68
+ def create(
69
+ args: TileSchedulerArguments, *, loc=None, ip=None
70
+ ) -> "SingleTileScheduler.Params":
71
+ return SingleTileScheduler.Params(
72
+ args.num_block,
73
+ args.num_head,
74
+ args.num_batch,
75
+ args.num_splits,
76
+ FastDivmodDivisor(args.num_splits),
77
+ args.is_split_kv,
78
+ args.cluster_shape_mn,
79
+ )
80
+
81
+ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
82
+ self.params = params
83
+ self._blk_coord = blk_coord
84
+ self._is_first_block = True
85
+ self._loc = loc
86
+ self._ip = ip
87
+
88
+ @staticmethod
89
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
90
+ return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
91
+
92
+ @staticmethod
93
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler":
94
+ # if const_expr(cute.size(params.cluster_shape_mn) == 1):
95
+ # blk_coord = cute.arch.block_idx()
96
+ # else:
97
+ # # All CTAs in a cluster must get the same block coordinate
98
+ # blk_coord = cute.arch.cluster_idx()
99
+ # Temporary set to block_idx until we sort out the best way to handle cluster
100
+ blk_coord = cute.arch.block_idx()
101
+ return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
102
+
103
+ # called by host
104
+ @staticmethod
105
+ def get_grid_shape(
106
+ params: Params,
107
+ *,
108
+ loc=None,
109
+ ip=None,
110
+ ) -> Tuple[Int32, Int32, Int32]:
111
+ # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
112
+ assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
113
+ return (
114
+ cute.round_up(params.num_block, params.cluster_shape_mn[0]),
115
+ params.num_head * params.num_splits,
116
+ params.num_batch,
117
+ )
118
+
119
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
120
+ block_idx, head_idx, batch_idx = self._blk_coord
121
+ if const_expr(self.params.is_split_kv):
122
+ head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)
123
+ else:
124
+ split_idx = Int32(0)
125
+ return WorkTileInfo(
126
+ (block_idx, head_idx, batch_idx, split_idx),
127
+ self._is_first_block,
128
+ )
129
+
130
+ def initial_work_tile_info(self, *, loc=None, ip=None):
131
+ return self.get_current_work(loc=loc, ip=ip)
132
+
133
+ def prefetch_next_work(self, *, loc=None, ip=None):
134
+ pass
135
+
136
+ def advance_to_next_work(self, *, loc=None, ip=None):
137
+ self._is_first_block = False
138
+
139
+ def __extract_mlir_values__(self):
140
+ values, self._values_pos = [], []
141
+ for obj in [self.params, self._blk_coord]:
142
+ obj_values = cutlass.extract_mlir_values(obj)
143
+ values += obj_values
144
+ self._values_pos.append(len(obj_values))
145
+ return values
146
+
147
+ def __new_from_mlir_values__(self, values):
148
+ obj_list = []
149
+ for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):
150
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
151
+ values = values[n_items:]
152
+ return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)
153
+
154
+
155
+ class StaticPersistentTileScheduler:
156
+ @dataclass
157
+ class Params(ParamsBase):
158
+ num_block_cluster_divmod: FastDivmodDivisor
159
+ num_head_divmod: FastDivmodDivisor
160
+ total_blocks_cluster: Int32
161
+ cluster_shape_m: cutlass.Constexpr[int] = 1
162
+
163
+ @staticmethod
164
+ def create(
165
+ args: TileSchedulerArguments, *, loc=None, ip=None
166
+ ) -> "StaticPersistentTileScheduler.Params":
167
+ num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn))
168
+ total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch
169
+ return StaticPersistentTileScheduler.Params(
170
+ FastDivmodDivisor(num_block_cluster),
171
+ FastDivmodDivisor(args.num_head),
172
+ total_blocks_cluster,
173
+ cluster_shape_m=args.cluster_shape_mn[0],
174
+ )
175
+
176
+ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
177
+ self.params = params
178
+ self._tile_idx = tile_idx
179
+ self._loc = loc
180
+ self._ip = ip
181
+
182
+ @staticmethod
183
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
184
+ return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
185
+
186
+ @staticmethod
187
+ def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler":
188
+ if const_expr(cute.size(params.cluster_shape_m) == 1):
189
+ tile_idx = cute.arch.block_idx()[0]
190
+ else:
191
+ tile_idx = cute.arch.cluster_idx()[0]
192
+ return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
193
+
194
+ # called by host
195
+ @staticmethod
196
+ def get_grid_shape(
197
+ params: Params,
198
+ *,
199
+ loc=None,
200
+ ip=None,
201
+ ) -> Tuple[Int32, Int32, Int32]:
202
+ hardware_info = cutlass.utils.HardwareInfo()
203
+ sm_count = hardware_info.get_device_multiprocessor_count()
204
+ # Grid must be a multiple of cluster_shape_m for CUDA cluster launch.
205
+ max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
206
+ grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)
207
+ return (grid_x, Int32(1), Int32(1))
208
+
209
+ # @cute.jit
210
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
211
+ hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
212
+ batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
213
+ is_valid = self._tile_idx < self.params.total_blocks_cluster
214
+ # if cute.arch.thread_idx()[0] == 0:
215
+ # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)
216
+ return WorkTileInfo(
217
+ (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
218
+ )
219
+
220
+ def initial_work_tile_info(self, *, loc=None, ip=None):
221
+ return self.get_current_work(loc=loc, ip=ip)
222
+
223
+ def prefetch_next_work(self, *, loc=None, ip=None):
224
+ pass
225
+
226
+ def advance_to_next_work(self, *, loc=None, ip=None):
227
+ if const_expr(self.params.cluster_shape_m == 1):
228
+ self._tile_idx += cute.arch.grid_dim()[0]
229
+ else:
230
+ self._tile_idx += cute.arch.cluster_dim()[0]
231
+
232
+ def __extract_mlir_values__(self):
233
+ values, self._values_pos = [], []
234
+ for obj in [self.params, self._tile_idx]:
235
+ obj_values = cutlass.extract_mlir_values(obj)
236
+ values += obj_values
237
+ self._values_pos.append(len(obj_values))
238
+ return values
239
+
240
+ def __new_from_mlir_values__(self, values):
241
+ obj_list = []
242
+ for obj, n_items in zip(
243
+ [self.params, self._tile_idx],
244
+ self._values_pos,
245
+ ):
246
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
247
+ values = values[n_items:]
248
+ return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)
249
+
250
+
251
+ class SingleTileLPTScheduler:
252
+ @dataclass
253
+ class Params(ParamsBase):
254
+ total_blocks: Int32
255
+ num_splits: Int32
256
+ num_block: Int32
257
+ l2_minor: Int32
258
+ num_block_divmod: FastDivmodDivisor
259
+ num_head_divmod: FastDivmodDivisor
260
+ l2_minor_divmod: FastDivmodDivisor
261
+ l2_major_divmod: FastDivmodDivisor
262
+ l2_minor_residual_divmod: FastDivmodDivisor
263
+ num_hb_quotient: Int32
264
+ is_split_kv: cutlass.Constexpr[bool] = False
265
+
266
+ @staticmethod
267
+ @cute.jit
268
+ def create(
269
+ args: TileSchedulerArguments, *, loc=None, ip=None
270
+ ) -> "SingleTileLPTScheduler.Params":
271
+ # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size)
272
+ size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
273
+ size_one_head = size_one_kv_head
274
+ size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
275
+ # Swizzle is the size of each "section". Round swizzle to a power of 2
276
+ # Need to be careful about the case where only one head will fit
277
+ # swizzle is how many heads can fit in L2
278
+ # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
279
+ # Seems faster if swizzle if a power of 2
280
+ log2_floor = lambda n: 31 - clz(n)
281
+ swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
282
+ # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
283
+ # If we're in the last section (called residual), we don't want to divide by
284
+ # swizzle. Instead we want to divide by the remainder.
285
+ num_hb_quotient = (args.num_head * args.num_batch) // swizzle
286
+ num_hb_remainder = (args.num_head * args.num_batch) % swizzle
287
+ return SingleTileLPTScheduler.Params(
288
+ total_blocks=args.num_block * args.num_head * args.num_batch,
289
+ num_block=args.num_block,
290
+ l2_minor=Int32(swizzle),
291
+ num_block_divmod=FastDivmodDivisor(args.num_block),
292
+ num_head_divmod=FastDivmodDivisor(args.num_head),
293
+ l2_minor_divmod=FastDivmodDivisor(swizzle),
294
+ l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
295
+ l2_minor_residual_divmod=FastDivmodDivisor(
296
+ max(num_hb_remainder, 1)
297
+ ), # don't divide by 0
298
+ num_hb_quotient=Int32(num_hb_quotient),
299
+ num_splits=args.num_splits,
300
+ is_split_kv=args.is_split_kv,
301
+ )
302
+
303
+ def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
304
+ self.params = params
305
+ self._tile_idx = tile_idx
306
+ self._split_idx = split_idx
307
+ self._loc = loc
308
+ self._ip = ip
309
+
310
+ @staticmethod
311
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
312
+ return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip)
313
+
314
+ @staticmethod
315
+ @cute.jit
316
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler":
317
+ tile_idx, split_idx, _ = cute.arch.block_idx()
318
+ return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
319
+
320
+ # called by host
321
+ @staticmethod
322
+ def get_grid_shape(
323
+ params: Params,
324
+ *,
325
+ loc=None,
326
+ ip=None,
327
+ ) -> Tuple[Int32, Int32, Int32]:
328
+ return (params.total_blocks, params.num_splits, Int32(1))
329
+
330
+ @cute.jit
331
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
332
+ params = self.params
333
+ # Implement LPT scheduling coordinate calculation
334
+ bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
335
+ # If we're in the last section (called residual), we don't want to divide by
336
+ # swizzle. Instead we want to divide by the remainder.
337
+ block, bidhb_residual = 0, 0
338
+ if bidhb < params.num_hb_quotient:
339
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
340
+ else:
341
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
342
+ bidhb_actual = bidhb * params.l2_minor + bidhb_residual
343
+ batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
344
+ # Longest-processing-time-first
345
+ block = params.num_block - 1 - block
346
+ is_valid = self._tile_idx < params.total_blocks
347
+ return WorkTileInfo(
348
+ (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
349
+ )
350
+
351
+ def initial_work_tile_info(self, *, loc=None, ip=None):
352
+ return self.get_current_work(loc=loc, ip=ip)
353
+
354
+ def prefetch_next_work(self, *, loc=None, ip=None):
355
+ pass
356
+
357
+ def advance_to_next_work(self, *, loc=None, ip=None):
358
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
359
+ self._tile_idx = self.params.total_blocks
360
+
361
+ def __extract_mlir_values__(self):
362
+ values, self._values_pos = [], []
363
+ for obj in [self.params, self._tile_idx, self._split_idx]:
364
+ obj_values = cutlass.extract_mlir_values(obj)
365
+ values += obj_values
366
+ self._values_pos.append(len(obj_values))
367
+ return values
368
+
369
+ def __new_from_mlir_values__(self, values):
370
+ obj_list = []
371
+ for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos):
372
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
373
+ values = values[n_items:]
374
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
375
+
376
+
377
+ class SingleTileLPTBwdScheduler:
378
+ @dataclass
379
+ class Params(ParamsBase):
380
+ total_blocks: Int32
381
+ num_block: Int32
382
+ l2_minor: Int32
383
+ num_head_divmod: FastDivmodDivisor
384
+ l2_minor_divmod: FastDivmodDivisor
385
+ l2_major_divmod: FastDivmodDivisor
386
+ l2_minor_residual_divmod: FastDivmodDivisor
387
+ num_hb_quotient: Int32
388
+ cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
389
+ spt: cutlass.Constexpr[bool] = True
390
+
391
+ @staticmethod
392
+ @cute.jit
393
+ def create(
394
+ args: TileSchedulerArguments, *, loc=None, ip=None
395
+ ) -> "SingleTileLPTBwdScheduler.Params":
396
+ size_l2 = 50 * 1024 * 1024
397
+ size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
398
+ # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
399
+ size_one_dqaccum_head = 0
400
+ size_one_head = size_one_qdo_head + size_one_dqaccum_head
401
+ log2_floor = lambda n: 31 - clz(n)
402
+ swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
403
+ # swizzle = 8
404
+ # If we're in the last section (called residual), we don't want to divide by
405
+ # swizzle. Instead we want to divide by the remainder.
406
+ num_hb_quotient = (args.num_head * args.num_batch) // swizzle
407
+ num_hb_remainder = (args.num_head * args.num_batch) % swizzle
408
+ num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0])
409
+ return SingleTileLPTBwdScheduler.Params(
410
+ total_blocks=(num_block * args.cluster_shape_mn[0])
411
+ * args.num_head
412
+ * args.num_batch,
413
+ num_block=num_block,
414
+ l2_minor=Int32(swizzle),
415
+ num_head_divmod=FastDivmodDivisor(args.num_head),
416
+ l2_minor_divmod=FastDivmodDivisor(swizzle),
417
+ l2_major_divmod=FastDivmodDivisor(swizzle * num_block),
418
+ l2_minor_residual_divmod=FastDivmodDivisor(
419
+ max(num_hb_remainder, 1)
420
+ ), # don't divide by 0
421
+ num_hb_quotient=Int32(num_hb_quotient),
422
+ cluster_shape_mn=args.cluster_shape_mn,
423
+ spt=args.lpt,
424
+ )
425
+
426
+ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):
427
+ self.params = params
428
+ self._tile_idx = tile_idx
429
+ self._loc = loc
430
+ self._ip = ip
431
+
432
+ @staticmethod
433
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
434
+ return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
435
+
436
+ @staticmethod
437
+ @cute.jit
438
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler":
439
+ tile_idx = cute.arch.block_idx()[0]
440
+ return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip)
441
+
442
+ # called by host
443
+ @staticmethod
444
+ def get_grid_shape(
445
+ params: Params,
446
+ *,
447
+ loc=None,
448
+ ip=None,
449
+ ) -> Tuple[Int32, Int32, Int32]:
450
+ return (params.total_blocks, Int32(1), Int32(1))
451
+
452
+ @cute.jit
453
+ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
454
+ cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0]
455
+ params = self.params
456
+ # Implement LPT scheduling coordinate calculation
457
+ bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod)
458
+ # If we're in the last section (called residual), we don't want to divide by
459
+ # swizzle. Instead we want to divide by the remainder.
460
+ block, bidhb_residual = 0, 0
461
+ if bidhb < params.num_hb_quotient:
462
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)
463
+ else:
464
+ block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)
465
+ bidhb_actual = bidhb * params.l2_minor + bidhb_residual
466
+ batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
467
+ if cutlass.const_expr(params.spt):
468
+ block = params.num_block - 1 - block
469
+ if cutlass.const_expr(params.cluster_shape_mn[0] > 1):
470
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
471
+ block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0]
472
+ is_valid = self._tile_idx < params.total_blocks
473
+ return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid)
474
+
475
+ def initial_work_tile_info(self, *, loc=None, ip=None):
476
+ return self.get_current_work(loc=loc, ip=ip)
477
+
478
+ def prefetch_next_work(self, *, loc=None, ip=None):
479
+ pass
480
+
481
+ def advance_to_next_work(self, *, loc=None, ip=None):
482
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
483
+ self._tile_idx = self.params.total_blocks
484
+
485
+ def __extract_mlir_values__(self):
486
+ values, self._values_pos = [], []
487
+ for obj in [self.params, self._tile_idx]:
488
+ obj_values = cutlass.extract_mlir_values(obj)
489
+ values += obj_values
490
+ self._values_pos.append(len(obj_values))
491
+ return values
492
+
493
+ def __new_from_mlir_values__(self, values):
494
+ obj_list = []
495
+ for obj, n_items in zip([self.params, self._tile_idx], self._values_pos):
496
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
497
+ values = values[n_items:]
498
+ return self.__class__(*(tuple(obj_list)), loc=self._loc)
499
+
500
+
501
+ class SingleTileVarlenScheduler:
502
+ @dataclass
503
+ class Params(ParamsBase):
504
+ num_head: Int32
505
+ num_batch: Int32
506
+ total_q: Int32
507
+ num_splits: Int32
508
+ max_kvblock_in_l2: Int32
509
+ tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]
510
+ mCuSeqlensQ: Optional[cute.Tensor] = None
511
+ mSeqUsedQ: Optional[cute.Tensor] = None
512
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
513
+ lpt: cutlass.Constexpr[bool] = False
514
+ is_split_kv: cutlass.Constexpr[bool] = False
515
+ head_swizzle: cutlass.Constexpr[bool] = False
516
+ cluster_shape_m: cutlass.Constexpr[int] = 1
517
+
518
+ @staticmethod
519
+ @cute.jit
520
+ def create(
521
+ args: TileSchedulerArguments, *, loc=None, ip=None
522
+ ) -> "SingleTileVarlenScheduler.Params":
523
+ size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
524
+ max_kvblock_in_l2 = size_l2 // (
525
+ (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
526
+ )
527
+ assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
528
+ "At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
529
+ )
530
+ assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
531
+ return SingleTileVarlenScheduler.Params(
532
+ num_head=args.num_head,
533
+ num_batch=args.num_batch,
534
+ total_q=args.total_q,
535
+ num_splits=args.num_splits,
536
+ max_kvblock_in_l2=max_kvblock_in_l2,
537
+ tile_shape_mn=args.tile_shape_mn,
538
+ mCuSeqlensQ=args.mCuSeqlensQ,
539
+ mSeqUsedQ=args.mSeqUsedQ,
540
+ qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
541
+ lpt=args.lpt,
542
+ is_split_kv=args.is_split_kv,
543
+ head_swizzle=args.head_swizzle,
544
+ cluster_shape_m=args.cluster_shape_mn[0],
545
+ )
546
+
547
+ def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
548
+ self.params = params
549
+ self._tile_idx = tile_idx
550
+ self._split_idx = split_idx
551
+ self._is_first_block = True
552
+ self._loc = loc
553
+ self._ip = ip
554
+
555
+ @staticmethod
556
+ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
557
+ return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip)
558
+
559
+ @staticmethod
560
+ def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler":
561
+ tile_idx, split_idx, _ = cute.arch.block_idx()
562
+ return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
563
+
564
+ # called by host
565
+ @staticmethod
566
+ def get_grid_shape(
567
+ params: Params,
568
+ *,
569
+ loc=None,
570
+ ip=None,
571
+ ) -> Tuple[Int32, Int32, Int32]:
572
+ total_blocks_max = (
573
+ params.total_q
574
+ + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
575
+ ) // params.tile_shape_mn[0]
576
+ # round down to nearest multiple of cluster since odd excess is always padding
577
+ total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
578
+ return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
579
+
580
+ @cute.jit
581
+ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:
582
+ params = self.params
583
+ batch_idx = lane + bidb_start
584
+ if cutlass.const_expr(params.mSeqUsedQ is not None):
585
+ seqlen = Int32(0)
586
+ if batch_idx < params.num_batch:
587
+ seqlen = params.mSeqUsedQ[batch_idx]
588
+ else:
589
+ assert params.mCuSeqlensQ is not None
590
+ cur_cu_seqlen = Int32(0)
591
+ if batch_idx <= params.num_batch:
592
+ cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]
593
+ next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)
594
+ seqlen = next_cu_seqlen - cur_cu_seqlen
595
+ if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):
596
+ seqlen *= params.qhead_per_kvhead_packgqa
597
+ return (
598
+ cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m)
599
+ if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1
600
+ else Int32(0)
601
+ )
602
+
603
+ @cute.jit
604
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
605
+ params = self.params
606
+ lane_idx = cute.arch.lane_idx()
607
+ num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
608
+ num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
609
+ # Total number of blocks for the next 31 batches
610
+ m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)
611
+ # Same for all lanes
612
+ group_end_tile = m_blocks_in_group * params.num_head
613
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)
614
+ block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)
615
+ next_tile_idx = self._tile_idx // params.cluster_shape_m
616
+ while group_end_tile <= next_tile_idx:
617
+ batch_idx += cute.arch.WARP_SIZE - 1
618
+ if batch_idx >= params.num_batch:
619
+ batch_idx = Int32(params.num_batch)
620
+ group_end_tile = next_tile_idx + 1
621
+ else:
622
+ num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)
623
+ num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)
624
+ m_blocks_in_group = cute.arch.shuffle_sync(
625
+ num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1
626
+ )
627
+ group_end_tile += m_blocks_in_group * params.num_head
628
+ is_valid = False
629
+ if batch_idx >= params.num_batch:
630
+ block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)
631
+ else:
632
+ group_start_tile = group_end_tile - m_blocks_in_group * params.num_head
633
+ # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)
634
+ # The next problem to process is the first one that does not have ending tile position
635
+ # that is greater than or equal to tile index.
636
+ batch_idx_in_group = cute.arch.popc(
637
+ cute.arch.vote_ballot_sync(
638
+ group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx
639
+ )
640
+ )
641
+ batch_idx += batch_idx_in_group
642
+ num_m_blocks_prev_lane = (
643
+ 0
644
+ if batch_idx_in_group == 0
645
+ else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)
646
+ )
647
+ num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
648
+ mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
649
+ if cutlass.const_expr(params.lpt or params.head_swizzle):
650
+ # This is a version of the SingleTileLPTScheduler, complicated by the fact that
651
+ # the seqlen can vary per batch.
652
+ # TODO: is there any case where num_m_blocks is 0?
653
+ # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here
654
+ num_n_blocks = (
655
+ num_m_blocks
656
+ * params.tile_shape_mn[0]
657
+ // params.qhead_per_kvhead_packgqa
658
+ // params.tile_shape_mn[1]
659
+ )
660
+ # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)
661
+ # Seems faster to have this be a power of 2
662
+ nheads_in_l2 = (
663
+ 16
664
+ if num_n_blocks * 16 <= params.max_kvblock_in_l2
665
+ else (
666
+ 8
667
+ if num_n_blocks * 8 <= params.max_kvblock_in_l2
668
+ else (
669
+ 4
670
+ if num_n_blocks * 4 <= params.max_kvblock_in_l2
671
+ else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)
672
+ )
673
+ )
674
+ )
675
+ nheads_in_l2 = min(nheads_in_l2, params.num_head)
676
+ mh_in_l2 = nheads_in_l2 * num_m_blocks
677
+ section_idx = mh_block // mh_in_l2
678
+ l2_mod = mh_block - section_idx * mh_in_l2
679
+ # Deal with tail section
680
+ nheads_in_this_section = (
681
+ nheads_in_l2
682
+ if nheads_in_l2 * (section_idx + 1) <= params.num_head
683
+ else params.num_head - section_idx * nheads_in_l2
684
+ )
685
+ block = l2_mod // nheads_in_this_section
686
+ head_idx_residual = l2_mod - block * nheads_in_this_section
687
+ head_idx = section_idx * nheads_in_l2 + head_idx_residual
688
+ if cutlass.const_expr(params.lpt):
689
+ block = num_m_blocks - 1 - block
690
+ else:
691
+ head_idx = mh_block // num_m_blocks
692
+ block = mh_block - head_idx * num_m_blocks
693
+ is_valid = self._is_first_block and batch_idx < params.num_batch
694
+ if cutlass.const_expr(params.cluster_shape_m > 1):
695
+ bidx_in_cluster = cute.arch.block_in_cluster_idx()
696
+ block = block * params.cluster_shape_m + bidx_in_cluster[0]
697
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid)
698
+ split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
699
+ return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
700
+
701
+ def initial_work_tile_info(self, *, loc=None, ip=None):
702
+ return self.get_current_work(loc=loc, ip=ip)
703
+
704
+ def prefetch_next_work(self, *, loc=None, ip=None):
705
+ pass
706
+
707
+ def advance_to_next_work(self, *, loc=None, ip=None):
708
+ # Single tile scheduler - set to invalid tile_idx to indicate no more work
709
+ self._is_first_block = False
710
+
711
+ def __extract_mlir_values__(self):
712
+ values, self._values_pos = [], []
713
+ for obj in [self.params, self._tile_idx, self._split_idx]:
714
+ obj_values = cutlass.extract_mlir_values(obj)
715
+ values += obj_values
716
+ self._values_pos.append(len(obj_values))
717
+ return values
718
+
719
+ def __new_from_mlir_values__(self, values):
720
+ obj_list = []
721
+ for obj, n_items in zip(
722
+ [self.params, self._tile_idx, self._split_idx],
723
+ self._values_pos,
724
+ ):
725
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
726
+ values = values[n_items:]
727
+ return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc)
build/torch-cuda/utils.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ import hashlib
5
+ import inspect
6
+ from typing import Type, Callable, Optional, Tuple, overload
7
+
8
+ import cutlass
9
+ import cutlass.cute as cute
10
+
11
+ from cutlass import Float32, const_expr
12
+ from cutlass.cutlass_dsl import T, dsl_user_op
13
+ from cutlass._mlir.dialects import nvvm, llvm
14
+ from cutlass.cute.runtime import from_dlpack
15
+
16
+
17
+ from .quack import activation
18
+
19
+ _MIXER_ATTRS = ("__vec_size__",)
20
+
21
+ # Obtained from sollya:
22
+ # fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);
23
+ POLY_EX2 = {
24
+ 0: (1.0),
25
+ 1: (
26
+ 1.0,
27
+ 0.922497093677520751953125,
28
+ ),
29
+ 2: (
30
+ 1.0,
31
+ 0.6657850742340087890625,
32
+ 0.330107033252716064453125,
33
+ ),
34
+ 3: (
35
+ 1.0,
36
+ 0.695146143436431884765625,
37
+ 0.227564394474029541015625,
38
+ 0.077119089663028717041015625,
39
+ ),
40
+ 4: (
41
+ 1.0,
42
+ 0.693042695522308349609375,
43
+ 0.2412912547588348388671875,
44
+ 5.2225358784198760986328125e-2,
45
+ 1.3434938155114650726318359375e-2,
46
+ ),
47
+ 5: (
48
+ 1.0,
49
+ 0.693151414394378662109375,
50
+ 0.24016360938549041748046875,
51
+ 5.5802188813686370849609375e-2,
52
+ 9.01452265679836273193359375e-3,
53
+ 1.86810153536498546600341796875e-3,
54
+ ),
55
+ }
56
+
57
+
58
+ def _compute_base_hash(func: Callable) -> str:
59
+ """Compute hash from source code or bytecode and closure values."""
60
+ try:
61
+ data = inspect.getsource(func).encode()
62
+ except (OSError, TypeError):
63
+ if hasattr(func, "__code__") and func.__code__ is not None:
64
+ data = func.__code__.co_code
65
+ else:
66
+ data = repr(func).encode()
67
+
68
+ hasher = hashlib.sha256(data)
69
+
70
+ if hasattr(func, "__closure__") and func.__closure__ is not None:
71
+ for cell in func.__closure__:
72
+ hasher.update(repr(cell.cell_contents).encode())
73
+
74
+ return hasher.hexdigest()
75
+
76
+
77
+ def hash_callable(
78
+ func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
79
+ ) -> str:
80
+ """Hash a callable based on the source code or bytecode and closure values.
81
+ Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
82
+ attribute, that value is returned immediately as the base hash, then
83
+ metadata dunders are mixed in to produce the final dict-key hash.
84
+ set_cute_hash: whether or not to set func.__cute_hash__
85
+ """
86
+ # Resolve base hash
87
+ if hasattr(func, "__cute_hash__"):
88
+ base_hash = func.__cute_hash__
89
+ else:
90
+ # Unwrap decorated functions (e.g., cute.jit wrappers).
91
+ base_func = getattr(func, "__wrapped__", func)
92
+
93
+ if hasattr(base_func, "__cute_hash__"):
94
+ base_hash = base_func.__cute_hash__
95
+ else:
96
+ base_hash = _compute_base_hash(base_func)
97
+
98
+ if set_cute_hash:
99
+ base_func.__cute_hash__ = base_hash
100
+
101
+ # Mix in mutable metadata dunders
102
+ mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)
103
+
104
+ if all(v is None for v in mixer_values):
105
+ return base_hash
106
+
107
+ hasher = hashlib.sha256(base_hash.encode())
108
+
109
+ for attr, val in zip(_MIXER_ATTRS, mixer_values):
110
+ hasher.update(f"{attr}={val!r}".encode())
111
+
112
+ return hasher.hexdigest()
113
+
114
+
115
+ def create_softcap_scoremod(softcap_val):
116
+ inv_softcap = 1.0 / softcap_val
117
+
118
+ @cute.jit
119
+ def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
120
+ scores = acc_S_SSA * inv_softcap
121
+ return scores * cute.math.tanh(scores, fastmath=True)
122
+
123
+ return scoremod_premask_fn
124
+
125
+
126
+ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
127
+ return (
128
+ from_dlpack(x, assumed_align=alignment)
129
+ .mark_layout_dynamic(leading_dim=leading_dim)
130
+ .mark_compact_shape_dynamic(
131
+ mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
132
+ )
133
+ )
134
+
135
+
136
+ def convert_from_dlpack_leading_static(
137
+ x, leading_dim, alignment=16, static_modes=None, stride_order=None
138
+ ) -> cute.Tensor:
139
+ if stride_order is None:
140
+ stride_order = x.dim_order()
141
+ x_ = from_dlpack(x, assumed_align=alignment)
142
+ for i in range(x.ndim):
143
+ if i != leading_dim and (static_modes is None or i not in static_modes):
144
+ x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
145
+ return x_
146
+
147
+
148
+ def make_tiled_copy_A(
149
+ copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
150
+ ) -> cute.TiledCopy:
151
+ if const_expr(swapAB):
152
+ return cute.make_tiled_copy_B(copy_atom, tiled_mma)
153
+ else:
154
+ return cute.make_tiled_copy_A(copy_atom, tiled_mma)
155
+
156
+
157
+ def make_tiled_copy_B(
158
+ copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
159
+ ) -> cute.TiledCopy:
160
+ if const_expr(swapAB):
161
+ return cute.make_tiled_copy_A(copy_atom, tiled_mma)
162
+ else:
163
+ return cute.make_tiled_copy_B(copy_atom, tiled_mma)
164
+
165
+
166
+ def mma_make_fragment_A(
167
+ smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
168
+ ) -> cute.Tensor:
169
+ if const_expr(swapAB):
170
+ return mma_make_fragment_B(smem, thr_mma)
171
+ else:
172
+ return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
173
+
174
+
175
+ def mma_make_fragment_B(
176
+ smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
177
+ ) -> cute.Tensor:
178
+ if const_expr(swapAB):
179
+ return mma_make_fragment_A(smem, thr_mma)
180
+ else:
181
+ return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
182
+
183
+
184
+ def get_smem_store_atom(
185
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
186
+ ) -> cute.CopyAtom:
187
+ if const_expr(arch < 90 or element_type.width != 16):
188
+ return cute.make_copy_atom(
189
+ cute.nvgpu.CopyUniversalOp(),
190
+ element_type,
191
+ num_bits_per_copy=2 * element_type.width,
192
+ )
193
+ else:
194
+ return cute.make_copy_atom(
195
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
196
+ element_type,
197
+ )
198
+
199
+
200
+ @cute.jit
201
+ def warp_reduce(
202
+ val: cute.TensorSSA | cute.Numeric,
203
+ op: Callable,
204
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
205
+ ) -> cute.TensorSSA | cute.Numeric:
206
+ if const_expr(isinstance(val, cute.TensorSSA)):
207
+ res = cute.make_fragment(val.shape, val.dtype)
208
+ res.store(val)
209
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
210
+ res[i] = warp_reduce(res[i], op, width)
211
+ return res.load()
212
+ else:
213
+ for i in cutlass.range_constexpr(int(math.log2(width))):
214
+ val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
215
+ return val
216
+
217
+
218
+ @dsl_user_op
219
+ def fmax(
220
+ a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
221
+ ) -> Float32:
222
+ from cutlass import CUDA_VERSION
223
+
224
+ # * NVVM call based on nvvm version
225
+ if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
226
+ # Old API: requires explicit result type as first positional argument
227
+ return Float32(
228
+ nvvm.fmax(
229
+ T.f32(),
230
+ Float32(a).ir_value(loc=loc, ip=ip),
231
+ Float32(b).ir_value(loc=loc, ip=ip),
232
+ c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
233
+ loc=loc,
234
+ ip=ip,
235
+ )
236
+ )
237
+ else:
238
+ # New API: infers result type automatically
239
+ return Float32(
240
+ nvvm.fmax(
241
+ Float32(a).ir_value(loc=loc, ip=ip),
242
+ Float32(b).ir_value(loc=loc, ip=ip),
243
+ c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
244
+ loc=loc,
245
+ ip=ip,
246
+ )
247
+ )
248
+
249
+
250
+ @cute.jit
251
+ def fmax_reduce(
252
+ x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
253
+ ) -> Float32:
254
+ if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
255
+ # if const_expr(init_val is None):
256
+ # init_val = -cutlass.Float32.if
257
+ # return x.reduce(cute.ReductionOp.MAX, init_val, 0)
258
+ res = cute.make_fragment(x.shape, Float32)
259
+ res.store(x)
260
+ # local_max = [res[0], res[1]]
261
+ # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):
262
+ # local_max[0] = fmax(local_max[0], res[i + 0])
263
+ # local_max[1] = fmax(local_max[1], res[i + 1])
264
+ # local_max[0] = fmax(local_max[0], local_max[1])
265
+ # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
266
+ local_max = [res[0], res[1], res[2], res[3]]
267
+ for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
268
+ local_max[0] = fmax(local_max[0], res[i + 0])
269
+ local_max[1] = fmax(local_max[1], res[i + 1])
270
+ local_max[2] = fmax(local_max[2], res[i + 2])
271
+ local_max[3] = fmax(local_max[3], res[i + 3])
272
+ local_max[0] = fmax(local_max[0], local_max[1])
273
+ local_max[2] = fmax(local_max[2], local_max[3])
274
+ local_max[0] = fmax(local_max[0], local_max[2])
275
+ return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
276
+ else:
277
+ # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
278
+ # We instead force the 3-input max.
279
+ res = cute.make_fragment(x.shape, Float32)
280
+ res.store(x)
281
+ local_max_0 = (
282
+ fmax(init_val, res[0], res[1])
283
+ if const_expr(init_val is not None)
284
+ else fmax(res[0], res[1])
285
+ )
286
+ local_max = [
287
+ local_max_0,
288
+ fmax(res[2], res[3]),
289
+ fmax(res[4], res[5]),
290
+ fmax(res[6], res[7]),
291
+ ]
292
+ for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
293
+ local_max[0] = fmax(local_max[0], res[i], res[i + 1])
294
+ local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
295
+ local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
296
+ local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
297
+ local_max[0] = fmax(local_max[0], local_max[1])
298
+ return fmax(local_max[0], local_max[2], local_max[3])
299
+
300
+
301
+ @cute.jit
302
+ def fadd_reduce(
303
+ x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
304
+ ) -> Float32:
305
+ if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
306
+ if const_expr(init_val is None):
307
+ init_val = Float32.zero
308
+ return x.reduce(cute.ReductionOp.ADD, init_val, 0)
309
+ # res = cute.make_fragment(x.shape, Float32)
310
+ # res.store(x)
311
+ # local_sum = [res[0], res[1], res[2], res[3]]
312
+ # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
313
+ # local_sum[0] += res[i + 0]
314
+ # local_sum[1] += res[i + 1]
315
+ # local_sum[2] += res[i + 2]
316
+ # local_sum[3] += res[i + 3]
317
+ # local_sum[0] += local_sum[1]
318
+ # local_sum[2] += local_sum[3]
319
+ # local_sum[0] += local_sum[2]
320
+ # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val
321
+ else:
322
+ res = cute.make_fragment(x.shape, Float32)
323
+ res.store(x)
324
+ local_sum_0 = (
325
+ cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
326
+ # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))
327
+ if const_expr(init_val is not None)
328
+ else (res[0], res[1])
329
+ )
330
+ local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
331
+ for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
332
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
333
+ local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
334
+ local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
335
+ local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
336
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
337
+ local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
338
+ local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
339
+ return local_sum[0][0] + local_sum[0][1]
340
+
341
+
342
+ @dsl_user_op
343
+ def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
344
+ # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
345
+ # # cache_hint = cutlass.Int64(0x12F0000000000000)
346
+ # llvm.inline_asm(
347
+ # None,
348
+ # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],
349
+ # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
350
+ # "red.global.add.f32 [$0], $1;",
351
+ # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
352
+ # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
353
+ # "l,f",
354
+ # # "l,f,l",
355
+ # has_side_effects=True,
356
+ # is_align_stack=False,
357
+ # asm_dialect=llvm.AsmDialect.AD_ATT,
358
+ # )
359
+ nvvm.atomicrmw(
360
+ res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
361
+ )
362
+
363
+
364
+ @dsl_user_op
365
+ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
366
+ return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
367
+
368
+
369
+ @cute.jit
370
+ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
371
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
372
+ tApA = cute.make_fragment(
373
+ cute.make_layout(
374
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
375
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
376
+ ),
377
+ cutlass.Boolean,
378
+ )
379
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
380
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
381
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
382
+ return tApA
383
+
384
+
385
+ def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
386
+ warp_group_idx = cute.arch.thread_idx()[0] // 128
387
+ if const_expr(sync):
388
+ warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
389
+ return warp_group_idx
390
+
391
+
392
+ # @dsl_user_op
393
+ # def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:
394
+ # mask = cutlass.Int32(-1)
395
+ # return cutlass.Boolean(
396
+ # llvm.inline_asm(
397
+ # T.i32(),
398
+ # [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],
399
+ # ".pred p1, p2;\n"
400
+ # "setp.lt.f32 p1, $1, $2;\n"
401
+ # "vote.sync.any.pred p2, p1, $3;\n"
402
+ # "selp.u32 $0, 1, 0, p2;",
403
+ # # "selp.u32 $0, 1, 0, p1;",
404
+ # "=r,f,f,r",
405
+ # has_side_effects=False,
406
+ # is_align_stack=False,
407
+ # asm_dialect=llvm.AsmDialect.AD_ATT,
408
+ # )
409
+ # )
410
+
411
+
412
+ @cute.jit
413
+ def shuffle_sync(
414
+ value: cute.Numeric,
415
+ offset: cute.typing.Int,
416
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
417
+ ) -> cute.Numeric:
418
+ assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
419
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
420
+ mask = cute.arch.WARP_SIZE - width
421
+ clamp = cute.arch.WARP_SIZE - 1
422
+ mask_and_clamp = mask << 8 | clamp
423
+ # important: need stride 1 and not 0 for recast_tensor to work
424
+ val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
425
+ val[0] = value
426
+ val_i32 = cute.recast_tensor(val, cutlass.Int32)
427
+ for i in cutlass.range_constexpr(cute.size(val_i32)):
428
+ val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
429
+ return val[0]
430
+
431
+
432
+ @dsl_user_op
433
+ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
434
+ return cutlass.Uint32(
435
+ llvm.inline_asm(
436
+ T.i32(),
437
+ [
438
+ cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
439
+ cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
440
+ ],
441
+ "shr.s32 $0, $1, $2;",
442
+ "=r,r,r",
443
+ has_side_effects=False,
444
+ is_align_stack=False,
445
+ asm_dialect=llvm.AsmDialect.AD_ATT,
446
+ )
447
+ )
448
+
449
+
450
+ @cute.jit
451
+ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
452
+ if const_expr(lane is None):
453
+ lane = cute.arch.lane_idx()
454
+ # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val)
455
+ for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
456
+ offset = 1 << i
457
+ # Very important that we set mask_and_clamp to 0
458
+ partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
459
+ if lane >= offset:
460
+ val += partial_sum
461
+ # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
462
+ return val
463
+
464
+
465
+ @dsl_user_op
466
+ def cvt_f16x2_f32(
467
+ a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
468
+ ) -> cutlass.Int32:
469
+ assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
470
+ return cutlass.Int32(
471
+ llvm.inline_asm(
472
+ T.i32(),
473
+ [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
474
+ f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
475
+ "=r,f,f",
476
+ has_side_effects=False,
477
+ is_align_stack=False,
478
+ asm_dialect=llvm.AsmDialect.AD_ATT,
479
+ )
480
+ )
481
+
482
+
483
+ @overload
484
+ def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
485
+
486
+
487
+ @overload
488
+ def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
489
+
490
+
491
+ @cute.jit
492
+ def cvt_f16(src: cute.Tensor, dst_or_dtype):
493
+ """Convert Float32 tensor to Float16/BFloat16.
494
+
495
+ Args:
496
+ src: Source tensor with Float32 element type
497
+ dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
498
+
499
+ Returns:
500
+ None if dst is a tensor, or a new tensor if dtype is provided
501
+ """
502
+ if const_expr(isinstance(dst_or_dtype, type)):
503
+ # dtype variant: create new tensor and call the tensor variant
504
+ dtype = dst_or_dtype
505
+ dst = cute.make_fragment(src.shape, dtype)
506
+ cvt_f16(src, dst)
507
+ return dst
508
+ else:
509
+ # tensor variant: write to dst
510
+ dst = dst_or_dtype
511
+ assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
512
+ assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
513
+ assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
514
+ "dst must be BFloat16 or Float16"
515
+ )
516
+ assert src.element_type is Float32, "src must be Float32"
517
+ dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
518
+ assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
519
+ for i in cutlass.range_constexpr(cute.size(dst_i32)):
520
+ dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
521
+
522
+
523
+ @dsl_user_op
524
+ @cute.jit
525
+ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
526
+ deg = len(poly) - 1
527
+ out = poly[deg]
528
+ for i in cutlass.range_constexpr(deg - 1, -1, -1):
529
+ out = out * x + poly[i]
530
+ return out
531
+
532
+
533
+ @dsl_user_op
534
+ @cute.jit
535
+ def evaluate_polynomial_2(
536
+ x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
537
+ ) -> Tuple[Float32, Float32]:
538
+ deg = len(poly) - 1
539
+ out = (poly[deg], poly[deg])
540
+ for i in cutlass.range_constexpr(deg - 1, -1, -1):
541
+ out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
542
+ return out
543
+
544
+
545
+ @dsl_user_op
546
+ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
547
+ # There's probably a way to call llvm or nvvm to do this instead of ptx
548
+ return cutlass.Float32(
549
+ llvm.inline_asm(
550
+ T.f32(),
551
+ [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
552
+ "add.rm.ftz.f32 $0, $1, $2;",
553
+ "=f,f,f",
554
+ has_side_effects=False,
555
+ is_align_stack=False,
556
+ asm_dialect=llvm.AsmDialect.AD_ATT,
557
+ )
558
+ )
559
+
560
+
561
+ @dsl_user_op
562
+ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
563
+ return cutlass.Float32(
564
+ llvm.inline_asm(
565
+ T.f32(),
566
+ [
567
+ Float32(x_rounded).ir_value(loc=loc, ip=ip),
568
+ Float32(frac_ex2).ir_value(loc=loc, ip=ip),
569
+ ],
570
+ "{\n\t"
571
+ ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
572
+ "mov.b32 x_rounded_i, $1;\n\t"
573
+ "mov.b32 frac_ex_i, $2;\n\t"
574
+ "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
575
+ # add.u32 generates IMAD instruction and add.s32 generates LEA instruction
576
+ # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
577
+ "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
578
+ "mov.b32 $0, out_i;\n\t"
579
+ "}\n",
580
+ "=f,f,f",
581
+ has_side_effects=False,
582
+ is_align_stack=False,
583
+ asm_dialect=llvm.AsmDialect.AD_ATT,
584
+ )
585
+ )
586
+
587
+
588
+ @dsl_user_op
589
+ def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:
590
+ assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported"
591
+ # We assume x <= 127.0
592
+ fp32_round_int = float(2**23 + 2**22)
593
+ x_clamped = cute.arch.fmax(x, -127.0)
594
+ # We want to round down here, so that the fractional part is in [0, 1)
595
+ x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
596
+ # The integer floor of x is now in the last 8 bits of x_rounded
597
+ # We assume the next 2 ops round to nearest even. The rounding mode is important.
598
+ x_rounded_back = x_rounded - fp32_round_int
599
+ x_frac = x_clamped - x_rounded_back
600
+ x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
601
+ return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
602
+
603
+
604
+ # TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version
605
+ @dsl_user_op
606
+ def ex2_emulation_2(
607
+ x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None
608
+ ) -> Tuple[Float32, Float32]:
609
+ # We assume x <= 127.0 and y <= 127.0
610
+ fp32_round_int = float(2**23 + 2**22)
611
+ xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
612
+ # We want to round down here, so that the fractional part is in [0, 1)
613
+ xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm")
614
+ # The integer floor of x & y are now in the last 8 bits of xy_rounded
615
+ # We want the next 2 ops to round to nearest even. The rounding mode is important.
616
+ xy_rounded_back = activation.sub_packed_f32x2(
617
+ xy_rounded, (fp32_round_int, fp32_round_int)
618
+ )
619
+ xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)
620
+ xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
621
+ x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
622
+ y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
623
+ return x_out, y_out
624
+
625
+
626
+ @dsl_user_op
627
+ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
628
+ out_f32x2 = llvm.inline_asm(
629
+ llvm.StructType.get_literal([T.f32(), T.f32()]),
630
+ [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
631
+ "{\n\t"
632
+ ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
633
+ ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
634
+ ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
635
+ "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
636
+ "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
637
+ "mov.b64 l1, {f1, f2};\n\t"
638
+ "mov.f32 f3, 0f4B400000;\n\t"
639
+ "mov.b64 l2, {f3, f3};\n\t"
640
+ "add.rm.ftz.f32x2 l7, l1, l2;\n\t"
641
+ "sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
642
+ "sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
643
+ "mov.f32 f7, 0f3D9DF09D;\n\t"
644
+ "mov.b64 l6, {f7, f7};\n\t"
645
+ "mov.f32 f6, 0f3E6906A4;\n\t"
646
+ "mov.b64 l5, {f6, f6};\n\t"
647
+ "mov.f32 f5, 0f3F31F519;\n\t"
648
+ "mov.b64 l4, {f5, f5};\n\t"
649
+ "mov.f32 f4, 0f3F800000;\n\t"
650
+ "mov.b64 l3, {f4, f4};\n\t"
651
+ "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
652
+ "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
653
+ "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
654
+ "mov.b64 {r1, r2}, l7;\n\t"
655
+ "mov.b64 {r3, r4}, l10;\n\t"
656
+ "shl.b32 r5, r1, 23;\n\t"
657
+ "add.s32 r7, r5, r3;\n\t"
658
+ "shl.b32 r6, r2, 23;\n\t"
659
+ "add.s32 r8, r6, r4;\n\t"
660
+ "mov.b32 $0, r7;\n\t"
661
+ "mov.b32 $1, r8;\n\t"
662
+ "}\n",
663
+ "=r,=r,f,f",
664
+ has_side_effects=False,
665
+ is_align_stack=False,
666
+ asm_dialect=llvm.AsmDialect.AD_ATT,
667
+ )
668
+ out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
669
+ out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
670
+ return out0, out1
671
+
672
+
673
+ @dsl_user_op
674
+ def domain_offset_aligned(
675
+ coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
676
+ ) -> cute.Tensor:
677
+ assert isinstance(tensor.iterator, cute.Pointer)
678
+ # We assume that applying the offset does not change the pointer alignment
679
+ new_ptr = cute.make_ptr(
680
+ tensor.element_type,
681
+ elem_pointer(tensor, coord).toint(),
682
+ tensor.memspace,
683
+ assumed_align=tensor.iterator.alignment,
684
+ )
685
+ return cute.make_tensor(new_ptr, tensor.layout)
686
+
687
+
688
+ @cute.jit
689
+ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
690
+ """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
691
+ vec = cute.make_fragment(1, dtype)
692
+ vec[0] = a
693
+ return vec.load()
694
+
695
+
696
+ def ssa_to_scalar(val):
697
+ """Could inline but nice for reflecting the above api"""
698
+ return val[0]