kernels-bot commited on
Commit
2b537bb
·
verified ·
1 Parent(s): 1890a47

Uploaded using `kernel-builder`.

Browse files
build/torch-cuda/__init__.py CHANGED
@@ -1,19 +1,15 @@
1
  """Flash Attention CUTE (CUDA Template Engine) implementation."""
2
 
3
- from importlib.metadata import PackageNotFoundError, version
4
-
5
- # Update when syncing again.
6
- __version__ = "4.0.0.beta4"
7
 
8
  import cutlass.cute as cute
9
 
 
10
  from .interface import (
11
  flash_attn_func,
12
  flash_attn_varlen_func,
13
  )
14
 
15
- from .cute_dsl_utils import cute_compile_patched
16
-
17
  # Patch cute.compile to optionally dump SASS
18
  cute.compile = cute_compile_patched
19
 
 
1
  """Flash Attention CUTE (CUDA Template Engine) implementation."""
2
 
3
+ __version__ = "4.0.0.beta8"
 
 
 
4
 
5
  import cutlass.cute as cute
6
 
7
+ from .cute_dsl_utils import cute_compile_patched
8
  from .interface import (
9
  flash_attn_func,
10
  flash_attn_varlen_func,
11
  )
12
 
 
 
13
  # Patch cute.compile to optionally dump SASS
14
  cute.compile = cute_compile_patched
15
 
build/torch-cuda/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._flash_attn4_525b056
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_flash_attn4_525b056::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._flash_attn4_c9a1374
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_flash_attn4_c9a1374::{op_name}"
build/torch-cuda/bench_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation."""
2
+
3
+ import math
4
+ import torch
5
+
6
+ try:
7
+ import cudnn
8
+ except ImportError:
9
+ cudnn = None
10
+
11
+
12
+ # ── FLOPS calculation ────────────────────────────────────────────────────────
13
+
14
+
15
+ def flops(
16
+ batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)
17
+ ):
18
+ if causal:
19
+ avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
20
+ else:
21
+ if window_size == (None, None):
22
+ avg_seqlen = seqlen_k
23
+ else:
24
+ row_idx = torch.arange(seqlen_q, device="cuda")
25
+ col_left = (
26
+ torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
27
+ if window_size[0] is not None
28
+ else torch.zeros_like(row_idx)
29
+ )
30
+ col_right = (
31
+ torch.minimum(
32
+ row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)
33
+ )
34
+ if window_size[1] is not None
35
+ else torch.full_like(row_idx, seqlen_k - 1)
36
+ )
37
+ avg_seqlen = (col_right - col_left + 1).float().mean().item()
38
+ return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
39
+
40
+
41
+ # ── Reference attention ─────────────────────────────────────────────────────
42
+
43
+ _attention_ref_mask_cache = {}
44
+
45
+
46
+ def attention_ref(q, k, v, causal=False):
47
+ """Standard attention reference implementation.
48
+
49
+ Args:
50
+ q, k, v: (batch, seqlen, nheads, headdim) tensors.
51
+ causal: whether to apply causal mask.
52
+ """
53
+ softmax_scale = 1.0 / math.sqrt(q.shape[-1])
54
+ scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k)
55
+ if causal:
56
+ if scores.shape[-2] not in _attention_ref_mask_cache:
57
+ mask = torch.tril(
58
+ torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0
59
+ )
60
+ _attention_ref_mask_cache[scores.shape[-2]] = mask
61
+ else:
62
+ mask = _attention_ref_mask_cache[scores.shape[-2]]
63
+ scores = scores.masked_fill(mask, float("-inf"))
64
+ attn = torch.softmax(scores, dim=-1)
65
+ return torch.einsum("bhts,bshd->bthd", attn, v)
66
+
67
+
68
+ # ── cuDNN graph helpers ─────────────────────────────────────────────────────
69
+
70
+ _TORCH_TO_CUDNN_DTYPE = {
71
+ torch.float16: "HALF",
72
+ torch.bfloat16: "BFLOAT16",
73
+ torch.float32: "FLOAT",
74
+ torch.int32: "INT32",
75
+ torch.int64: "INT64",
76
+ }
77
+
78
+
79
+ def _build_cudnn_graph(io_dtype, tensors, build_fn):
80
+ """Build a cuDNN graph. Returns (graph, variant_pack, workspace)."""
81
+ assert cudnn is not None, "cuDNN is not available"
82
+ cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype])
83
+ graph = cudnn.pygraph(
84
+ io_data_type=cudnn_dtype,
85
+ intermediate_data_type=cudnn.data_type.FLOAT,
86
+ compute_data_type=cudnn.data_type.FLOAT,
87
+ )
88
+ graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()}
89
+ variant_pack = build_fn(graph, graph_tensors)
90
+ graph.validate()
91
+ graph.build_operation_graph()
92
+ graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
93
+ graph.check_support()
94
+ graph.build_plans()
95
+ workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
96
+ return graph, variant_pack, workspace
97
+
98
+
99
+ def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None):
100
+ """Build a cuDNN forward SDPA graph.
101
+
102
+ Args:
103
+ q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout).
104
+ causal: whether to apply causal mask.
105
+ window_size_left: sliding window size (None for no window).
106
+
107
+ Returns:
108
+ (fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable.
109
+ """
110
+ b, nheads, seqlen_q, headdim = q.shape
111
+ headdim_v = v.shape[-1]
112
+ o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device)
113
+ stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
114
+
115
+ def build(graph, gt):
116
+ o, stats = graph.sdpa(
117
+ name="sdpa",
118
+ q=gt["q"],
119
+ k=gt["k"],
120
+ v=gt["v"],
121
+ is_inference=False,
122
+ attn_scale=1.0 / math.sqrt(headdim),
123
+ use_causal_mask=causal or window_size_left is not None,
124
+ sliding_window_length=window_size_left
125
+ if window_size_left is not None and not causal
126
+ else None,
127
+ )
128
+ o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
129
+ stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
130
+ return {gt["q"]: q, gt["k"]: k, gt["v"]: v, o: o_gpu, stats: stats_gpu}
131
+
132
+ graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {"q": q, "k": k, "v": v}, build)
133
+
134
+ def fwd_fn():
135
+ graph.execute(variant_pack, workspace)
136
+ return o_gpu
137
+
138
+ return fwd_fn, o_gpu, stats_gpu
139
+
140
+
141
+ def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None):
142
+ """Build a cuDNN backward SDPA graph.
143
+
144
+ Args:
145
+ q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout).
146
+ causal: whether to apply causal mask.
147
+ window_size_left: sliding window size (None for no window).
148
+
149
+ Returns:
150
+ bwd_fn: zero-arg callable that returns (dq, dk, dv).
151
+ """
152
+ headdim = q.shape[-1]
153
+ dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
154
+
155
+ def build(graph, gt):
156
+ dq, dk, dv = graph.sdpa_backward(
157
+ name="sdpa_backward",
158
+ q=gt["q"],
159
+ k=gt["k"],
160
+ v=gt["v"],
161
+ o=gt["o"],
162
+ dO=gt["g"],
163
+ stats=gt["lse"],
164
+ attn_scale=1.0 / math.sqrt(headdim),
165
+ use_causal_mask=causal or window_size_left is not None,
166
+ sliding_window_length=window_size_left
167
+ if window_size_left is not None and not causal
168
+ else None,
169
+ use_deterministic_algorithm=False,
170
+ )
171
+ dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())
172
+ dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())
173
+ dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())
174
+ return {
175
+ gt["q"]: q,
176
+ gt["k"]: k,
177
+ gt["v"]: v,
178
+ gt["o"]: o,
179
+ gt["g"]: g,
180
+ gt["lse"]: lse,
181
+ dq: dq_gpu,
182
+ dk: dk_gpu,
183
+ dv: dv_gpu,
184
+ }
185
+
186
+ graph, variant_pack, workspace = _build_cudnn_graph(
187
+ q.dtype,
188
+ {"q": q, "k": k, "v": v, "o": o, "g": g, "lse": lse},
189
+ build,
190
+ )
191
+
192
+ def bwd_fn():
193
+ graph.execute(variant_pack, workspace)
194
+ return dq_gpu, dk_gpu, dv_gpu
195
+
196
+ return bwd_fn
build/torch-cuda/block_info.py CHANGED
@@ -6,7 +6,7 @@ import cutlass
6
  import cutlass.cute as cute
7
  from cutlass import Int32, const_expr
8
 
9
- from .seqlen_info import SeqlenInfoQK
10
 
11
 
12
  @dataclass(frozen=True)
@@ -25,8 +25,8 @@ class BlockInfo:
25
  self,
26
  seqlen_info: SeqlenInfoQK,
27
  m_block: Int32,
28
- split_idx: cutlass.Int32 = 0,
29
- num_splits: cutlass.Int32 = 1,
30
  ) -> Tuple[Int32, Int32]:
31
  n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
32
  if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
@@ -46,7 +46,7 @@ class BlockInfo:
46
  n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
47
  if cutlass.const_expr(self.is_split_kv):
48
  num_n_blocks_per_split = (
49
- cutlass.Int32(0)
50
  if n_block_max <= n_block_min
51
  else (n_block_max - n_block_min + num_splits - 1) // num_splits
52
  )
@@ -70,6 +70,37 @@ class BlockInfo:
70
  m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
71
  return m_block_min, m_block_max
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @cute.jit
74
  def get_n_block_min_causal_local_mask(
75
  self,
 
6
  import cutlass.cute as cute
7
  from cutlass import Int32, const_expr
8
 
9
+ from .seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK
10
 
11
 
12
  @dataclass(frozen=True)
 
25
  self,
26
  seqlen_info: SeqlenInfoQK,
27
  m_block: Int32,
28
+ split_idx: Int32 = 0,
29
+ num_splits: Int32 = 1,
30
  ) -> Tuple[Int32, Int32]:
31
  n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
32
  if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
 
46
  n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
47
  if cutlass.const_expr(self.is_split_kv):
48
  num_n_blocks_per_split = (
49
+ Int32(0)
50
  if n_block_max <= n_block_min
51
  else (n_block_max - n_block_min + num_splits - 1) // num_splits
52
  )
 
70
  m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
71
  return m_block_min, m_block_max
72
 
73
+ @cute.jit
74
+ def get_n_block_k_new_min_max(
75
+ self,
76
+ seqlen_info: SeqlenInfoQKNewK,
77
+ m_block: Int32,
78
+ split_idx: Int32 = 0,
79
+ num_splits: Int32 = 1,
80
+ ) -> Tuple[Int32, Int32]:
81
+ """Get the block range for new K tokens (append KV).
82
+
83
+ First computes the full n_block range via get_n_block_min_max, then maps
84
+ those blocks into the new-K index space by subtracting seqlen_k_og.
85
+ """
86
+ n_block_min, n_block_max = self.get_n_block_min_max(
87
+ seqlen_info,
88
+ m_block,
89
+ split_idx,
90
+ num_splits,
91
+ )
92
+ idx_k_new_min = cutlass.max(n_block_min * self.tile_n - seqlen_info.seqlen_k_og, 0)
93
+ idx_k_new_max = cutlass.min(
94
+ n_block_max * self.tile_n - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new
95
+ )
96
+ n_block_new_min = idx_k_new_min // self.tile_n
97
+ n_block_new_max = (
98
+ cute.ceil_div(idx_k_new_max, self.tile_n)
99
+ if idx_k_new_max > idx_k_new_min
100
+ else n_block_new_min
101
+ )
102
+ return n_block_new_min, n_block_new_max
103
+
104
  @cute.jit
105
  def get_n_block_min_causal_local_mask(
106
  self,
build/torch-cuda/block_sparse_utils.py CHANGED
@@ -72,24 +72,22 @@ from .named_barrier import NamedBarrierBwd
72
  def load_block_list(
73
  block_indices: cute.Tensor,
74
  block_count,
75
- load_q_with_first: cutlass.Constexpr,
76
  first_block_preloaded: cutlass.Constexpr,
77
  kv_producer_state,
78
- load_Q,
79
  load_K,
80
  load_V,
81
  pipeline_k,
82
  pipeline_v,
83
- use_tma_q: cutlass.Constexpr,
84
- tma_q_bytes: cutlass.Constexpr,
85
  intra_wg_overlap: cutlass.Constexpr,
86
  ):
87
- """Iterate over the sparse blocks and load K, V (and Q) into the pipeline.
88
- for the intra_wg_overlap case, we overlap the loads of K and V. And this
89
  means we need to pipeline the last V load from the partial block case,
90
  with the loads for the full blocks. Set first_block_preloaded when the
91
  caller has already issued the first K load for the list.
92
 
 
 
93
  Note:
94
  we iterate along the block_n indices in reverse.
95
 
@@ -99,21 +97,7 @@ def load_block_list(
99
  """
100
  if block_count > 0:
101
  if const_expr(not intra_wg_overlap):
102
- # Peel first iteration: the first block may need to load Q alongside K,
103
- # Parameters are already Constexpr, so no need to wrap in const_expr()
104
- n_block_first = block_indices[block_count - 1]
105
- extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
106
- pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
107
-
108
- if const_expr(load_q_with_first and use_tma_q):
109
- load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
110
-
111
- load_K(src_idx=n_block_first, producer_state=kv_producer_state)
112
- pipeline_v.producer_acquire(kv_producer_state)
113
- load_V(src_idx=n_block_first, producer_state=kv_producer_state)
114
- kv_producer_state.advance()
115
-
116
- for offset in cutlass.range(1, block_count):
117
  n_block = block_indices[block_count - 1 - offset]
118
  pipeline_k.producer_acquire(kv_producer_state)
119
  load_K(src_idx=n_block, producer_state=kv_producer_state)
@@ -123,14 +107,7 @@ def load_block_list(
123
  else:
124
  n_block_first = block_indices[block_count - 1]
125
  if const_expr(not first_block_preloaded):
126
- extra_tx = (
127
- tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0
128
- )
129
- pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx)
130
-
131
- if const_expr(load_q_with_first and use_tma_q):
132
- load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
133
-
134
  load_K(src_idx=n_block_first, producer_state=kv_producer_state)
135
 
136
  for idx in cutlass.range(block_count - 1, unroll=1):
@@ -186,19 +163,18 @@ def produce_block_sparse_loads(
186
  head_idx,
187
  m_block,
188
  kv_producer_state,
189
- load_Q,
190
  load_K,
191
  load_V,
192
  pipeline_k,
193
  pipeline_v,
194
- use_tma_q: cutlass.Constexpr,
195
- tma_q_bytes: cutlass.Constexpr,
196
  intra_wg_overlap: cutlass.Constexpr,
197
  qhead_per_kvhead: cutlass.Constexpr[int] = 1,
198
  q_subtile_factor: cutlass.Constexpr[int] = 1,
199
  ):
200
  """Iterate over the mask and full block lists for a single tile.
201
 
 
 
202
  The masked (partial) list may leave the last V load pending when intra-warp-group
203
  overlap is enabled. The first full block must consume that pending V while
204
  issuing its own K load on the next pipeline stage.
@@ -230,20 +206,16 @@ def produce_block_sparse_loads(
230
  full_empty = curr_full_block_cnt == 0
231
 
232
  if mask_empty:
233
- # No masked blocks: the full list owns the initial Q+K load.
234
  kv_producer_state = load_block_list(
235
  curr_full_block_idx,
236
  curr_full_block_cnt,
237
- load_q_with_first=True,
238
  first_block_preloaded=False,
239
  kv_producer_state=kv_producer_state,
240
- load_Q=load_Q,
241
  load_K=load_K,
242
  load_V=load_V,
243
  pipeline_k=pipeline_k,
244
  pipeline_v=pipeline_v,
245
- use_tma_q=use_tma_q,
246
- tma_q_bytes=tma_q_bytes,
247
  intra_wg_overlap=intra_wg_overlap,
248
  )
249
 
@@ -256,21 +228,16 @@ def produce_block_sparse_loads(
256
  kv_producer_state,
257
  )
258
  else:
259
- # Masked blocks present: load Q together with the first masked K so consumers can
260
- # start immediately. When overlap is disabled this fully drains the list.
261
  kv_producer_state = load_block_list(
262
  curr_mask_block_idx,
263
  curr_mask_block_cnt,
264
- load_q_with_first=True,
265
  first_block_preloaded=False,
266
  kv_producer_state=kv_producer_state,
267
- load_Q=load_Q,
268
  load_K=load_K,
269
  load_V=load_V,
270
  pipeline_k=pipeline_k,
271
  pipeline_v=pipeline_v,
272
- use_tma_q=use_tma_q,
273
- tma_q_bytes=tma_q_bytes,
274
  intra_wg_overlap=intra_wg_overlap,
275
  )
276
 
@@ -299,16 +266,12 @@ def produce_block_sparse_loads(
299
  kv_producer_state = load_block_list(
300
  curr_full_block_idx,
301
  curr_full_block_cnt,
302
- load_q_with_first=False,
303
  first_block_preloaded=True,
304
  kv_producer_state=kv_producer_state,
305
- load_Q=load_Q,
306
  load_K=load_K,
307
  load_V=load_V,
308
  pipeline_k=pipeline_k,
309
  pipeline_v=pipeline_v,
310
- use_tma_q=use_tma_q,
311
- tma_q_bytes=tma_q_bytes,
312
  intra_wg_overlap=intra_wg_overlap,
313
  )
314
 
@@ -320,21 +283,16 @@ def produce_block_sparse_loads(
320
  kv_producer_state,
321
  )
322
  else:
323
- # Non-overlap path with both lists: run the full list normally (skipping the Q
324
- # reload because the masked list already issued it).
325
  kv_producer_state = load_block_list(
326
  curr_full_block_idx,
327
  curr_full_block_cnt,
328
- load_q_with_first=False,
329
  first_block_preloaded=False,
330
  kv_producer_state=kv_producer_state,
331
- load_Q=load_Q,
332
  load_K=load_K,
333
  load_V=load_V,
334
  pipeline_k=pipeline_k,
335
  pipeline_v=pipeline_v,
336
- use_tma_q=use_tma_q,
337
- tma_q_bytes=tma_q_bytes,
338
  intra_wg_overlap=intra_wg_overlap,
339
  )
340
 
@@ -1390,18 +1348,18 @@ def _store_one_dQaccum_sm90(
1390
  m_block,
1391
  sdQaccum: cute.Tensor,
1392
  gdQaccum: cute.Tensor,
1393
- num_mma_warp_groups: cutlass.Constexpr,
1394
  num_threads_per_warp_group: cutlass.Constexpr,
1395
  tma_copy_bytes_dQ,
1396
  ):
1397
  """Store dQaccum for a single m_block."""
1398
- for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1399
- cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
1400
  cute.arch.barrier_arrive(
1401
  barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1402
  number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1403
  )
1404
- for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
1405
  cute.arch.barrier(
1406
  barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1407
  number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
@@ -1409,7 +1367,7 @@ def _store_one_dQaccum_sm90(
1409
  with cute.arch.elect_one():
1410
  copy_utils.cpasync_reduce_bulk_add_f32(
1411
  sdQaccum[None, warp_group_idx].iterator,
1412
- gdQaccum[None, warp_group_idx, m_block].iterator,
1413
  tma_copy_bytes_dQ,
1414
  )
1415
  cute.arch.cp_async_bulk_commit_group()
@@ -1425,7 +1383,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
1425
  gdQaccum: cute.Tensor,
1426
  subtile_factor: cutlass.Constexpr,
1427
  m_block_max: int,
1428
- num_mma_warp_groups: cutlass.Constexpr,
1429
  num_threads_per_warp_group: cutlass.Constexpr,
1430
  tma_copy_bytes_dQ,
1431
  ):
@@ -1454,7 +1412,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
1454
  m_block,
1455
  sdQaccum,
1456
  gdQaccum,
1457
- num_mma_warp_groups,
1458
  num_threads_per_warp_group,
1459
  tma_copy_bytes_dQ,
1460
  )
@@ -1470,7 +1428,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
1470
  m_block,
1471
  sdQaccum,
1472
  gdQaccum,
1473
- num_mma_warp_groups,
1474
  num_threads_per_warp_group,
1475
  tma_copy_bytes_dQ,
1476
  )
 
72
  def load_block_list(
73
  block_indices: cute.Tensor,
74
  block_count,
 
75
  first_block_preloaded: cutlass.Constexpr,
76
  kv_producer_state,
 
77
  load_K,
78
  load_V,
79
  pipeline_k,
80
  pipeline_v,
 
 
81
  intra_wg_overlap: cutlass.Constexpr,
82
  ):
83
+ """Iterate over the sparse blocks and load K, V into the pipeline.
84
+ For the intra_wg_overlap case, we overlap the loads of K and V. And this
85
  means we need to pipeline the last V load from the partial block case,
86
  with the loads for the full blocks. Set first_block_preloaded when the
87
  caller has already issued the first K load for the list.
88
 
89
+ Q is loaded separately on its own mbarrier before this function is called.
90
+
91
  Note:
92
  we iterate along the block_n indices in reverse.
93
 
 
97
  """
98
  if block_count > 0:
99
  if const_expr(not intra_wg_overlap):
100
+ for offset in cutlass.range(block_count):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  n_block = block_indices[block_count - 1 - offset]
102
  pipeline_k.producer_acquire(kv_producer_state)
103
  load_K(src_idx=n_block, producer_state=kv_producer_state)
 
107
  else:
108
  n_block_first = block_indices[block_count - 1]
109
  if const_expr(not first_block_preloaded):
110
+ pipeline_k.producer_acquire(kv_producer_state)
 
 
 
 
 
 
 
111
  load_K(src_idx=n_block_first, producer_state=kv_producer_state)
112
 
113
  for idx in cutlass.range(block_count - 1, unroll=1):
 
163
  head_idx,
164
  m_block,
165
  kv_producer_state,
 
166
  load_K,
167
  load_V,
168
  pipeline_k,
169
  pipeline_v,
 
 
170
  intra_wg_overlap: cutlass.Constexpr,
171
  qhead_per_kvhead: cutlass.Constexpr[int] = 1,
172
  q_subtile_factor: cutlass.Constexpr[int] = 1,
173
  ):
174
  """Iterate over the mask and full block lists for a single tile.
175
 
176
+ Q is loaded separately on its own mbarrier before this function is called.
177
+
178
  The masked (partial) list may leave the last V load pending when intra-warp-group
179
  overlap is enabled. The first full block must consume that pending V while
180
  issuing its own K load on the next pipeline stage.
 
206
  full_empty = curr_full_block_cnt == 0
207
 
208
  if mask_empty:
209
+ # No masked blocks: the full list owns the initial K load.
210
  kv_producer_state = load_block_list(
211
  curr_full_block_idx,
212
  curr_full_block_cnt,
 
213
  first_block_preloaded=False,
214
  kv_producer_state=kv_producer_state,
 
215
  load_K=load_K,
216
  load_V=load_V,
217
  pipeline_k=pipeline_k,
218
  pipeline_v=pipeline_v,
 
 
219
  intra_wg_overlap=intra_wg_overlap,
220
  )
221
 
 
228
  kv_producer_state,
229
  )
230
  else:
231
+ # Masked blocks present. When overlap is disabled this fully drains the list.
 
232
  kv_producer_state = load_block_list(
233
  curr_mask_block_idx,
234
  curr_mask_block_cnt,
 
235
  first_block_preloaded=False,
236
  kv_producer_state=kv_producer_state,
 
237
  load_K=load_K,
238
  load_V=load_V,
239
  pipeline_k=pipeline_k,
240
  pipeline_v=pipeline_v,
 
 
241
  intra_wg_overlap=intra_wg_overlap,
242
  )
243
 
 
266
  kv_producer_state = load_block_list(
267
  curr_full_block_idx,
268
  curr_full_block_cnt,
 
269
  first_block_preloaded=True,
270
  kv_producer_state=kv_producer_state,
 
271
  load_K=load_K,
272
  load_V=load_V,
273
  pipeline_k=pipeline_k,
274
  pipeline_v=pipeline_v,
 
 
275
  intra_wg_overlap=intra_wg_overlap,
276
  )
277
 
 
283
  kv_producer_state,
284
  )
285
  else:
286
+ # Non-overlap path with both lists: run the full list normally.
 
287
  kv_producer_state = load_block_list(
288
  curr_full_block_idx,
289
  curr_full_block_cnt,
 
290
  first_block_preloaded=False,
291
  kv_producer_state=kv_producer_state,
 
292
  load_K=load_K,
293
  load_V=load_V,
294
  pipeline_k=pipeline_k,
295
  pipeline_v=pipeline_v,
 
 
296
  intra_wg_overlap=intra_wg_overlap,
297
  )
298
 
 
1348
  m_block,
1349
  sdQaccum: cute.Tensor,
1350
  gdQaccum: cute.Tensor,
1351
+ num_dQ_warp_groups: cutlass.Constexpr,
1352
  num_threads_per_warp_group: cutlass.Constexpr,
1353
  tma_copy_bytes_dQ,
1354
  ):
1355
  """Store dQaccum for a single m_block."""
1356
+ for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups):
1357
+ cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True)
1358
  cute.arch.barrier_arrive(
1359
  barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1360
  number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
1361
  )
1362
+ for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups):
1363
  cute.arch.barrier(
1364
  barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1365
  number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
 
1367
  with cute.arch.elect_one():
1368
  copy_utils.cpasync_reduce_bulk_add_f32(
1369
  sdQaccum[None, warp_group_idx].iterator,
1370
+ gdQaccum[(None, warp_group_idx), m_block].iterator,
1371
  tma_copy_bytes_dQ,
1372
  )
1373
  cute.arch.cp_async_bulk_commit_group()
 
1383
  gdQaccum: cute.Tensor,
1384
  subtile_factor: cutlass.Constexpr,
1385
  m_block_max: int,
1386
+ num_dQ_warp_groups: cutlass.Constexpr,
1387
  num_threads_per_warp_group: cutlass.Constexpr,
1388
  tma_copy_bytes_dQ,
1389
  ):
 
1412
  m_block,
1413
  sdQaccum,
1414
  gdQaccum,
1415
+ num_dQ_warp_groups,
1416
  num_threads_per_warp_group,
1417
  tma_copy_bytes_dQ,
1418
  )
 
1428
  m_block,
1429
  sdQaccum,
1430
  gdQaccum,
1431
+ num_dQ_warp_groups,
1432
  num_threads_per_warp_group,
1433
  tma_copy_bytes_dQ,
1434
  )
build/torch-cuda/block_sparsity.py CHANGED
@@ -34,6 +34,23 @@ class BlockSparseTensorsTorch(NamedTuple):
34
  block_size: tuple[int, int] | None = None
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def _expand_sparsity_tensor(
38
  tensor: torch.Tensor,
39
  expected_shape: Tuple[int, ...],
@@ -81,6 +98,12 @@ def _check_and_expand_block(
81
  expanded_cnt = _expand_sparsity_tensor(
82
  cnt, expected_count_shape, f"{name}_block_cnt", context, hint
83
  )
 
 
 
 
 
 
84
  expanded_idx = _expand_sparsity_tensor(
85
  idx, expected_index_shape, f"{name}_block_idx", context, hint
86
  )
@@ -140,17 +163,14 @@ def infer_block_sparse_expected_shapes(
140
  num_m_blocks = tensors.mask_block_idx.shape[2]
141
 
142
  if sparse_block_size_q is None:
143
- min_block_size = ceildiv(seqlen_q, num_m_blocks)
144
- if num_m_blocks == 1:
145
- max_block_size = seqlen_q
146
- else:
147
- max_block_size = (seqlen_q - 1) // (num_m_blocks - 1)
148
- if max_block_size != min_block_size and base_m_block != 1:
149
  raise ValueError(
150
  f"Block sparse tensors{context} require explicit sparse_block_size[0] "
151
  f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
152
  )
153
- sparse_block_size_q = min_block_size
 
154
 
155
  if sparse_block_size_q % base_m_block != 0:
156
  raise ValueError(
@@ -186,9 +206,11 @@ def infer_block_sparse_expected_shapes(
186
  raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
187
  if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
188
  raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
189
- if mask_block_idx.shape[3] != expected_n_blocks:
 
 
190
  raise ValueError(
191
- f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
192
  )
193
  if expected_m_blocks != num_m_blocks:
194
  raise ValueError(
@@ -314,7 +336,7 @@ def normalize_block_sparse_config(
314
  ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
315
  m_block_size, n_block_size = block_size
316
  if tensors.block_size is None:
317
- sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size
318
  else:
319
  sparse_block_size_q, sparse_block_size_kv = tensors.block_size
320
  if sparse_block_size_kv != n_block_size:
@@ -401,6 +423,7 @@ def to_cute_block_sparse_tensors(
401
  """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
402
  if not is_block_sparsity_enabled(tensors):
403
  return None
 
404
  (
405
  mask_block_cnt,
406
  mask_block_idx,
 
34
  block_size: tuple[int, int] | None = None
35
 
36
 
37
+ def get_sparse_q_block_size(
38
+ tensors: BlockSparseTensorsTorch | None,
39
+ seqlen_q: int,
40
+ ) -> int | None:
41
+ """Return the Q sparse block size, or None when sparsity is unset or ambiguous."""
42
+ if tensors is None:
43
+ return None
44
+ if tensors.block_size is not None:
45
+ return tensors.block_size[0]
46
+ num_m_blocks = tensors.mask_block_idx.shape[2]
47
+ min_block_size = ceildiv(seqlen_q, num_m_blocks)
48
+ max_block_size = seqlen_q if num_m_blocks == 1 else (seqlen_q - 1) // (num_m_blocks - 1)
49
+ if min_block_size != max_block_size:
50
+ return None
51
+ return min_block_size
52
+
53
+
54
  def _expand_sparsity_tensor(
55
  tensor: torch.Tensor,
56
  expected_shape: Tuple[int, ...],
 
98
  expanded_cnt = _expand_sparsity_tensor(
99
  cnt, expected_count_shape, f"{name}_block_cnt", context, hint
100
  )
101
+ # [Note] Allow Compact block sparse indices
102
+ # Allow the last dimension (n_blocks) of idx to be <= expected, since
103
+ # FA4 only accesses indices 0..cnt-1 per query tile. This enables compact
104
+ # index tensors that avoid O(N^2) memory at long sequence lengths.
105
+ if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]:
106
+ expected_index_shape = (*expected_index_shape[:3], idx.shape[3])
107
  expanded_idx = _expand_sparsity_tensor(
108
  idx, expected_index_shape, f"{name}_block_idx", context, hint
109
  )
 
163
  num_m_blocks = tensors.mask_block_idx.shape[2]
164
 
165
  if sparse_block_size_q is None:
166
+ sparse_block_size_q = get_sparse_q_block_size(tensors, seqlen_q)
167
+ if sparse_block_size_q is None and base_m_block != 1:
 
 
 
 
168
  raise ValueError(
169
  f"Block sparse tensors{context} require explicit sparse_block_size[0] "
170
  f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
171
  )
172
+ if sparse_block_size_q is None:
173
+ sparse_block_size_q = ceildiv(seqlen_q, num_m_blocks)
174
 
175
  if sparse_block_size_q % base_m_block != 0:
176
  raise ValueError(
 
206
  raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
207
  if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
208
  raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
209
+ # [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1
210
+ # per query tile, so idx.shape[3] can be <= expected_n_blocks.
211
+ if mask_block_idx.shape[3] > expected_n_blocks:
212
  raise ValueError(
213
+ f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}."
214
  )
215
  if expected_m_blocks != num_m_blocks:
216
  raise ValueError(
 
336
  ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
337
  m_block_size, n_block_size = block_size
338
  if tensors.block_size is None:
339
+ sparse_block_size_q, sparse_block_size_kv = None, n_block_size
340
  else:
341
  sparse_block_size_q, sparse_block_size_kv = tensors.block_size
342
  if sparse_block_size_kv != n_block_size:
 
423
  """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
424
  if not is_block_sparsity_enabled(tensors):
425
  return None
426
+
427
  (
428
  mask_block_cnt,
429
  mask_block_idx,
build/torch-cuda/cache_utils.py CHANGED
@@ -1,7 +1,6 @@
1
  # Manage Ahead-of-Time (AOT) compiled kernels
2
  import fcntl
3
  import hashlib
4
- import logging
5
  import os
6
  import pickle
7
  import sys
@@ -18,6 +17,7 @@ import cutlass
18
  import cutlass.cute as cute
19
  import tvm_ffi
20
  from cutlass.cutlass_dsl import JitCompiledFunction
 
21
 
22
  # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
23
  # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
@@ -30,12 +30,6 @@ for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False):
30
  CompileKeyType: TypeAlias = tuple[Hashable, ...]
31
  CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
32
 
33
- logger = logging.getLogger(__name__)
34
- _handler = logging.StreamHandler()
35
- _handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
36
- logger.addHandler(_handler)
37
- logger.setLevel(logging.DEBUG)
38
-
39
 
40
  # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
41
  CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
@@ -222,13 +216,13 @@ class JITPersistentCache(JITCache):
222
  label=sha256_hex,
223
  ):
224
  if obj_path.exists():
225
- logger.debug("Loading compiled function from disk: %s", obj_path)
226
  m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
227
  fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
228
  JITCache.__setitem__(self, key, fn)
229
  return True
230
  else:
231
- logger.debug("Cache miss on disk for key hash %s", sha256_hex)
232
  return False
233
 
234
  def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
@@ -243,14 +237,14 @@ class JITPersistentCache(JITCache):
243
  obj_path = self.cache_path / f"{sha256_hex}.o"
244
  if obj_path.exists():
245
  # Another process already exported.
246
- logger.debug("Skipping export, already on disk: %s", obj_path)
247
  return
248
- logger.debug("Exporting compiled function to disk: %s", obj_path)
249
  fn.export_to_c(
250
  object_file_path=str(obj_path),
251
  function_name=self.EXPORT_FUNCTION_PREFIX,
252
  )
253
- logger.debug("Successfully exported compiled function to disk: %s", obj_path)
254
 
255
  def _key_to_hash(self, key: CompileKeyType) -> str:
256
  return hashlib.sha256(pickle.dumps(key)).hexdigest()
@@ -262,7 +256,7 @@ class JITPersistentCache(JITCache):
262
  """
263
  Not only clear the in-memory cache. Also purge persistent compilation cache.
264
  """
265
- logger.debug("Clearing persistent cache at %s", self.cache_path)
266
  super().clear()
267
  for child in self.cache_path.iterdir():
268
  child.unlink()
@@ -281,8 +275,8 @@ def get_jit_cache(name: str | None = None) -> JITCache:
281
  path = get_cache_path() / _compute_source_fingerprint()
282
  if name:
283
  path = path / name
284
- logger.debug("Creating persistent JIT cache at %s", path)
285
  return JITPersistentCache(path)
286
  else:
287
- logger.debug("Persistent cache disabled, using in-memory JIT cache")
288
  return JITCache()
 
1
  # Manage Ahead-of-Time (AOT) compiled kernels
2
  import fcntl
3
  import hashlib
 
4
  import os
5
  import pickle
6
  import sys
 
17
  import cutlass.cute as cute
18
  import tvm_ffi
19
  from cutlass.cutlass_dsl import JitCompiledFunction
20
+ from .fa_logging import fa_log
21
 
22
  # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
23
  # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
 
30
  CompileKeyType: TypeAlias = tuple[Hashable, ...]
31
  CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function
32
 
 
 
 
 
 
 
33
 
34
  # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
35
  CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
 
216
  label=sha256_hex,
217
  ):
218
  if obj_path.exists():
219
+ fa_log(1, f"Loading compiled function from disk: {obj_path}")
220
  m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
221
  fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
222
  JITCache.__setitem__(self, key, fn)
223
  return True
224
  else:
225
+ fa_log(1, f"Cache miss on disk for key hash {sha256_hex}")
226
  return False
227
 
228
  def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
 
237
  obj_path = self.cache_path / f"{sha256_hex}.o"
238
  if obj_path.exists():
239
  # Another process already exported.
240
+ fa_log(1, f"Skipping export, already on disk: {obj_path}")
241
  return
242
+ fa_log(1, f"Exporting compiled function to disk: {obj_path}")
243
  fn.export_to_c(
244
  object_file_path=str(obj_path),
245
  function_name=self.EXPORT_FUNCTION_PREFIX,
246
  )
247
+ fa_log(1, f"Successfully exported compiled function to disk: {obj_path}")
248
 
249
  def _key_to_hash(self, key: CompileKeyType) -> str:
250
  return hashlib.sha256(pickle.dumps(key)).hexdigest()
 
256
  """
257
  Not only clear the in-memory cache. Also purge persistent compilation cache.
258
  """
259
+ fa_log(1, f"Clearing persistent cache at {self.cache_path}")
260
  super().clear()
261
  for child in self.cache_path.iterdir():
262
  child.unlink()
 
275
  path = get_cache_path() / _compute_source_fingerprint()
276
  if name:
277
  path = path / name
278
+ fa_log(1, f"Creating persistent JIT cache at {path}")
279
  return JITPersistentCache(path)
280
  else:
281
+ fa_log(1, "Persistent cache disabled, using in-memory JIT cache")
282
  return JITCache()
build/torch-cuda/cute_dsl_utils.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  import pathlib
5
  from typing import Tuple
6
  from functools import partial, lru_cache
7
- from dataclasses import dataclass, fields
8
 
9
  import torch
10
 
@@ -15,7 +14,6 @@ except ImportError:
15
 
16
  import cutlass
17
  import cutlass.cute as cute
18
- from cutlass.base_dsl.typing import JitArgument
19
  from cutlass.cutlass_dsl import NumericMeta
20
  from cutlass.cute.runtime import from_dlpack
21
 
@@ -43,42 +41,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
43
  return torch.cuda.get_device_capability(device)
44
 
45
 
46
- @dataclass
47
- class ArgumentsBase(JitArgument):
48
- def __c_pointers__(self):
49
- all_fields = [getattr(self, field.name) for field in fields(self)]
50
- non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
51
- c_ptrs = []
52
- for obj in non_constexpr_fields:
53
- if hasattr(obj, "__c_pointers__"):
54
- c_ptrs.extend(obj.__c_pointers__())
55
- return c_ptrs
56
-
57
- def __get_mlir_types__(self):
58
- all_fields = [getattr(self, field.name) for field in fields(self)]
59
- non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
60
- types, self._values_pos = [], []
61
- for obj in non_constexpr_fields:
62
- if hasattr(obj, "__get_mlir_types__"):
63
- obj_types = obj.__get_mlir_types__()
64
- types.extend(obj_types)
65
- self._values_pos.append(len(obj_types))
66
- else:
67
- self._values_pos.append(0)
68
- return types
69
-
70
- def __new_from_mlir_values__(self, values):
71
- all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
72
- constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
73
- non_constexpr_fields = {
74
- n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
75
- }
76
- for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
77
- non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
78
- values = values[n_items:]
79
- return self.__class__(**non_constexpr_fields, **constexpr_fields)
80
-
81
-
82
  def load_cubin_module_data_patched(cubin_data, filepath):
83
  pathlib.Path(filepath).write_bytes(cubin_data)
84
  return load_cubin_module_data_og(cubin_data)
 
4
  import pathlib
5
  from typing import Tuple
6
  from functools import partial, lru_cache
 
7
 
8
  import torch
9
 
 
14
 
15
  import cutlass
16
  import cutlass.cute as cute
 
17
  from cutlass.cutlass_dsl import NumericMeta
18
  from cutlass.cute.runtime import from_dlpack
19
 
 
41
  return torch.cuda.get_device_capability(device)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def load_cubin_module_data_patched(cubin_data, filepath):
45
  pathlib.Path(filepath).write_bytes(cubin_data)
46
  return load_cubin_module_data_og(cubin_data)
build/torch-cuda/fa_logging.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ """Unified FlashAttention logging controlled by a single ``FA_LOG_LEVEL`` env var.
4
+
5
+ Host-side messages go through Python ``logging`` (logger name ``flash_attn``).
6
+ A default ``StreamHandler`` is attached automatically when ``FA_LOG_LEVEL >= 1``
7
+ so that standalone scripts get output without extra setup; applications that
8
+ configure their own logging can remove or replace it via the standard API.
9
+
10
+ FA_LOG_LEVEL mapping::
11
+
12
+ 0 off nothing logged
13
+ 1 host host-side summaries only (no kernel printf)
14
+ 2 kernel host + curated kernel traces
15
+ 3 max host + all kernel traces (noisy, perf hit)
16
+
17
+ Set via environment variable::
18
+
19
+ FA_LOG_LEVEL=1 python train.py
20
+
21
+ Device-side ``cute.printf`` calls are compile-time eliminated via
22
+ ``cutlass.const_expr`` when the log level is below the callsite threshold,
23
+ so there is zero performance cost when device logging is off.
24
+ Changing the log level after kernel compilation requires a recompile
25
+ (the level participates in the forward compile key).
26
+ """
27
+
28
+ import logging
29
+ import os
30
+ import sys
31
+
32
+ import cutlass.cute as cute
33
+ from cutlass import const_expr
34
+
35
+ _LOG_LEVEL_NAMES = {"off": 0, "host": 1, "kernel": 2, "max": 3}
36
+
37
+
38
+ def _parse_log_level(raw: str) -> int:
39
+ if raw in _LOG_LEVEL_NAMES:
40
+ return _LOG_LEVEL_NAMES[raw]
41
+ try:
42
+ level = int(raw)
43
+ except ValueError:
44
+ return 0
45
+ return max(0, min(level, 3))
46
+
47
+
48
+ _fa_log_level: int = _parse_log_level(os.environ.get("FA_LOG_LEVEL", "0"))
49
+
50
+ _logger = logging.getLogger("flash_attn")
51
+ _logger.addHandler(logging.NullHandler())
52
+ _default_handler: logging.Handler | None = None
53
+
54
+
55
+ def _configure_default_handler() -> None:
56
+ global _default_handler
57
+ if _fa_log_level >= 1:
58
+ if _default_handler is None:
59
+ _default_handler = logging.StreamHandler(sys.stdout)
60
+ _default_handler.setFormatter(logging.Formatter("[FA] %(message)s"))
61
+ _logger.addHandler(_default_handler)
62
+ _logger.setLevel(logging.DEBUG)
63
+ else:
64
+ if _default_handler is not None:
65
+ _logger.removeHandler(_default_handler)
66
+ _default_handler = None
67
+ _logger.setLevel(logging.WARNING)
68
+
69
+
70
+ _configure_default_handler()
71
+
72
+
73
+ def get_fa_log_level() -> int:
74
+ return _fa_log_level
75
+
76
+
77
+ def set_fa_log_level(level: int | str) -> None:
78
+ """Set the FA log level programmatically.
79
+
80
+ Host logging takes effect immediately. Device logging changes only
81
+ affect kernels compiled after this call (new compile-key selection).
82
+ """
83
+ global _fa_log_level
84
+ if isinstance(level, str):
85
+ level = _parse_log_level(level)
86
+ _fa_log_level = max(0, min(int(level), 3))
87
+ _configure_default_handler()
88
+
89
+
90
+ def fa_log(level: int, msg: str):
91
+ if _fa_log_level >= level:
92
+ _logger.info(msg)
93
+
94
+
95
+ def fa_printf(level: int, fmt, *args):
96
+ if const_expr(_fa_log_level >= level):
97
+ cute.printf(fmt, *args)
build/torch-cuda/flash_bwd.py CHANGED
@@ -22,6 +22,7 @@ from .mask import AttentionMask
22
  from .seqlen_info import SeqlenInfoQK
23
  from .quack.cute_dsl_utils import ParamsBase
24
  from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
 
25
 
26
 
27
  class FlashAttentionBackwardSm80:
@@ -372,7 +373,6 @@ class FlashAttentionBackwardSm80:
372
  mdK: cute.Tensor,
373
  mdV: cute.Tensor,
374
  softmax_scale: cutlass.Float32,
375
- stream: cuda.CUstream,
376
  mCuSeqlensQ: Optional[cute.Tensor] = None,
377
  mCuSeqlensK: Optional[cute.Tensor] = None,
378
  mSeqUsedQ: Optional[cute.Tensor] = None,
@@ -381,8 +381,16 @@ class FlashAttentionBackwardSm80:
381
  window_size_left: Int32 | int | None = None,
382
  window_size_right: Int32 | int | None = None,
383
  mdQ_semaphore: Optional[cute.Tensor] = None,
 
 
 
 
 
 
384
  ):
385
- assert mdQ_semaphore is None, "semaphore not supported yet"
 
 
386
  # Get the data type and check if it is fp16 or bf16
387
  self._check_type(*(t.element_type if t is not None else None
388
  for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
@@ -512,7 +520,17 @@ class FlashAttentionBackwardSm80:
512
  n_block, head_idx, batch_idx, _ = work_tile.tile_idx
513
 
514
  if work_tile.is_valid_tile:
515
- seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK)
 
 
 
 
 
 
 
 
 
 
516
 
517
  m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
518
  m_block_min = 0
@@ -538,7 +556,7 @@ class FlashAttentionBackwardSm80:
538
  mdPsum_cur = mdPsum[batch_idx, head_idx, None]
539
  mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
540
  else:
541
- padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
542
  mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
543
  mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
544
  mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
@@ -794,9 +812,10 @@ class FlashAttentionBackwardSm80:
794
  # Mainloop
795
  # ///////////////////////////////////////////////////////////////////////////////
796
  # Start processing of the first n-block.
797
- mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k)
798
  mask_fn = partial(
799
  mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
 
800
  mask_seqlen=True, mask_causal=self.is_causal
801
  )
802
  smem_pipe_read_q = cutlass.Int32(0)
@@ -968,7 +987,7 @@ class FlashAttentionBackwardSm80:
968
 
969
  # MMA dK
970
  if cutlass.const_expr(self.Mma_dKV_is_RS):
971
- tdVrP = layout_utils.reshape_acc_to_frgA(rdS)
972
  else:
973
  tdKrdS = mma_params.tdKrdS
974
  sm80_utils.gemm(
 
22
  from .seqlen_info import SeqlenInfoQK
23
  from .quack.cute_dsl_utils import ParamsBase
24
  from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
25
+ from .block_sparsity import BlockSparseTensors
26
 
27
 
28
  class FlashAttentionBackwardSm80:
 
373
  mdK: cute.Tensor,
374
  mdV: cute.Tensor,
375
  softmax_scale: cutlass.Float32,
 
376
  mCuSeqlensQ: Optional[cute.Tensor] = None,
377
  mCuSeqlensK: Optional[cute.Tensor] = None,
378
  mSeqUsedQ: Optional[cute.Tensor] = None,
 
381
  window_size_left: Int32 | int | None = None,
382
  window_size_right: Int32 | int | None = None,
383
  mdQ_semaphore: Optional[cute.Tensor] = None,
384
+ mdK_semaphore: Optional[cute.Tensor] = None,
385
+ mdV_semaphore: Optional[cute.Tensor] = None,
386
+ aux_tensors: Optional[list] = None,
387
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
388
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
389
+ stream: cuda.CUstream = None,
390
  ):
391
+ assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
392
+ "determinism not supported yet for Sm80"
393
+ )
394
  # Get the data type and check if it is fp16 or bf16
395
  self._check_type(*(t.element_type if t is not None else None
396
  for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))
 
520
  n_block, head_idx, batch_idx, _ = work_tile.tile_idx
521
 
522
  if work_tile.is_valid_tile:
523
+ seqlen = SeqlenInfoQK.create(
524
+ batch_idx,
525
+ mQ.shape[1],
526
+ mK.shape[1],
527
+ mCuSeqlensQ=mCuSeqlensQ,
528
+ mCuSeqlensK=mCuSeqlensK,
529
+ mSeqUsedQ=mSeqUsedQ,
530
+ mSeqUsedK=mSeqUsedK,
531
+ tile_m=self.m_block_size,
532
+ tile_n=self.n_block_size,
533
+ )
534
 
535
  m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)
536
  m_block_min = 0
 
556
  mdPsum_cur = mdPsum[batch_idx, head_idx, None]
557
  mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
558
  else:
559
+ padded_offset_q = seqlen.padded_offset_q
560
  mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])
561
  mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])
562
  mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
 
812
  # Mainloop
813
  # ///////////////////////////////////////////////////////////////////////////////
814
  # Start processing of the first n-block.
815
+ mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen)
816
  mask_fn = partial(
817
  mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,
818
+ batch_idx=batch_idx, head_idx=head_idx,
819
  mask_seqlen=True, mask_causal=self.is_causal
820
  )
821
  smem_pipe_read_q = cutlass.Int32(0)
 
987
 
988
  # MMA dK
989
  if cutlass.const_expr(self.Mma_dKV_is_RS):
990
+ tdKrdS = layout_utils.reshape_acc_to_frgA(rdS)
991
  else:
992
  tdKrdS = mma_params.tdKrdS
993
  sm80_utils.gemm(
build/torch-cuda/flash_bwd_postprocess.py CHANGED
@@ -2,7 +2,7 @@
2
  # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
3
  # from Cutlass C++ to Cute-DSL.
4
  import math
5
- from typing import Callable, Optional, Type, Literal
6
 
7
  import cuda.bindings.driver as cuda
8
 
@@ -36,7 +36,7 @@ class FlashAttentionBackwardPostprocess:
36
  self,
37
  dtype: Type[cutlass.Numeric],
38
  head_dim: int,
39
- arch: Literal[80, 90, 100],
40
  tile_m: int = 128,
41
  num_threads: int = 256,
42
  AtomLayoutMdQ: int = 1,
@@ -52,8 +52,8 @@ class FlashAttentionBackwardPostprocess:
52
  """
53
  self.dtype = dtype
54
  self.tile_m = tile_m
55
- assert arch // 10 in [8, 9, 10, 11], (
56
- "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x) are supported"
57
  )
58
  self.arch = arch
59
  # padding head_dim to a multiple of 32 as k_block_size
@@ -63,7 +63,7 @@ class FlashAttentionBackwardPostprocess:
63
  self.num_threads = num_threads
64
  self.AtomLayoutMdQ = AtomLayoutMdQ
65
  self.dQ_swapAB = dQ_swapAB
66
- self.use_2cta_instrs = use_2cta_instrs and arch == 100 and head_dim != 64
67
  self.cluster_size = cluster_size
68
 
69
  @staticmethod
@@ -89,7 +89,7 @@ class FlashAttentionBackwardPostprocess:
89
  return True
90
 
91
  def _get_tiled_mma(self):
92
- if const_expr(self.arch == 80):
93
  num_mma_warps = self.num_threads // 32
94
  atom_layout_dQ = (
95
  (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
@@ -101,9 +101,9 @@ class FlashAttentionBackwardPostprocess:
101
  atom_layout_dQ,
102
  permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
103
  )
104
- elif const_expr(self.arch == 90):
105
- num_mma_warp_groups = self.num_threads // 128
106
- atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ)
107
  tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
108
  tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
109
  self.dtype,
@@ -125,7 +125,7 @@ class FlashAttentionBackwardPostprocess:
125
  cta_group,
126
  (self.tile_m, self.tile_hdim),
127
  )
128
- if const_expr(self.arch in [80, 90]):
129
  assert self.num_threads == tiled_mma.size
130
  return tiled_mma
131
 
@@ -148,22 +148,22 @@ class FlashAttentionBackwardPostprocess:
148
  cute.make_layout(self.num_threads),
149
  cute.make_layout(async_copy_elems_accum),
150
  )
151
- num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4
152
- if const_expr(self.arch == 80):
153
  self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
154
  Float32, self.num_threads, num_s2r_copy_elems
155
  )
156
  self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
157
- elif const_expr(self.arch == 90):
158
  num_threads_per_warp_group = 128
159
- num_mma_warp_groups = self.num_threads // 128
160
  self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
161
  cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
162
- cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout
163
  cute.make_layout(128 // Float32.width), # val_layout
164
  )
165
  self.sdQaccum_layout = cute.make_layout(
166
- (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups)
167
  )
168
  else:
169
  self.dQ_reduce_ncol = 32
@@ -188,14 +188,18 @@ class FlashAttentionBackwardPostprocess:
188
  # then setting kBlockKSmem to 32 will cause "Static shape_div failure".
189
  # We want to treat it as 64 x 48, so kBlockKSmem should be 16.
190
  mma_shape_n = self.tiled_mma.get_tile_size(1)
191
- if const_expr(self.arch == 80):
192
  sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
193
  self.sdQ_layout = cute.tile_to_shape(
194
  sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
195
  )
196
- elif const_expr(self.arch == 90):
 
197
  self.sdQ_layout = sm90_utils.make_smem_layout(
198
- self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim)
 
 
 
199
  )
200
  else:
201
  # TODO: this is hard-coded for hdim 128
@@ -211,7 +215,8 @@ class FlashAttentionBackwardPostprocess:
211
  scale: cutlass.Float32,
212
  mCuSeqlensQ: Optional[cute.Tensor],
213
  mSeqUsedQ: Optional[cute.Tensor],
214
- stream: cuda.CUstream,
 
215
  ):
216
  # Get the data type and check if it is fp16 or bf16
217
  if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
@@ -305,7 +310,7 @@ class FlashAttentionBackwardPostprocess:
305
  smem = cutlass.utils.SmemAllocator()
306
  sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
307
  sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
308
- if const_expr(self.arch in [80, 90]):
309
  sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
310
  else:
311
  # extra stage dimension
@@ -343,10 +348,7 @@ class FlashAttentionBackwardPostprocess:
343
  mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
344
  head_dim = mdQ.shape[3]
345
  else:
346
- if cutlass.const_expr(self.arch >= 90):
347
- padded_offset_q = seqlen.padded_offset_q
348
- else:
349
- padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
350
  mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
351
  mdQaccum_cur = cute.domain_offset(
352
  (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
@@ -371,7 +373,7 @@ class FlashAttentionBackwardPostprocess:
371
  seqlen_q = seqlen.seqlen_q
372
  seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
373
 
374
- if const_expr(self.arch == 100 and self.use_2cta_instrs):
375
  # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
376
  num_reduce_threads = self.num_threads
377
  thr_mma_dsk = tiled_mma.get_slice(tidx)
@@ -502,7 +504,7 @@ class FlashAttentionBackwardPostprocess:
502
  tile_shape = (self.tile_m, self.tile_hdim)
503
  acc = None
504
  tiled_copy_t2r = None
505
- if const_expr(self.arch in [80, 90]):
506
  acc_shape = tiled_mma.partition_shape_C(
507
  tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
508
  )
@@ -531,7 +533,7 @@ class FlashAttentionBackwardPostprocess:
531
 
532
  # Step 3: Copy dQ from register to smem
533
  cute.arch.barrier() # make sure all threads have finished loading dQaccum
534
- if const_expr(self.arch in [80, 90]):
535
  copy_atom_r2s_dQ = utils.get_smem_store_atom(
536
  self.arch, self.dtype, transpose=self.dQ_swapAB
537
  )
@@ -553,7 +555,7 @@ class FlashAttentionBackwardPostprocess:
553
  )
554
  thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
555
  cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
556
- if const_expr(self.arch in [80, 90]):
557
  taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
558
  else:
559
  taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
 
2
  # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h
3
  # from Cutlass C++ to Cute-DSL.
4
  import math
5
+ from typing import Callable, Optional, Type
6
 
7
  import cuda.bindings.driver as cuda
8
 
 
36
  self,
37
  dtype: Type[cutlass.Numeric],
38
  head_dim: int,
39
+ arch: int,
40
  tile_m: int = 128,
41
  num_threads: int = 256,
42
  AtomLayoutMdQ: int = 1,
 
52
  """
53
  self.dtype = dtype
54
  self.tile_m = tile_m
55
+ assert arch // 10 in [8, 9, 10, 11, 12], (
56
+ "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x, 12.x) are supported"
57
  )
58
  self.arch = arch
59
  # padding head_dim to a multiple of 32 as k_block_size
 
63
  self.num_threads = num_threads
64
  self.AtomLayoutMdQ = AtomLayoutMdQ
65
  self.dQ_swapAB = dQ_swapAB
66
+ self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64
67
  self.cluster_size = cluster_size
68
 
69
  @staticmethod
 
89
  return True
90
 
91
  def _get_tiled_mma(self):
92
+ if const_expr(self.arch // 10 in [8, 12]):
93
  num_mma_warps = self.num_threads // 32
94
  atom_layout_dQ = (
95
  (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)
 
101
  atom_layout_dQ,
102
  permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),
103
  )
104
+ elif const_expr(self.arch // 10 == 9):
105
+ num_wg_mma = self.num_threads // 128
106
+ atom_layout_dQ = (self.AtomLayoutMdQ, num_wg_mma // self.AtomLayoutMdQ)
107
  tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
108
  tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(
109
  self.dtype,
 
125
  cta_group,
126
  (self.tile_m, self.tile_hdim),
127
  )
128
+ if const_expr(self.arch // 10 in [8, 9, 12]):
129
  assert self.num_threads == tiled_mma.size
130
  return tiled_mma
131
 
 
148
  cute.make_layout(self.num_threads),
149
  cute.make_layout(async_copy_elems_accum),
150
  )
151
+ num_s2r_copy_elems = 1 if const_expr(self.arch // 10 in [8, 12]) else 4
152
+ if const_expr(self.arch // 10 in [8, 12]):
153
  self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
154
  Float32, self.num_threads, num_s2r_copy_elems
155
  )
156
  self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)
157
+ elif const_expr(self.arch // 10 == 9):
158
  num_threads_per_warp_group = 128
159
+ num_wg_mma = self.num_threads // 128
160
  self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
161
  cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
162
+ cute.make_layout((num_threads_per_warp_group, num_wg_mma)), # thr_layout
163
  cute.make_layout(128 // Float32.width), # val_layout
164
  )
165
  self.sdQaccum_layout = cute.make_layout(
166
+ (self.tile_m * self.tile_hdim // num_wg_mma, num_wg_mma)
167
  )
168
  else:
169
  self.dQ_reduce_ncol = 32
 
188
  # then setting kBlockKSmem to 32 will cause "Static shape_div failure".
189
  # We want to treat it as 64 x 48, so kBlockKSmem should be 16.
190
  mma_shape_n = self.tiled_mma.get_tile_size(1)
191
+ if const_expr(self.arch // 10 in [8, 12]):
192
  sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)
193
  self.sdQ_layout = cute.tile_to_shape(
194
  sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)
195
  )
196
+ elif const_expr(self.arch // 10 == 9):
197
+ wg_d_dQ = num_wg_mma // self.AtomLayoutMdQ
198
  self.sdQ_layout = sm90_utils.make_smem_layout(
199
+ self.dtype,
200
+ LayoutEnum.ROW_MAJOR,
201
+ (self.tile_m, self.tile_hdim),
202
+ major_mode_size=self.tile_hdim // wg_d_dQ,
203
  )
204
  else:
205
  # TODO: this is hard-coded for hdim 128
 
215
  scale: cutlass.Float32,
216
  mCuSeqlensQ: Optional[cute.Tensor],
217
  mSeqUsedQ: Optional[cute.Tensor],
218
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
219
+ stream: cuda.CUstream = None,
220
  ):
221
  # Get the data type and check if it is fp16 or bf16
222
  if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):
 
310
  smem = cutlass.utils.SmemAllocator()
311
  sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)
312
  sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))
313
+ if const_expr(self.arch // 10 in [8, 9, 12]):
314
  sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)
315
  else:
316
  # extra stage dimension
 
348
  mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
349
  head_dim = mdQ.shape[3]
350
  else:
351
+ padded_offset_q = seqlen.padded_offset_q
 
 
 
352
  mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
353
  mdQaccum_cur = cute.domain_offset(
354
  (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
 
373
  seqlen_q = seqlen.seqlen_q
374
  seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
375
 
376
+ if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs):
377
  # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
378
  num_reduce_threads = self.num_threads
379
  thr_mma_dsk = tiled_mma.get_slice(tidx)
 
504
  tile_shape = (self.tile_m, self.tile_hdim)
505
  acc = None
506
  tiled_copy_t2r = None
507
+ if const_expr(self.arch // 10 in [8, 9, 12]):
508
  acc_shape = tiled_mma.partition_shape_C(
509
  tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]
510
  )
 
533
 
534
  # Step 3: Copy dQ from register to smem
535
  cute.arch.barrier() # make sure all threads have finished loading dQaccum
536
+ if const_expr(self.arch // 10 in [8, 9, 12]):
537
  copy_atom_r2s_dQ = utils.get_smem_store_atom(
538
  self.arch, self.dtype, transpose=self.dQ_swapAB
539
  )
 
555
  )
556
  thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)
557
  cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))
558
+ if const_expr(self.arch // 10 in [8, 9, 12]):
559
  taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)
560
  else:
561
  taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape
build/torch-cuda/flash_bwd_preprocess.py CHANGED
@@ -1,21 +1,32 @@
1
  # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
  # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
3
  # from Cutlass C++ to Cute-DSL.
 
 
 
 
 
 
 
 
 
 
4
  import math
5
  import operator
6
- from typing import Callable, Type, Optional, Literal
 
7
 
8
  import cuda.bindings.driver as cuda
9
 
10
  import cutlass
11
  import cutlass.cute as cute
12
- from cutlass import Float32
 
13
 
14
- from .quack import copy_utils
15
 
16
  from . import utils
17
- from .cute_dsl_utils import assume_tensor_aligned
18
- from .seqlen_info import SeqlenInfoQK
19
  from .quack.cute_dsl_utils import ParamsBase
20
  from .tile_scheduler import (
21
  SingleTileScheduler,
@@ -30,9 +41,8 @@ class FlashAttentionBackwardPreprocess:
30
  dtype: Type[cutlass.Numeric],
31
  head_dim: int,
32
  head_dim_v: int,
33
- arch: Literal[80, 90, 100],
34
- m_block_size: int = 128,
35
- num_threads: int = 128,
36
  ):
37
  """
38
  All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
@@ -40,14 +50,14 @@ class FlashAttentionBackwardPreprocess:
40
 
41
  :param head_dim: head dimension
42
  :type head_dim: int
43
- :param m_block_size: m block size
44
- :type m_block_size: int
45
  :param num_threads: number of threads
46
  :type num_threads: int
47
  """
 
48
  self.dtype = dtype
49
- self.m_block_size = m_block_size
50
- self.arch = arch
51
  # padding head_dim to a multiple of 32 as k_block_size
52
  hdim_multiple_of = 32
53
  self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
@@ -56,15 +66,15 @@ class FlashAttentionBackwardPreprocess:
56
  self.num_threads = num_threads
57
 
58
  @staticmethod
59
- def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool:
60
  """Check if the kernel can be implemented with the given parameters.
61
 
62
  :param dtype: data type
63
  :type dtype: cutlass.Numeric
64
  :param head_dim: head dimension
65
  :type head_dim: int
66
- :param m_block_size: m block size
67
- :type m_block_size: int
68
  :param num_threads: number of threads
69
  :type num_threads: int
70
 
@@ -77,7 +87,7 @@ class FlashAttentionBackwardPreprocess:
77
  return False
78
  if num_threads % 32 != 0:
79
  return False
80
- if num_threads < m_block_size: # For multiplying lse with log2
81
  return False
82
  return True
83
 
@@ -105,7 +115,7 @@ class FlashAttentionBackwardPreprocess:
105
  universal_copy_bits = 128
106
  num_copy_elems_dQaccum = universal_copy_bits // Float32.width
107
  assert (
108
- self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum
109
  ) % self.num_threads == 0
110
  self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
111
  Float32, self.num_threads, num_copy_elems_dQaccum
@@ -114,38 +124,53 @@ class FlashAttentionBackwardPreprocess:
114
  @cute.jit
115
  def __call__(
116
  self,
117
- mO: cute.Tensor,
118
- mdO: cute.Tensor,
119
- mdPsum: cute.Tensor,
120
- mLSE: Optional[cute.Tensor],
121
- mLSElog2: Optional[cute.Tensor],
 
122
  mdQaccum: Optional[cute.Tensor],
123
- mCuSeqlensQ: Optional[cute.Tensor],
124
- mSeqUsedQ: Optional[cute.Tensor],
125
- stream: cuda.CUstream,
 
 
126
  ):
127
  # Get the data type and check if it is fp16 or bf16
128
- if cutlass.const_expr(not (mO.element_type == mdO.element_type)):
129
  raise TypeError("All tensors must have the same data type")
130
- if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
131
  raise TypeError("Only Float16 or BFloat16 is supported")
132
- if cutlass.const_expr(mdPsum.element_type not in [Float32]):
133
- raise TypeError("dPsum tensor must be Float32")
134
- if cutlass.const_expr(mdQaccum is not None):
135
- if cutlass.const_expr(mdQaccum.element_type not in [Float32]):
136
  raise TypeError("dQaccum tensor must be Float32")
137
- if cutlass.const_expr(mLSE is not None):
138
  assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
139
- if cutlass.const_expr(mLSE.element_type not in [Float32]):
140
  raise TypeError("LSE tensor must be Float32")
141
- if cutlass.const_expr(mLSElog2.element_type not in [Float32]):
142
  raise TypeError("LSElog2 tensor must be Float32")
143
-
144
- mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)]
 
145
 
146
  self._setup_attributes()
147
 
148
- if cutlass.const_expr(mCuSeqlensQ is not None):
 
 
 
 
 
 
 
 
 
 
 
149
  TileScheduler = SingleTileVarlenScheduler
150
  num_head = mO.shape[1]
151
  num_batch = mCuSeqlensQ.shape[0] - 1
@@ -155,7 +180,7 @@ class FlashAttentionBackwardPreprocess:
155
  num_batch = mO.shape[0]
156
 
157
  tile_sched_args = TileSchedulerArguments(
158
- num_block=cute.ceil_div(mO.shape[1], self.m_block_size),
159
  num_head=num_head,
160
  num_batch=num_batch,
161
  num_splits=1,
@@ -163,7 +188,7 @@ class FlashAttentionBackwardPreprocess:
163
  headdim=0,
164
  headdim_v=mO.shape[2],
165
  total_q=mO.shape[0],
166
- tile_shape_mn=(self.m_block_size, 1),
167
  mCuSeqlensQ=mCuSeqlensQ,
168
  mSeqUsedQ=mSeqUsedQ,
169
  )
@@ -174,12 +199,13 @@ class FlashAttentionBackwardPreprocess:
174
  self.kernel(
175
  mO,
176
  mdO,
177
- mdPsum,
178
  mLSE,
179
  mLSElog2,
180
  mdQaccum,
181
  mCuSeqlensQ,
182
  mSeqUsedQ,
 
183
  self.gmem_tiled_copy_O,
184
  self.gmem_tiled_copy_dQaccum,
185
  tile_sched_params,
@@ -188,6 +214,7 @@ class FlashAttentionBackwardPreprocess:
188
  grid=grid_dim,
189
  block=[self.num_threads, 1, 1],
190
  stream=stream,
 
191
  )
192
 
193
  @cute.kernel
@@ -195,12 +222,13 @@ class FlashAttentionBackwardPreprocess:
195
  self,
196
  mO: cute.Tensor,
197
  mdO: cute.Tensor,
198
- mdPsum: cute.Tensor,
199
  mLSE: Optional[cute.Tensor],
200
  mLSElog2: Optional[cute.Tensor],
201
  mdQaccum: Optional[cute.Tensor],
202
  mCuSeqlensQ: Optional[cute.Tensor],
203
  mSeqUsedQ: Optional[cute.Tensor],
 
204
  gmem_tiled_copy_O: cute.TiledCopy,
205
  gmem_tiled_copy_dQaccum: cute.TiledCopy,
206
  tile_sched_params: ParamsBase,
@@ -217,145 +245,106 @@ class FlashAttentionBackwardPreprocess:
217
  # ///////////////////////////////////////////////////////////////////////////////
218
  # Get the appropriate tiles for this thread block.
219
  # ///////////////////////////////////////////////////////////////////////////////
220
- seqlen = SeqlenInfoQK.create(
221
- batch_idx,
222
- mO.shape[1],
223
- 0,
224
- mCuSeqlensQ=mCuSeqlensQ,
225
- mCuSeqlensK=None,
226
- mSeqUsedQ=mSeqUsedQ,
227
- mSeqUsedK=None,
228
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
231
- mO_cur = mO[batch_idx, None, head_idx, None]
232
- mdO_cur = mdO[batch_idx, None, head_idx, None]
233
- mdPsum_cur = mdPsum[batch_idx, head_idx, None]
234
- headdim_v = mO.shape[3]
235
- else:
236
- mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
237
- mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
238
-
239
- padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
240
- if cutlass.const_expr(self.arch >= 90):
241
- padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
242
- mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
243
- headdim_v = mO.shape[2]
244
-
245
- blkOdO_shape = (self.m_block_size, self.head_dim_v_padded)
246
- # (m_block_size, head_dim_v)
247
- gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0))
248
- gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0))
249
-
250
  gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
251
  # (CPY_Atom, CPY_M, CPY_K)
252
  tOgO = gmem_thr_copy_O.partition_S(gO)
253
  tOgdO = gmem_thr_copy_O.partition_S(gdO)
254
-
255
- # ///////////////////////////////////////////////////////////////////////////////
256
- # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
257
- # of tile_shape
258
- # ///////////////////////////////////////////////////////////////////////////////
259
- # Construct identity layout for KV
260
- cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
261
  tOcO = gmem_thr_copy_O.partition_S(cO)
262
  t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
263
- tOpO = utils.predicate_k(tOcO, limit=headdim_v)
264
- tOpdO = utils.predicate_k(tOcO, limit=headdim_v)
265
-
266
- seqlen_q = seqlen.seqlen_q
267
- seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size)
268
-
269
- if cutlass.const_expr(mLSE is not None):
270
- if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
271
- mLSE_cur = mLSE[batch_idx, head_idx, None]
272
- else:
273
- mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
274
-
275
- gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
276
- lse = Float32.inf
277
- if tidx < seqlen_q - m_block * self.m_block_size:
278
- lse = gLSE[tidx]
279
-
280
- tOrO = cute.make_fragment_like(tOgO)
281
- tOrdO = cute.make_fragment_like(tOgdO)
282
- assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
283
- assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
284
- assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
285
  for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
286
- # Instead of using tOcO, we using t0OcO and subtract the offset from the limit
287
- # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time.
288
- if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]:
289
- cute.copy(
290
- gmem_thr_copy_O,
291
- tOgO[None, m, None],
292
- tOrO[None, m, None],
293
- pred=tOpO[None, m, None]
294
- if cutlass.const_expr(self.check_hdim_v_oob)
295
- else None,
296
- )
297
- cute.copy(
298
- gmem_thr_copy_O,
299
- tOgdO[None, m, None],
300
- tOrdO[None, m, None],
301
- pred=tOpdO[None, m, None]
302
- if cutlass.const_expr(self.check_hdim_v_oob)
303
- else None,
304
- )
305
  # Sum across the "k" dimension
306
- dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
307
  cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
308
  )
309
  threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
310
  assert cute.arch.WARP_SIZE % threads_per_row == 0
311
- dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row)
312
- dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32)
313
- dP_sum.store(dpsum)
314
-
315
- # Write dPsum from rmem -> gmem
316
- gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,))
317
- # Only the thread corresponding to column 0 writes out the dPsum to gmem
 
 
 
 
 
 
318
  if tOcO[0, 0, 0][1] == 0:
319
- for m in cutlass.range(cute.size(dP_sum), unroll_full=True):
320
  row = tOcO[0, m, 0][0]
321
- gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0
 
 
 
 
 
322
 
323
  # Clear dQaccum
324
- if cutlass.const_expr(mdQaccum is not None):
325
- if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
326
- mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
327
- else:
328
- mdQaccum_cur = cute.domain_offset(
329
- (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
330
- )
331
-
332
- # HACK: Compiler doesn't seem to recognize that padding
333
- # by padded_offset_q * self.head_dim_padded keeps alignment
334
- # since statically divisible by 4
335
-
336
- mdQaccum_cur_ptr = cute.make_ptr(
337
- dtype=mdQaccum_cur.element_type,
338
- value=mdQaccum_cur.iterator.toint(),
339
- mem_space=mdQaccum_cur.iterator.memspace,
340
- assumed_align=mdQaccum.iterator.alignment,
341
- )
342
- mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)
343
-
344
- blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
345
  gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
346
  gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
347
  tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
348
- zero = cute.make_fragment_like(tdQgdQaccum)
349
  zero.fill(0.0)
350
  cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
351
 
352
- if cutlass.const_expr(mLSE is not None):
353
- if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
354
- mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
355
- else:
356
- mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
357
-
358
- gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
359
  LOG2_E = math.log2(math.e)
360
- if tidx < seqlen_q_rounded - m_block * self.m_block_size:
361
  gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
 
1
  # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
  # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
3
  # from Cutlass C++ to Cute-DSL.
4
+ #
5
+ # Computes D_i = (dO_i * O_i).sum(dim=-1), optionally adjusted for LSE gradient:
6
+ # D'_i = D_i - dLSE_i
7
+ # This works because in the backward pass:
8
+ # dS_ij = P_ij * (dP_ij - D_i) [standard]
9
+ # When LSE is differentiable, d(loss)/d(S_ij) gets an extra term dLSE_i * P_ij
10
+ # (since d(LSE_i)/d(S_ij) = P_ij), giving:
11
+ # dS_ij = P_ij * (dP_ij - D_i) + dLSE_i * P_ij
12
+ # = P_ij * (dP_ij - (D_i - dLSE_i))
13
+ # So the main backward kernel is unchanged; we just replace D with D' = D - dLSE here.
14
  import math
15
  import operator
16
+ from functools import partial
17
+ from typing import Callable, Type, Optional
18
 
19
  import cuda.bindings.driver as cuda
20
 
21
  import cutlass
22
  import cutlass.cute as cute
23
+ from cutlass import Float32, const_expr
24
+ from cutlass.cutlass_dsl import Arch, BaseDSL
25
 
26
+ from .quack import copy_utils, layout_utils
27
 
28
  from . import utils
29
+ from .seqlen_info import SeqlenInfo
 
30
  from .quack.cute_dsl_utils import ParamsBase
31
  from .tile_scheduler import (
32
  SingleTileScheduler,
 
41
  dtype: Type[cutlass.Numeric],
42
  head_dim: int,
43
  head_dim_v: int,
44
+ tile_m: int = 128,
45
+ num_threads: int = 256,
 
46
  ):
47
  """
48
  All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
 
50
 
51
  :param head_dim: head dimension
52
  :type head_dim: int
53
+ :param tile_m: m block size
54
+ :type tile_m: int
55
  :param num_threads: number of threads
56
  :type num_threads: int
57
  """
58
+ self.use_pdl = BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90a
59
  self.dtype = dtype
60
+ self.tile_m = tile_m
 
61
  # padding head_dim to a multiple of 32 as k_block_size
62
  hdim_multiple_of = 32
63
  self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
 
66
  self.num_threads = num_threads
67
 
68
  @staticmethod
69
+ def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:
70
  """Check if the kernel can be implemented with the given parameters.
71
 
72
  :param dtype: data type
73
  :type dtype: cutlass.Numeric
74
  :param head_dim: head dimension
75
  :type head_dim: int
76
+ :param tile_m: m block size
77
+ :type tile_m: int
78
  :param num_threads: number of threads
79
  :type num_threads: int
80
 
 
87
  return False
88
  if num_threads % 32 != 0:
89
  return False
90
+ if num_threads < tile_m: # For multiplying lse with log2
91
  return False
92
  return True
93
 
 
115
  universal_copy_bits = 128
116
  num_copy_elems_dQaccum = universal_copy_bits // Float32.width
117
  assert (
118
+ self.tile_m * self.head_dim_padded // num_copy_elems_dQaccum
119
  ) % self.num_threads == 0
120
  self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(
121
  Float32, self.num_threads, num_copy_elems_dQaccum
 
124
  @cute.jit
125
  def __call__(
126
  self,
127
+ mO: cute.Tensor, # (batch, seqlen, nheads, head_dim_v) or (total_q, nheads, head_dim_v)
128
+ mdO: cute.Tensor, # same shape as mO
129
+ mPdPsum: cute.Tensor, # (batch, nheads, seqlen_padded) or (nheads, total_q_padded)
130
+ mLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q)
131
+ mLSElog2: Optional[cute.Tensor], # same shape as mPdPsum
132
+ # (batch, nheads, seqlen_padded * head_dim_v) or (nheads, total_q_padded * head_dim_v)
133
  mdQaccum: Optional[cute.Tensor],
134
+ mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,)
135
+ mSeqUsedQ: Optional[cute.Tensor], # (batch,)
136
+ mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q)
137
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
138
+ stream: cuda.CUstream = None,
139
  ):
140
  # Get the data type and check if it is fp16 or bf16
141
+ if const_expr(not (mO.element_type == mdO.element_type)):
142
  raise TypeError("All tensors must have the same data type")
143
+ if const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):
144
  raise TypeError("Only Float16 or BFloat16 is supported")
145
+ if const_expr(mPdPsum.element_type not in [Float32]):
146
+ raise TypeError("PdPsum tensor must be Float32")
147
+ if const_expr(mdQaccum is not None):
148
+ if const_expr(mdQaccum.element_type not in [Float32]):
149
  raise TypeError("dQaccum tensor must be Float32")
150
+ if const_expr(mLSE is not None):
151
  assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
152
+ if const_expr(mLSE.element_type not in [Float32]):
153
  raise TypeError("LSE tensor must be Float32")
154
+ if const_expr(mLSElog2.element_type not in [Float32]):
155
  raise TypeError("LSElog2 tensor must be Float32")
156
+ if const_expr(mdLSE is not None):
157
+ if const_expr(mdLSE.element_type not in [Float32]):
158
+ raise TypeError("dLSE tensor must be Float32")
159
 
160
  self._setup_attributes()
161
 
162
+ # (batch, nheads, seqlen) -> (seqlen, nheads, batch) or (total_q, nheads) -> (nheads, total_q)
163
+ transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
164
+ mPdPsum = layout_utils.select(mPdPsum, transpose)
165
+ if const_expr(mLSE is not None):
166
+ mLSE = layout_utils.select(mLSE, transpose)
167
+ mLSElog2 = layout_utils.select(mLSElog2, transpose)
168
+ if const_expr(mdLSE is not None):
169
+ mdLSE = layout_utils.select(mdLSE, transpose)
170
+ if const_expr(mdQaccum is not None):
171
+ mdQaccum = layout_utils.select(mdQaccum, transpose)
172
+
173
+ if const_expr(mCuSeqlensQ is not None):
174
  TileScheduler = SingleTileVarlenScheduler
175
  num_head = mO.shape[1]
176
  num_batch = mCuSeqlensQ.shape[0] - 1
 
180
  num_batch = mO.shape[0]
181
 
182
  tile_sched_args = TileSchedulerArguments(
183
+ num_block=cute.ceil_div(mO.shape[1], self.tile_m),
184
  num_head=num_head,
185
  num_batch=num_batch,
186
  num_splits=1,
 
188
  headdim=0,
189
  headdim_v=mO.shape[2],
190
  total_q=mO.shape[0],
191
+ tile_shape_mn=(self.tile_m, 1),
192
  mCuSeqlensQ=mCuSeqlensQ,
193
  mSeqUsedQ=mSeqUsedQ,
194
  )
 
199
  self.kernel(
200
  mO,
201
  mdO,
202
+ mPdPsum,
203
  mLSE,
204
  mLSElog2,
205
  mdQaccum,
206
  mCuSeqlensQ,
207
  mSeqUsedQ,
208
+ mdLSE,
209
  self.gmem_tiled_copy_O,
210
  self.gmem_tiled_copy_dQaccum,
211
  tile_sched_params,
 
214
  grid=grid_dim,
215
  block=[self.num_threads, 1, 1],
216
  stream=stream,
217
+ use_pdl=self.use_pdl,
218
  )
219
 
220
  @cute.kernel
 
222
  self,
223
  mO: cute.Tensor,
224
  mdO: cute.Tensor,
225
+ mPdPsum: cute.Tensor,
226
  mLSE: Optional[cute.Tensor],
227
  mLSElog2: Optional[cute.Tensor],
228
  mdQaccum: Optional[cute.Tensor],
229
  mCuSeqlensQ: Optional[cute.Tensor],
230
  mSeqUsedQ: Optional[cute.Tensor],
231
+ mdLSE: Optional[cute.Tensor],
232
  gmem_tiled_copy_O: cute.TiledCopy,
233
  gmem_tiled_copy_dQaccum: cute.TiledCopy,
234
  tile_sched_params: ParamsBase,
 
245
  # ///////////////////////////////////////////////////////////////////////////////
246
  # Get the appropriate tiles for this thread block.
247
  # ///////////////////////////////////////////////////////////////////////////////
248
+ seqlen = SeqlenInfo.create(
249
+ batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m
 
 
 
 
 
 
250
  )
251
+ mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None]
252
+ mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None]
253
+ mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[
254
+ None, head_idx
255
+ ]
256
+ headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1]
257
+ seqlen_q = seqlen.seqlen
258
+ seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)
259
+ seqlen_limit = seqlen_q - m_block * self.tile_m
260
+
261
+ lse = None
262
+ if const_expr(mLSE is not None):
263
+ mLSE_cur = seqlen.offset_batch(mLSE, batch_idx, dim=2)[None, head_idx]
264
+ gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
265
+ lse = Float32.inf
266
+ if tidx < seqlen_limit:
267
+ lse = gLSE[tidx]
268
 
269
+ blk_shape = (self.tile_m, self.head_dim_v_padded)
270
+ gO = cute.local_tile(mO_cur, blk_shape, (m_block, 0))
271
+ gdO = cute.local_tile(mdO_cur, blk_shape, (m_block, 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
273
  # (CPY_Atom, CPY_M, CPY_K)
274
  tOgO = gmem_thr_copy_O.partition_S(gO)
275
  tOgdO = gmem_thr_copy_O.partition_S(gdO)
276
+ cO = cute.make_identity_tensor(blk_shape)
 
 
 
 
 
 
277
  tOcO = gmem_thr_copy_O.partition_S(cO)
278
  t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)
279
+ tOpO = None
280
+ if const_expr(self.check_hdim_v_oob):
281
+ tOpO = copy_utils.predicate_k(tOcO, limit=headdim_v)
282
+ # Each copy will use the same predicate
283
+ copy = partial(copy_utils.copy, pred=tOpO)
284
+
285
+ tOrO = cute.make_rmem_tensor_like(tOgO)
286
+ tOrdO = cute.make_rmem_tensor_like(tOgdO)
287
+ if const_expr(self.check_hdim_v_oob):
288
+ tOrO.fill(0.0)
289
+ tOrdO.fill(0.0)
290
+ assert tOgO.shape == tOgdO.shape
 
 
 
 
 
 
 
 
 
 
291
  for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
292
+ # Instead of using tOcO, we using t0OcO and subtract the offset from the limit.
293
+ # This is bc the entries of t0OcO are known at compile time.
294
+ if t0OcO[0, m, 0][0] < seqlen_limit - tOcO[0][0]:
295
+ copy(tOgO[None, m, None], tOrO[None, m, None])
296
+ copy(tOgdO[None, m, None], tOrdO[None, m, None])
297
+ # O and dO loads are done; signal that the next kernel can start.
298
+ # Correctness is ensured by griddepcontrol_wait() in bwd_sm90 before it reads our outputs.
299
+ if const_expr(self.use_pdl):
300
+ cute.arch.griddepcontrol_launch_dependents()
 
 
 
 
 
 
 
 
 
 
301
  # Sum across the "k" dimension
302
+ pdpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(
303
  cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
304
  )
305
  threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]
306
  assert cute.arch.WARP_SIZE % threads_per_row == 0
307
+ pdpsum = utils.warp_reduce(pdpsum, operator.add, width=threads_per_row)
308
+ PdP_sum = cute.make_rmem_tensor(cute.size(tOrO, mode=[1]), Float32)
309
+ PdP_sum.store(pdpsum)
310
+
311
+ # If dLSE is provided, compute D' = D - dLSE (see module docstring for derivation).
312
+ gdLSE = None
313
+ if const_expr(mdLSE is not None):
314
+ mdLSE_cur = seqlen.offset_batch(mdLSE, batch_idx, dim=2)[None, head_idx]
315
+ gdLSE = cute.local_tile(mdLSE_cur, (self.tile_m,), (m_block,))
316
+
317
+ # Write PdPsum from rmem -> gmem
318
+ gPdPsum = cute.local_tile(mPdPsum_cur, (self.tile_m,), (m_block,))
319
+ # Only the thread corresponding to column 0 writes out the PdPsum to gmem
320
  if tOcO[0, 0, 0][1] == 0:
321
+ for m in cutlass.range(cute.size(PdP_sum), unroll_full=True):
322
  row = tOcO[0, m, 0][0]
323
+ PdPsum_val = 0.0
324
+ if row < seqlen_limit:
325
+ PdPsum_val = PdP_sum[m]
326
+ if const_expr(mdLSE is not None):
327
+ PdPsum_val -= gdLSE[row]
328
+ gPdPsum[row] = PdPsum_val
329
 
330
  # Clear dQaccum
331
+ if const_expr(mdQaccum is not None):
332
+ mdQaccum_cur = seqlen.offset_batch(
333
+ mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded
334
+ )[None, head_idx]
335
+ blkdQaccum_shape = (self.tile_m * self.head_dim_padded,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))
337
  gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
338
  tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
339
+ zero = cute.make_rmem_tensor_like(tdQgdQaccum)
340
  zero.fill(0.0)
341
  cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)
342
 
343
+ if const_expr(mLSE is not None):
344
+ mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[
345
+ None, head_idx
346
+ ]
347
+ gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,))
 
 
348
  LOG2_E = math.log2(math.e)
349
+ if tidx < seqlen_q_rounded - m_block * self.tile_m:
350
  gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0
build/torch-cuda/flash_bwd_sm100.py CHANGED
@@ -84,7 +84,6 @@ class FlashAttentionBackwardSm100:
84
  self.use_2cta_instrs = bool(
85
  use_2cta_instrs
86
  and cluster_size == 2
87
- and not is_local
88
  and score_mod is None
89
  and score_mod_bwd is None
90
  and mask_mod is None
@@ -453,7 +452,6 @@ class FlashAttentionBackwardSm100:
453
  mdK: cute.Tensor,
454
  mdV: cute.Tensor,
455
  softmax_scale: Float32,
456
- stream: cuda.CUstream,
457
  mCuSeqlensQ: Optional[cute.Tensor] = None,
458
  mCuSeqlensK: Optional[cute.Tensor] = None,
459
  mSeqUsedQ: Optional[cute.Tensor] = None,
@@ -467,6 +465,8 @@ class FlashAttentionBackwardSm100:
467
  aux_tensors: Optional[list] = None,
468
  # Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
469
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
 
 
470
  ):
471
  self.q_dtype = mQ.element_type
472
  self.k_dtype = mK.element_type
@@ -927,10 +927,6 @@ class FlashAttentionBackwardSm100:
927
  "2-CTA mode does not support block sparsity. "
928
  "Please create kernel with use_2cta_instrs=False for block sparse attention."
929
  )
930
- assert window_size_left is None and window_size_right is None, (
931
- "2-CTA mode does not support window attention. "
932
- "Please create kernel with use_2cta_instrs=False for window attention."
933
- )
934
  # 2-CTA: 231424 and 1-CTA: 232448
935
  # print("SMEM: ", self.shared_storage.size_in_bytes())
936
  if const_expr(self.use_block_sparsity or aux_tensors is not None):
@@ -3143,6 +3139,8 @@ class FlashAttentionBackwardSm100:
3143
  with cute.arch.elect_one():
3144
  pipeline_S_P.consumer_release(consumer_state_S_P_dP)
3145
  # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
 
 
3146
  pipeline_LSE.consumer_release(consumer_state_LSE)
3147
  consumer_state_LSE.advance()
3148
  # ---------------------------------------------
@@ -3253,6 +3251,8 @@ class FlashAttentionBackwardSm100:
3253
 
3254
  cute.arch.fence_view_async_shared()
3255
  self.compute_sync_barrier.arrive_and_wait()
 
 
3256
  pipeline_dPsum.consumer_release(consumer_state_dPsum)
3257
  consumer_state_dPsum.advance()
3258
  # when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred
@@ -3650,6 +3650,9 @@ class FlashAttentionBackwardSm100:
3650
  tile_scheduler.advance_to_next_work()
3651
  work_tile = tile_scheduler.get_current_work()
3652
 
 
 
 
3653
  @cute.jit
3654
  def epilogue_dKV(
3655
  self,
 
84
  self.use_2cta_instrs = bool(
85
  use_2cta_instrs
86
  and cluster_size == 2
 
87
  and score_mod is None
88
  and score_mod_bwd is None
89
  and mask_mod is None
 
452
  mdK: cute.Tensor,
453
  mdV: cute.Tensor,
454
  softmax_scale: Float32,
 
455
  mCuSeqlensQ: Optional[cute.Tensor] = None,
456
  mCuSeqlensK: Optional[cute.Tensor] = None,
457
  mSeqUsedQ: Optional[cute.Tensor] = None,
 
465
  aux_tensors: Optional[list] = None,
466
  # Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
467
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
468
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
469
+ stream: cuda.CUstream = None,
470
  ):
471
  self.q_dtype = mQ.element_type
472
  self.k_dtype = mK.element_type
 
927
  "2-CTA mode does not support block sparsity. "
928
  "Please create kernel with use_2cta_instrs=False for block sparse attention."
929
  )
 
 
 
 
930
  # 2-CTA: 231424 and 1-CTA: 232448
931
  # print("SMEM: ", self.shared_storage.size_in_bytes())
932
  if const_expr(self.use_block_sparsity or aux_tensors is not None):
 
3139
  with cute.arch.elect_one():
3140
  pipeline_S_P.consumer_release(consumer_state_S_P_dP)
3141
  # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
3142
+ # Normally we'd need syncwarp here since only 1 thread will signal in
3143
+ # consumer_release, but we already have the self.compute_sync_barrier before this
3144
  pipeline_LSE.consumer_release(consumer_state_LSE)
3145
  consumer_state_LSE.advance()
3146
  # ---------------------------------------------
 
3251
 
3252
  cute.arch.fence_view_async_shared()
3253
  self.compute_sync_barrier.arrive_and_wait()
3254
+ # Normally we'd need syncwarp here since only 1 thread will signal in
3255
+ # consumer_release, but we already have the self.compute_sync_barrier before this
3256
  pipeline_dPsum.consumer_release(consumer_state_dPsum)
3257
  consumer_state_dPsum.advance()
3258
  # when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred
 
3650
  tile_scheduler.advance_to_next_work()
3651
  work_tile = tile_scheduler.get_current_work()
3652
 
3653
+ if const_expr(not self.deterministic):
3654
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
3655
+
3656
  @cute.jit
3657
  def epilogue_dKV(
3658
  self,
build/torch-cuda/flash_bwd_sm120.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # SM120 (Blackwell GeForce / DGX Spark) backward pass.
3
+ #
4
+ # SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has
5
+ # a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses
6
+ # FlashAttentionBackwardSm80 and overrides the SMEM capacity check accordingly.
7
+
8
+ import cutlass
9
+ import cutlass.utils as utils_basic
10
+
11
+ from .flash_bwd import FlashAttentionBackwardSm80
12
+
13
+
14
+ class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80):
15
+ @staticmethod
16
+ def can_implement(
17
+ dtype,
18
+ head_dim,
19
+ head_dim_v,
20
+ m_block_size,
21
+ n_block_size,
22
+ num_stages_Q,
23
+ num_stages_dO,
24
+ num_threads,
25
+ is_causal,
26
+ V_in_regs=False,
27
+ ) -> bool:
28
+ """Check if the kernel can be implemented on SM120.
29
+
30
+ Same logic as SM80 but uses SM120's shared memory capacity (99 KB).
31
+ """
32
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
33
+ return False
34
+ if head_dim % 8 != 0:
35
+ return False
36
+ if head_dim_v % 8 != 0:
37
+ return False
38
+ if n_block_size % 16 != 0:
39
+ return False
40
+ if num_threads % 32 != 0:
41
+ return False
42
+ # Shared memory usage: Q tile + dO tile + K tile + V tile
43
+ smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2
44
+ smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2
45
+ smem_usage_K = n_block_size * head_dim * 2
46
+ smem_usage_V = n_block_size * head_dim_v * 2
47
+ smem_usage_QV = (
48
+ (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)
49
+ )
50
+ smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K
51
+ # SM120 has 99 KB shared memory (vs 163 KB on SM80)
52
+ smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120")
53
+ if smem_usage > smem_capacity:
54
+ return False
55
+ return True
build/torch-cuda/flash_bwd_sm90.py CHANGED
@@ -24,7 +24,13 @@ from .seqlen_info import SeqlenInfoQK
24
  from .block_info import BlockInfo
25
  from . import pipeline
26
  from .quack.cute_dsl_utils import ParamsBase
27
- from .tile_scheduler import TileSchedulerArguments, SingleTileScheduler
 
 
 
 
 
 
28
  from .named_barrier import NamedBarrierBwd
29
  from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
30
  from .block_sparsity import BlockSparseTensors
@@ -46,6 +52,8 @@ class FlashAttentionBackwardSm90:
46
  head_dim_v: Optional[int] = None,
47
  qhead_per_kvhead: int = 1,
48
  is_causal: bool = False,
 
 
49
  tile_m: int = 64,
50
  tile_n: int = 128,
51
  Q_stage: int = 2,
@@ -64,6 +72,7 @@ class FlashAttentionBackwardSm90:
64
  mask_mod: cutlass.Constexpr | None = None,
65
  has_aux_tensors: cutlass.Constexpr = False,
66
  subtile_factor: cutlass.Constexpr[int] = 1,
 
67
  ):
68
  self.dtype = dtype
69
  # padding head_dim to a multiple of 16 as k_block_size
@@ -77,7 +86,8 @@ class FlashAttentionBackwardSm90:
77
  self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
78
  self.qhead_per_kvhead = qhead_per_kvhead
79
  self.is_causal = is_causal
80
- self.is_local = False
 
81
  self.tile_m = tile_m
82
  self.tile_n = tile_n
83
  self.num_threads = num_threads
@@ -92,23 +102,23 @@ class FlashAttentionBackwardSm90:
92
  self.AtomLayoutMSdP = AtomLayoutMSdP
93
  self.AtomLayoutNdKV = AtomLayoutNdKV
94
  self.AtomLayoutMdQ = AtomLayoutMdQ
95
- self.num_mma_warp_groups = (self.num_threads // 128) - 1
96
  self.mma_dkv_is_rs = (
97
  AtomLayoutMSdP == 1
98
- and AtomLayoutNdKV == self.num_mma_warp_groups
99
  and SdP_swapAB
100
  and not dKV_swapAB
101
  )
102
  self.V_in_regs = V_in_regs
 
103
  if qhead_per_kvhead > 1:
104
  assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
105
- assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups"
106
  # These are tuned for speed
107
  # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
108
  # them and then shuffle to get the value whenever we need? This can reduce register
109
  # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
110
  # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
111
- # TODO: impl these for hdim 64
112
  self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
113
  self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
114
 
@@ -124,6 +134,12 @@ class FlashAttentionBackwardSm90:
124
  else:
125
  self.vec_size: cutlass.Constexpr = 4
126
  self.qk_acc_dtype = Float32
 
 
 
 
 
 
127
 
128
  @staticmethod
129
  def can_implement(
@@ -182,32 +198,58 @@ class FlashAttentionBackwardSm90:
182
  assert mQ_type == self.dtype
183
 
184
  def _setup_attributes(self):
185
- self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [
186
- sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage)
187
- for shape, stage in [
188
- ((self.tile_m, self.tile_hdim), self.Q_stage),
189
- ((self.tile_n, self.tile_hdim), None),
190
- ((self.tile_n, self.tile_hdimv), None),
191
- ((self.tile_m, self.tile_hdimv), self.dO_stage),
192
- ((self.tile_m, self.tile_n), self.PdS_stage),
 
 
193
  ]
194
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  self.sdQaccum_layout = cute.make_layout(
196
- (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups)
197
  )
198
  # dQaccum R->S
199
  self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
200
  cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
201
  # thr_layout
202
- cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
203
  cute.make_layout(128 // Float32.width), # val_layout
204
  )
205
  # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
206
  # TODO: assert that sVaccum and sKaccum don't overflow smem
207
 
208
  def _get_tiled_mma(self):
 
209
  # S = Q @ K.T, dP = dO @ V.T
210
- atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP)
211
  tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
212
  tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
213
  self.dtype,
@@ -215,12 +257,11 @@ class FlashAttentionBackwardSm90:
215
  warpgroup.OperandMajorMode.K,
216
  warpgroup.OperandMajorMode.K,
217
  Float32,
218
- atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1])
219
- + (1,),
220
- tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1],
221
  )
222
  # dV = P.T @ dO, dK = dS.T @ Q
223
- atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV)
224
  tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
225
  tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
226
  tiled_mma_dK, tiled_mma_dV = [
@@ -232,9 +273,8 @@ class FlashAttentionBackwardSm90:
232
  else warpgroup.OperandMajorMode.K,
233
  warpgroup.OperandMajorMode.MN,
234
  Float32,
235
- atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1])
236
- + (1,),
237
- tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1],
238
  a_source=warpgroup.OperandSource.RMEM
239
  if self.mma_dkv_is_rs
240
  else warpgroup.OperandSource.SMEM,
@@ -242,7 +282,8 @@ class FlashAttentionBackwardSm90:
242
  for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
243
  ]
244
  # dQ = dS @ K
245
- atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ)
 
246
  tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
247
  tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
248
  self.dtype,
@@ -250,8 +291,8 @@ class FlashAttentionBackwardSm90:
250
  warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
251
  warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
252
  Float32,
253
- atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,),
254
- tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],
255
  )
256
  return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
257
 
@@ -305,7 +346,6 @@ class FlashAttentionBackwardSm90:
305
  mdK: cute.Tensor,
306
  mdV: cute.Tensor,
307
  softmax_scale: Float32,
308
- stream: cuda.CUstream,
309
  mCuSeqlensQ: Optional[cute.Tensor] = None,
310
  mCuSeqlensK: Optional[cute.Tensor] = None,
311
  mSeqUsedQ: Optional[cute.Tensor] = None,
@@ -318,10 +358,13 @@ class FlashAttentionBackwardSm90:
318
  mdV_semaphore: Optional[cute.Tensor] = None,
319
  aux_tensors: Optional[list] = None,
320
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
 
 
321
  ):
322
- assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (
323
- "determinism not supported yet for Sm90"
324
- )
 
325
 
326
  self._check_type(
327
  *(
@@ -330,23 +373,36 @@ class FlashAttentionBackwardSm90:
330
  )
331
  )
332
 
 
 
333
  mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
334
  assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
335
  ]
336
 
337
- layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
338
- mQ, mK, mV, mdO = [layout_utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)]
 
 
 
 
 
339
  if const_expr(self.qhead_per_kvhead == 1):
340
- mdK, mdV = [layout_utils.select(t, layout_transpose) for t in (mdK, mdV)]
341
  else:
342
- accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b)
 
343
  mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]
344
- LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b)
 
345
  mLSE, mdPsum, mdQaccum = [
346
  layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
347
  ]
348
 
349
  tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
 
 
 
 
350
 
351
  self.num_mma_threads = tiled_mma_SdP.size
352
  assert self.num_mma_threads + 128 == self.num_threads
@@ -354,10 +410,25 @@ class FlashAttentionBackwardSm90:
354
  self.num_threads_per_warp_group = 128
355
  self.num_producer_threads = 32
356
 
357
- self.num_mma_regs = 240
358
- self.num_producer_regs = 24
359
- # self.num_mma_regs = 232
360
- # self.num_producer_regs = 40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  self._setup_attributes()
363
  SharedStorage = self._get_shared_storage_cls()
@@ -374,7 +445,7 @@ class FlashAttentionBackwardSm90:
374
  self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
375
  self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
376
  self.tma_copy_bytes["dQ"] = (
377
- self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups
378
  )
379
  self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
380
  self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
@@ -404,38 +475,59 @@ class FlashAttentionBackwardSm90:
404
  (self.tile_m, self.tile_hdimv),
405
  )
406
  if const_expr(self.qhead_per_kvhead == 1):
 
 
 
 
 
 
 
 
 
 
407
  tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
408
  cpasync.CopyBulkTensorTileS2GOp(),
409
- mdK,
410
  cute.select(self.sK_layout, mode=[0, 1]),
411
  (self.tile_n, self.tile_hdim),
412
  )
413
  tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
414
  cpasync.CopyBulkTensorTileS2GOp(),
415
- mdV,
416
  cute.select(self.sV_layout, mode=[0, 1]),
417
  (self.tile_n, self.tile_hdimv),
418
  )
419
  else:
420
  tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
421
 
422
- TileScheduler = SingleTileScheduler
 
 
 
 
 
 
423
  tile_sched_args = TileSchedulerArguments(
424
  cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
425
  cute.size(mQ.shape[2]),
426
- cute.size(mQ.shape[3]),
 
 
427
  1, # num_splits
428
- cute.size(mK.shape[0]),
429
- mQ.shape[1],
430
- mV.shape[1],
431
- total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
432
- tile_shape_mn=(self.tile_m, self.tile_n),
433
- mCuSeqlensQ=None,
434
- mSeqUsedQ=None,
 
 
435
  qhead_per_kvhead_packgqa=1,
436
  element_size=self.dtype.width // 8,
437
  is_persistent=False,
438
- lpt=False,
 
439
  )
440
 
441
  tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
@@ -461,6 +553,11 @@ class FlashAttentionBackwardSm90:
461
 
462
  self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
463
 
 
 
 
 
 
464
  self.kernel(
465
  tma_tensor_Q,
466
  tma_tensor_K,
@@ -477,6 +574,10 @@ class FlashAttentionBackwardSm90:
477
  mLSE,
478
  mdPsum,
479
  mdQaccum,
 
 
 
 
480
  self.sQ_layout,
481
  self.sK_layout,
482
  self.sV_layout,
@@ -497,11 +598,15 @@ class FlashAttentionBackwardSm90:
497
  fastdiv_mods,
498
  blocksparse_tensors,
499
  qhead_per_kvhead_divmod,
 
 
 
500
  ).launch(
501
  grid=grid_dim,
502
  block=[self.num_threads, 1, 1],
503
  stream=stream,
504
  min_blocks_per_mp=1,
 
505
  )
506
 
507
  @cute.kernel
@@ -522,6 +627,10 @@ class FlashAttentionBackwardSm90:
522
  mLSE: cute.Tensor,
523
  mdPsum: cute.Tensor,
524
  mdQaccum: cute.Tensor,
 
 
 
 
525
  sQ_layout: cute.ComposedLayout,
526
  sK_layout: cute.ComposedLayout,
527
  sV_layout: cute.ComposedLayout,
@@ -542,15 +651,17 @@ class FlashAttentionBackwardSm90:
542
  fastdiv_mods=(None, None),
543
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
544
  qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
 
 
 
545
  ):
546
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
547
 
548
  # prefetch TMA descriptors
549
  if warp_idx == 0:
550
- cpasync.prefetch_descriptor(tma_atom_Q)
551
- cpasync.prefetch_descriptor(tma_atom_K)
552
- cpasync.prefetch_descriptor(tma_atom_V)
553
- cpasync.prefetch_descriptor(tma_atom_dO)
554
 
555
  smem = cutlass.utils.SmemAllocator()
556
  storage = smem.allocate(SharedStorage)
@@ -604,25 +715,27 @@ class FlashAttentionBackwardSm90:
604
  self.is_causal,
605
  self.is_local,
606
  False, # is_split_kv
607
- None,
608
- None,
609
  qhead_per_kvhead_packgqa=1,
610
  )
611
  SeqlenInfoCls = partial(
612
  SeqlenInfoQK.create,
613
  seqlen_q_static=mQ.shape[0],
614
  seqlen_k_static=mK.shape[0],
615
- mCuSeqlensQ=None,
616
- mCuSeqlensK=None,
617
- mSeqUsedQ=None,
618
- mSeqUsedK=None,
 
 
619
  )
620
  AttentionMaskCls = partial(
621
  AttentionMask,
622
  self.tile_m,
623
  self.tile_n,
624
- window_size_left=None,
625
- window_size_right=None,
626
  swap_AB=self.SdP_swapAB,
627
  )
628
  TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
@@ -663,12 +776,12 @@ class FlashAttentionBackwardSm90:
663
  TileSchedulerCls,
664
  SeqlenInfoCls,
665
  blocksparse_tensors,
 
666
  )
667
  else:
668
- cute.arch.setmaxregister_increase(self.num_mma_regs)
669
  tidx, _, _ = cute.arch.thread_idx()
670
  tidx = tidx - 128
671
- self.mma(
672
  tiled_mma_SdP,
673
  tiled_mma_dK,
674
  tiled_mma_dV,
@@ -702,6 +815,19 @@ class FlashAttentionBackwardSm90:
702
  blocksparse_tensors,
703
  qhead_per_kvhead_divmod,
704
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
  @cute.jit
707
  def load(
@@ -749,18 +875,22 @@ class FlashAttentionBackwardSm90:
749
  if const_expr(self.qhead_per_kvhead == 1)
750
  else head_idx // qhead_per_kvhead_divmod
751
  )
752
- mK_cur = mK[None, None, head_idx_kv, batch_idx]
 
753
  gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
754
- mV_cur = mV[None, None, head_idx_kv, batch_idx]
755
  gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
756
 
757
- mQ_cur = mQ[None, None, head_idx, batch_idx]
 
 
 
 
 
 
 
758
  gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
759
- mdO_cur = mdO[None, None, head_idx, batch_idx]
760
  gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
761
- mLSE_cur = mLSE[None, head_idx, batch_idx]
762
  gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
763
- mdPsum_cur = mdPsum[None, head_idx, batch_idx]
764
  gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
765
 
766
  load_K, _, _ = copy_utils.tma_get_copy_fn(
@@ -786,7 +916,10 @@ class FlashAttentionBackwardSm90:
786
 
787
  if const_expr(not self.use_block_sparsity):
788
  total_m_block_cnt = m_block_max - m_block_min
789
- process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
 
 
 
790
  else:
791
  total_m_block_cnt = get_total_q_block_count_bwd(
792
  blocksparse_tensors,
@@ -806,6 +939,8 @@ class FlashAttentionBackwardSm90:
806
  )
807
  load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
808
  load_Q(first_m_block, producer_state=producer_state_Q)
 
 
809
  load_LSE(first_m_block, producer_state=producer_state_Q)
810
  producer_state_dO_cur = (
811
  producer_state_dO
@@ -984,16 +1119,20 @@ class FlashAttentionBackwardSm90:
984
  fastdiv_mods=(None, None),
985
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
986
  qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
 
987
  ):
988
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
989
  warp_group_thread_layout = cute.make_layout(
990
- self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
991
  )
992
  thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
993
  wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
994
  wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
995
  wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
996
- wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx))
 
 
 
997
  # S = Q @ K.T
998
  shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
999
  _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
@@ -1039,23 +1178,43 @@ class FlashAttentionBackwardSm90:
1039
  # dQ = dS @ K
1040
  sKt = layout_utils.transpose_view(sK)
1041
  shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
1042
- _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(
1043
- wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB
1044
- )
1045
- mma_dsk_fn = partial(
1046
- gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB
1047
- )
 
 
 
 
 
 
 
1048
 
1049
- # Smem copy atom tiling
1050
  copy_P_r2s = None
 
1051
  if const_expr(sP is not None):
1052
  sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt
1053
  copy_P_r2s, _, _ = copy_utils.get_smem_store_C(
1054
- tiled_mma_SdP, sP_cpy, tidx, self.arch, transpose=self.SdP_swapAB
 
 
 
 
 
 
1055
  )
1056
  sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt
1057
  copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(
1058
- tiled_mma_SdP, sdS_cpy, tidx, self.arch, transpose=self.SdP_swapAB
 
 
 
 
 
 
1059
  )
1060
 
1061
  tLSEsLSE = layout_utils.mma_partition_C_vec(
@@ -1064,9 +1223,21 @@ class FlashAttentionBackwardSm90:
1064
  tLSEsdPsum = layout_utils.mma_partition_C_vec(
1065
  sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
1066
  )
1067
-
1068
- smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
1069
- tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
 
 
 
 
 
 
 
 
 
 
 
 
1070
 
1071
  PdS_barrier = cutlass.pipeline.NamedBarrier(
1072
  barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
@@ -1105,6 +1276,7 @@ class FlashAttentionBackwardSm90:
1105
  PdS_barrier=PdS_barrier,
1106
  # acc_dV=acc_dV,
1107
  # acc_dK=acc_dK,
 
1108
  )
1109
 
1110
  consumer_state_Q = cutlass.pipeline.make_pipeline_state(
@@ -1136,7 +1308,10 @@ class FlashAttentionBackwardSm90:
1136
  m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1137
 
1138
  if const_expr(not self.use_block_sparsity):
1139
- process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
 
 
 
1140
  else:
1141
  total_m_block_cnt = get_total_q_block_count_bwd(
1142
  blocksparse_tensors,
@@ -1218,8 +1393,8 @@ class FlashAttentionBackwardSm90:
1218
  qhead_per_kvhead_divmod,
1219
  )
1220
  else:
1221
- # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros.
1222
- if const_expr(self.use_block_sparsity):
1223
  acc_dK.fill(0.0)
1224
  acc_dV.fill(0.0)
1225
  self.epilogue_dKV(
@@ -1248,6 +1423,22 @@ class FlashAttentionBackwardSm90:
1248
  if warp_idx == 4:
1249
  cute.arch.cp_async_bulk_wait_group(0, read=True)
1250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1251
  @cute.jit
1252
  def mma_one_m_block(
1253
  self,
@@ -1266,16 +1457,17 @@ class FlashAttentionBackwardSm90:
1266
  pipeline_dO: cutlass.pipeline.PipelineAsync,
1267
  tLSEsLSE: cute.Tensor,
1268
  tLSEsdPsum: cute.Tensor,
1269
- tdQsdQaccum: cute.Tensor,
1270
  softmax_scale_log2: Float32,
1271
  PdS_barrier: cutlass.pipeline.NamedBarrier,
 
1272
  mask_fn: Optional[Callable] = None,
1273
  score_mod_fn: Optional[Callable] = None,
1274
  score_mod_bwd_fn: Optional[Callable] = None,
1275
  dKV_accumulate: Boolean = True,
1276
  ):
1277
  consumer_state_dO_cur = (
1278
- consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
1279
  )
1280
  smem_idx_Q = consumer_state_Q.index
1281
  smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
@@ -1283,6 +1475,7 @@ class FlashAttentionBackwardSm90:
1283
  # (1) [GEMM 1] S = Q @ K^T
1284
  pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
1285
  acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
 
1286
  tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
1287
  # (2) [GEMM 2] dP = dO @ V.T
1288
  pipeline_dO.consumer_wait(
@@ -1301,10 +1494,12 @@ class FlashAttentionBackwardSm90:
1301
  if cutlass.const_expr(mask_fn is not None):
1302
  mask_fn(acc_S, m_block=m_block)
1303
  acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
 
1304
  for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
 
1305
  for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
1306
  acc_S_mn[r, c] = cute.math.exp2(
1307
- acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True
1308
  )
1309
  tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
1310
 
@@ -1321,8 +1516,9 @@ class FlashAttentionBackwardSm90:
1321
  warpgroup.wait_group(0)
1322
  acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
1323
  for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
 
1324
  for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
1325
- acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
1326
 
1327
  if const_expr(self.score_mod_bwd is not None):
1328
  score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
@@ -1354,36 +1550,50 @@ class FlashAttentionBackwardSm90:
1354
  # smem fence to make sure sdS is written before it's read by WGMMA
1355
  cute.arch.fence_view_async_shared()
1356
  PdS_barrier.arrive_and_wait()
1357
- # (6) [GEMM 4] dQ = dS @ K
1358
- acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
1359
- # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)
1360
- pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
1361
 
1362
- # (7) [GEMM 5] dK += dS.T @ Q
1363
- if const_expr(not self.mma_dkv_is_rs):
1364
- mma_dsq_fn(
1365
- A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
1366
- )
1367
- else:
1368
- mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
1369
- # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ)
1370
 
1371
- cute.arch.barrier(
1372
- barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1373
- number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1374
- )
1375
- tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape))
1376
- cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
1377
- cute.arch.fence_view_async_shared()
1378
- cute.arch.barrier_arrive(
1379
- barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1380
- number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1381
- )
1382
 
1383
- warpgroup.wait_group(0)
1384
- # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK)
1385
- pipeline_Q.consumer_release(consumer_state_Q)
1386
- # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1387
 
1388
  consumer_state_Q.advance()
1389
  consumer_state_dO.advance()
@@ -1415,8 +1625,12 @@ class FlashAttentionBackwardSm90:
1415
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1416
 
1417
  if const_expr(self.qhead_per_kvhead == 1):
1418
- mdV_cur = mdV[None, None, head_idx, batch_idx]
1419
- mdK_cur = mdK[None, None, head_idx, batch_idx]
 
 
 
 
1420
  gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
1421
  gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
1422
  store_dK, _, _ = copy_utils.tma_get_copy_fn(
@@ -1428,10 +1642,20 @@ class FlashAttentionBackwardSm90:
1428
  sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)
1429
  sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)
1430
  copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(
1431
- tiled_mma_dV, sdV, tidx, self.arch, transpose=self.dKV_swapAB
 
 
 
 
 
1432
  )
1433
  copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(
1434
- tiled_mma_dK, sdK, tidx, self.arch, transpose=self.dKV_swapAB
 
 
 
 
 
1435
  )
1436
  cute.arch.cp_async_bulk_wait_group(1, read=True)
1437
  epi_barrier.arrive_and_wait()
@@ -1450,15 +1674,19 @@ class FlashAttentionBackwardSm90:
1450
  store_dK()
1451
  cute.arch.cp_async_bulk_commit_group()
1452
  else:
1453
- sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_mma_warp_groups
1454
- sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_mma_warp_groups
1455
- sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_mma_warp_groups))
1456
- sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_mma_warp_groups))
1457
  head_idx_kv = head_idx // qhead_per_kvhead_divmod
1458
- mdKaccum_cur = mdK[None, head_idx_kv, batch_idx]
 
 
 
 
 
1459
  gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
1460
  gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))
1461
- mdVaccum_cur = mdV[None, head_idx_kv, batch_idx]
1462
  gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
1463
  gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))
1464
  # These two overlap each other
@@ -1467,7 +1695,7 @@ class FlashAttentionBackwardSm90:
1467
  sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)
1468
  tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(
1469
  cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
1470
- cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)),
1471
  cute.make_layout(128 // Float32.width),
1472
  )
1473
  thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)
@@ -1482,11 +1710,11 @@ class FlashAttentionBackwardSm90:
1482
  epi_barrier.arrive_and_wait()
1483
  if warp_idx == 4:
1484
  with cute.arch.elect_one():
1485
- for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1486
  copy_utils.cpasync_reduce_bulk_add_f32(
1487
  sdKaccum[None, wg_idx].iterator,
1488
  gdKaccum[None, wg_idx].iterator,
1489
- self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups,
1490
  )
1491
  cute.arch.cp_async_bulk_commit_group()
1492
 
@@ -1498,11 +1726,11 @@ class FlashAttentionBackwardSm90:
1498
  epi_barrier.arrive_and_wait()
1499
  if warp_idx == 4:
1500
  with cute.arch.elect_one():
1501
- for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1502
  copy_utils.cpasync_reduce_bulk_add_f32(
1503
  sdVaccum[None, wg_idx].iterator,
1504
  gdVaccum[None, wg_idx].iterator,
1505
- self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups,
1506
  )
1507
  cute.arch.cp_async_bulk_commit_group()
1508
 
@@ -1515,21 +1743,45 @@ class FlashAttentionBackwardSm90:
1515
  TileSchedulerCls: cutlass.Constexpr[Callable],
1516
  SeqlenInfoCls: cutlass.Constexpr[Callable],
1517
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
 
1518
  ):
 
 
 
 
 
1519
  tile_scheduler = TileSchedulerCls()
1520
  work_tile = tile_scheduler.initial_work_tile_info()
1521
  while work_tile.is_valid_tile:
1522
  n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1523
  seqlen = SeqlenInfoCls(batch_idx)
1524
- mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
1525
- gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
1526
- # (M * K / WG, WG, _)
1527
- gdQaccum = cute.flat_divide(
1528
- gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,)
 
 
 
 
 
 
 
 
 
 
1529
  )
 
 
 
 
 
1530
  m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1531
  if const_expr(not self.use_block_sparsity):
1532
- process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
 
 
 
1533
  loop_count = m_block_max - m_block_min
1534
  else:
1535
  total_block_cnt = get_total_q_block_count_bwd(
@@ -1548,17 +1800,36 @@ class FlashAttentionBackwardSm90:
1548
  m_block = m_block_min + iter_idx
1549
  m_block_safe = m_block
1550
 
1551
- for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
1552
- cute.arch.cp_async_bulk_wait_group(
1553
- self.num_mma_warp_groups - 1 - warp_group_idx, read=True
1554
- )
 
 
 
1555
  cute.arch.barrier_arrive(
1556
  barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1557
  number_of_threads=self.num_threads_per_warp_group
1558
  + cute.arch.WARP_SIZE,
1559
  )
1560
 
1561
- for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1562
  cute.arch.barrier(
1563
  barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1564
  number_of_threads=self.num_threads_per_warp_group
@@ -1567,11 +1838,24 @@ class FlashAttentionBackwardSm90:
1567
  with cute.arch.elect_one():
1568
  copy_utils.cpasync_reduce_bulk_add_f32(
1569
  sdQaccum[None, warp_group_idx].iterator,
1570
- gdQaccum[None, warp_group_idx, m_block_safe].iterator,
1571
  self.tma_copy_bytes["dQ"],
1572
  )
1573
  cute.arch.cp_async_bulk_commit_group()
 
 
 
 
 
 
 
 
 
 
1574
  else:
 
 
 
1575
  dQaccum_store_block_sparse_bwd_sm90(
1576
  blocksparse_tensors,
1577
  batch_idx,
@@ -1581,11 +1865,27 @@ class FlashAttentionBackwardSm90:
1581
  gdQaccum,
1582
  subtile_factor=self.subtile_factor,
1583
  m_block_max=m_block_max,
1584
- num_mma_warp_groups=self.num_mma_warp_groups,
1585
  num_threads_per_warp_group=self.num_threads_per_warp_group,
1586
  tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
1587
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1588
  tile_scheduler.advance_to_next_work()
1589
  work_tile = tile_scheduler.get_current_work()
1590
 
1591
- cute.arch.cp_async_bulk_wait_group(0, read=True)
 
 
24
  from .block_info import BlockInfo
25
  from . import pipeline
26
  from .quack.cute_dsl_utils import ParamsBase
27
+ from .tile_scheduler import (
28
+ TileSchedulerArguments,
29
+ SingleTileScheduler,
30
+ SingleTileLPTBwdScheduler,
31
+ SingleTileVarlenScheduler,
32
+ )
33
+ from . import barrier
34
  from .named_barrier import NamedBarrierBwd
35
  from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
36
  from .block_sparsity import BlockSparseTensors
 
52
  head_dim_v: Optional[int] = None,
53
  qhead_per_kvhead: int = 1,
54
  is_causal: bool = False,
55
+ is_local: bool = False,
56
+ deterministic: bool = False,
57
  tile_m: int = 64,
58
  tile_n: int = 128,
59
  Q_stage: int = 2,
 
72
  mask_mod: cutlass.Constexpr | None = None,
73
  has_aux_tensors: cutlass.Constexpr = False,
74
  subtile_factor: cutlass.Constexpr[int] = 1,
75
+ dQ_single_wg: bool = False,
76
  ):
77
  self.dtype = dtype
78
  # padding head_dim to a multiple of 16 as k_block_size
 
86
  self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
87
  self.qhead_per_kvhead = qhead_per_kvhead
88
  self.is_causal = is_causal
89
+ self.is_local = is_local
90
+ self.deterministic = deterministic
91
  self.tile_m = tile_m
92
  self.tile_n = tile_n
93
  self.num_threads = num_threads
 
102
  self.AtomLayoutMSdP = AtomLayoutMSdP
103
  self.AtomLayoutNdKV = AtomLayoutNdKV
104
  self.AtomLayoutMdQ = AtomLayoutMdQ
105
+ self.num_wg_mma = (self.num_threads // 128) - 1
106
  self.mma_dkv_is_rs = (
107
  AtomLayoutMSdP == 1
108
+ and AtomLayoutNdKV == self.num_wg_mma
109
  and SdP_swapAB
110
  and not dKV_swapAB
111
  )
112
  self.V_in_regs = V_in_regs
113
+ # May be overridden in __call__ for varlen inputs.
114
  if qhead_per_kvhead > 1:
115
  assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v"
116
+ assert self.num_wg_mma == 2, "GQA backward assumes 2 warp groups"
117
  # These are tuned for speed
118
  # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share
119
  # them and then shuffle to get the value whenever we need? This can reduce register
120
  # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)
121
  # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.
 
122
  self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64
123
  self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64
124
 
 
134
  else:
135
  self.vec_size: cutlass.Constexpr = 4
136
  self.qk_acc_dtype = Float32
137
+ # dQ_single_wg: WG0 computes the full dQ GEMM, WG1 skips it.
138
+ # Only valid for 2 MMA warp groups.
139
+ # Credit: Ben Spector
140
+ if dQ_single_wg:
141
+ assert self.num_wg_mma == 2, "dQ_single_wg only supports 2 warp groups"
142
+ self.num_wg_dQ = 1 if dQ_single_wg else self.num_wg_mma
143
 
144
  @staticmethod
145
  def can_implement(
 
198
  assert mQ_type == self.dtype
199
 
200
  def _setup_attributes(self):
201
+ # We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.
202
+ # Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.
203
+ # The M dimension (tile_m) doesn't matter for the layout, only the K dimension
204
+ wg_d_dKV = self.num_wg_mma // self.AtomLayoutNdKV
205
+ self.sQ_layout, self.sdO_layout = [
206
+ # Need to set major_mode_size (mms) to accommodate Q and Q.T
207
+ sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage, mms)
208
+ for shape, stage, mms in [
209
+ ((self.tile_m, self.tile_hdim), self.Q_stage, self.tile_hdim // wg_d_dKV),
210
+ ((self.tile_m, self.tile_hdimv), self.dO_stage, self.tile_hdim // wg_d_dKV),
211
  ]
212
  ]
213
+ wg_d_dQ = self.num_wg_dQ // self.AtomLayoutMdQ
214
+ # Accomodate both K and K.T
215
+ self.sK_layout = sm90_utils.make_smem_layout(
216
+ self.dtype,
217
+ LayoutEnum.ROW_MAJOR,
218
+ (self.tile_n, self.tile_hdim),
219
+ stage=None,
220
+ major_mode_size=self.tile_hdim // wg_d_dQ,
221
+ )
222
+ # There's only V, no V.T, so layout is normal
223
+ self.sV_layout = sm90_utils.make_smem_layout(
224
+ self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdimv), None
225
+ )
226
+ # Accomodate both S and S.T
227
+ wg_n_SdP = self.num_wg_mma // self.AtomLayoutMSdP
228
+ wg_n_dKV = self.AtomLayoutNdKV
229
+ self.sPdS_layout = sm90_utils.make_smem_layout(
230
+ self.dtype,
231
+ LayoutEnum.ROW_MAJOR,
232
+ (self.tile_m, self.tile_n),
233
+ stage=self.PdS_stage,
234
+ major_mode_size=math.gcd(self.tile_n // wg_n_SdP, self.tile_n // wg_n_dKV),
235
+ )
236
  self.sdQaccum_layout = cute.make_layout(
237
+ (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ)
238
  )
239
  # dQaccum R->S
240
  self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
241
  cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
242
  # thr_layout
243
+ cute.make_layout((self.num_threads_per_warp_group, self.num_wg_dQ)),
244
  cute.make_layout(128 // Float32.width), # val_layout
245
  )
246
  # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32
247
  # TODO: assert that sVaccum and sKaccum don't overflow smem
248
 
249
  def _get_tiled_mma(self):
250
+ maybe_swap_mn = lambda shape, swap: (shape[1], shape[0], *shape[2:]) if swap else shape
251
  # S = Q @ K.T, dP = dO @ V.T
252
+ atom_layout_SdP = (self.AtomLayoutMSdP, self.num_wg_mma // self.AtomLayoutMSdP, 1)
253
  tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])
254
  tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(
255
  self.dtype,
 
257
  warpgroup.OperandMajorMode.K,
258
  warpgroup.OperandMajorMode.K,
259
  Float32,
260
+ atom_layout_mnk=maybe_swap_mn(atom_layout_SdP, self.SdP_swapAB),
261
+ tiler_mn=(64, tiler_mn_SdP[1] if not self.SdP_swapAB else tiler_mn_SdP[0]),
 
262
  )
263
  # dV = P.T @ dO, dK = dS.T @ Q
264
+ atom_layout_dKV = (self.AtomLayoutNdKV, self.num_wg_mma // self.AtomLayoutNdKV, 1)
265
  tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])
266
  tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])
267
  tiled_mma_dK, tiled_mma_dV = [
 
273
  else warpgroup.OperandMajorMode.K,
274
  warpgroup.OperandMajorMode.MN,
275
  Float32,
276
+ atom_layout_mnk=maybe_swap_mn(atom_layout_dKV, self.dKV_swapAB),
277
+ tiler_mn=(64, tiler_mn_d[1] if not self.dKV_swapAB else tiler_mn_d[0]),
 
278
  a_source=warpgroup.OperandSource.RMEM
279
  if self.mma_dkv_is_rs
280
  else warpgroup.OperandSource.SMEM,
 
282
  for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)
283
  ]
284
  # dQ = dS @ K
285
+ assert self.num_wg_dQ % self.AtomLayoutMdQ == 0
286
+ atom_layout_dQ = (self.AtomLayoutMdQ, self.num_wg_dQ // self.AtomLayoutMdQ, 1)
287
  tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])
288
  tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(
289
  self.dtype,
 
291
  warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,
292
  warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,
293
  Float32,
294
+ atom_layout_mnk=maybe_swap_mn(atom_layout_dQ, self.dQ_swapAB),
295
+ tiler_mn=(64, tiler_mn_dQ[1] if not self.dQ_swapAB else tiler_mn_dQ[0]),
296
  )
297
  return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
298
 
 
346
  mdK: cute.Tensor,
347
  mdV: cute.Tensor,
348
  softmax_scale: Float32,
 
349
  mCuSeqlensQ: Optional[cute.Tensor] = None,
350
  mCuSeqlensK: Optional[cute.Tensor] = None,
351
  mSeqUsedQ: Optional[cute.Tensor] = None,
 
358
  mdV_semaphore: Optional[cute.Tensor] = None,
359
  aux_tensors: Optional[list] = None,
360
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
361
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
362
+ stream: cuda.CUstream = None,
363
  ):
364
+ # For GQA (qhead_per_kvhead > 1), multiple Q heads accumulate into the same dK/dV,
365
+ # so we need the float32 accum path + postprocess.
366
+ # For varlen_k with qhead_per_kvhead == 1, we use ragged TMA tensors.
367
+ self.varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
368
 
369
  self._check_type(
370
  *(
 
373
  )
374
  )
375
 
376
+ self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
377
+
378
  mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [
379
  assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)
380
  ]
381
 
382
+ # Non-varlen inputs are (b, s, n, h), varlen inputs are (s, n, h).
383
+ # We convert both to a seqlen-major view with head-dim second.
384
+ # Each tensor may have different rank when Q is padded (seqused_q) but K/V are unpadded (cu_seqlens_k).
385
+ def _qkv_transpose(t):
386
+ return layout_utils.select(t, [1, 3, 2, 0] if cute.rank(t.shape) == 4 else [0, 2, 1])
387
+
388
+ mQ, mK, mV, mdO = [_qkv_transpose(t) for t in (mQ, mK, mV, mdO)]
389
  if const_expr(self.qhead_per_kvhead == 1):
390
+ mdK, mdV = [_qkv_transpose(t) for t in (mdK, mdV)]
391
  else:
392
+ # Accum tensors are (b, n, s*h) for non-varlen and (n, s*h) for varlen.
393
+ accum_transpose = [2, 1, 0] if cute.rank(mdK.shape) == 3 else [1, 0]
394
  mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]
395
+ # Non-varlen stats are (b, n, s), varlen stats are (n, s).
396
+ LSE_dPsum_dQaccum_transpose = [2, 1, 0] if cute.rank(mLSE.shape) == 3 else [1, 0]
397
  mLSE, mdPsum, mdQaccum = [
398
  layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
399
  ]
400
 
401
  tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()
402
+ # (batch, num_head, num_m_blocks, cluster_size) -> (num_m_blocks, cluster_size, num_head, batch)
403
+ if const_expr(self.deterministic):
404
+ assert mdQ_semaphore is not None
405
+ mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0])
406
 
407
  self.num_mma_threads = tiled_mma_SdP.size
408
  assert self.num_mma_threads + 128 == self.num_threads
 
410
  self.num_threads_per_warp_group = 128
411
  self.num_producer_threads = 32
412
 
413
+ REG_LIMIT = 504 if self.num_wg_mma == 2 else 512
414
+ if const_expr(self.num_wg_mma == 2):
415
+ if const_expr(self.num_wg_dQ == 1):
416
+ self.num_mma_regs_wg0 = 256
417
+ self.num_mma_regs_wg1 = 224
418
+ else:
419
+ self.num_mma_regs_wg0 = 240
420
+ self.num_mma_regs_wg1 = 240
421
+ self.num_mma_regs = self.num_mma_regs_wg0 # for backward compat
422
+ self.num_producer_regs = 24
423
+ assert (
424
+ self.num_mma_regs_wg0 + self.num_mma_regs_wg1 + self.num_producer_regs <= REG_LIMIT
425
+ )
426
+ else: # 3 warp groups
427
+ self.num_mma_regs_wg0 = 160
428
+ self.num_mma_regs_wg1 = 160
429
+ self.num_mma_regs = 160
430
+ self.num_producer_regs = 32
431
+ assert self.num_mma_regs_wg0 * self.num_wg_mma + self.num_producer_regs <= REG_LIMIT
432
 
433
  self._setup_attributes()
434
  SharedStorage = self._get_shared_storage_cls()
 
445
  self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
446
  self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
447
  self.tma_copy_bytes["dQ"] = (
448
+ self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_wg_dQ
449
  )
450
  self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8
451
  self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8
 
475
  (self.tile_m, self.tile_hdimv),
476
  )
477
  if const_expr(self.qhead_per_kvhead == 1):
478
+ mdK_tma = (
479
+ copy_utils.create_ragged_tensor_for_tma(mdK, ragged_dim=0, ptr_shift=True)
480
+ if self.varlen_k
481
+ else mdK
482
+ )
483
+ mdV_tma = (
484
+ copy_utils.create_ragged_tensor_for_tma(mdV, ragged_dim=0, ptr_shift=True)
485
+ if self.varlen_k
486
+ else mdV
487
+ )
488
  tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(
489
  cpasync.CopyBulkTensorTileS2GOp(),
490
+ mdK_tma,
491
  cute.select(self.sK_layout, mode=[0, 1]),
492
  (self.tile_n, self.tile_hdim),
493
  )
494
  tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(
495
  cpasync.CopyBulkTensorTileS2GOp(),
496
+ mdV_tma,
497
  cute.select(self.sV_layout, mode=[0, 1]),
498
  (self.tile_n, self.tile_hdimv),
499
  )
500
  else:
501
  tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None
502
 
503
+ if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None):
504
+ TileScheduler = SingleTileVarlenScheduler
505
+ elif const_expr(self.deterministic):
506
+ TileScheduler = SingleTileLPTBwdScheduler
507
+ else:
508
+ TileScheduler = SingleTileScheduler
509
+ self.spt = (self.is_causal or self.is_local) and self.deterministic
510
  tile_sched_args = TileSchedulerArguments(
511
  cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),
512
  cute.size(mQ.shape[2]),
513
+ cute.size(mK.shape[3])
514
+ if const_expr(mCuSeqlensK is None)
515
+ else cute.size(mCuSeqlensK.shape[0] - 1), # num_batch
516
  1, # num_splits
517
+ cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k
518
+ mQ.shape[1], # headdim
519
+ mV.shape[1], # headdim_v
520
+ total_q=cute.size(mK.shape[0])
521
+ if const_expr(mCuSeqlensK is not None)
522
+ else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),
523
+ tile_shape_mn=(self.tile_n, self.tile_m), # Swapping the role of Q & K
524
+ mCuSeqlensQ=mCuSeqlensK,
525
+ mSeqUsedQ=mSeqUsedK,
526
  qhead_per_kvhead_packgqa=1,
527
  element_size=self.dtype.width // 8,
528
  is_persistent=False,
529
+ lpt=self.spt,
530
+ head_swizzle=self.deterministic,
531
  )
532
 
533
  tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
 
553
 
554
  self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
555
 
556
+ if const_expr(window_size_left is not None):
557
+ window_size_left = Int32(window_size_left)
558
+ if const_expr(window_size_right is not None):
559
+ window_size_right = Int32(window_size_right)
560
+
561
  self.kernel(
562
  tma_tensor_Q,
563
  tma_tensor_K,
 
574
  mLSE,
575
  mdPsum,
576
  mdQaccum,
577
+ mCuSeqlensQ,
578
+ mCuSeqlensK,
579
+ mSeqUsedQ,
580
+ mSeqUsedK,
581
  self.sQ_layout,
582
  self.sK_layout,
583
  self.sV_layout,
 
598
  fastdiv_mods,
599
  blocksparse_tensors,
600
  qhead_per_kvhead_divmod,
601
+ mdQ_semaphore,
602
+ window_size_left,
603
+ window_size_right,
604
  ).launch(
605
  grid=grid_dim,
606
  block=[self.num_threads, 1, 1],
607
  stream=stream,
608
  min_blocks_per_mp=1,
609
+ use_pdl=True,
610
  )
611
 
612
  @cute.kernel
 
627
  mLSE: cute.Tensor,
628
  mdPsum: cute.Tensor,
629
  mdQaccum: cute.Tensor,
630
+ mCuSeqlensQ: Optional[cute.Tensor],
631
+ mCuSeqlensK: Optional[cute.Tensor],
632
+ mSeqUsedQ: Optional[cute.Tensor],
633
+ mSeqUsedK: Optional[cute.Tensor],
634
  sQ_layout: cute.ComposedLayout,
635
  sK_layout: cute.ComposedLayout,
636
  sV_layout: cute.ComposedLayout,
 
651
  fastdiv_mods=(None, None),
652
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
653
  qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
654
+ mdQ_semaphore: Optional[cute.Tensor] = None,
655
+ window_size_left: Optional[Int32] = None,
656
+ window_size_right: Optional[Int32] = None,
657
  ):
658
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
659
 
660
  # prefetch TMA descriptors
661
  if warp_idx == 0:
662
+ for atom in [tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV]:
663
+ if const_expr(atom is not None):
664
+ cpasync.prefetch_descriptor(atom)
 
665
 
666
  smem = cutlass.utils.SmemAllocator()
667
  storage = smem.allocate(SharedStorage)
 
715
  self.is_causal,
716
  self.is_local,
717
  False, # is_split_kv
718
+ window_size_left,
719
+ window_size_right,
720
  qhead_per_kvhead_packgqa=1,
721
  )
722
  SeqlenInfoCls = partial(
723
  SeqlenInfoQK.create,
724
  seqlen_q_static=mQ.shape[0],
725
  seqlen_k_static=mK.shape[0],
726
+ mCuSeqlensQ=mCuSeqlensQ,
727
+ mCuSeqlensK=mCuSeqlensK,
728
+ mSeqUsedQ=mSeqUsedQ,
729
+ mSeqUsedK=mSeqUsedK,
730
+ tile_m=self.tile_m,
731
+ tile_n=self.tile_n,
732
  )
733
  AttentionMaskCls = partial(
734
  AttentionMask,
735
  self.tile_m,
736
  self.tile_n,
737
+ window_size_left=window_size_left,
738
+ window_size_right=window_size_right,
739
  swap_AB=self.SdP_swapAB,
740
  )
741
  TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
 
776
  TileSchedulerCls,
777
  SeqlenInfoCls,
778
  blocksparse_tensors,
779
+ mdQ_semaphore,
780
  )
781
  else:
 
782
  tidx, _, _ = cute.arch.thread_idx()
783
  tidx = tidx - 128
784
+ mma_args = (
785
  tiled_mma_SdP,
786
  tiled_mma_dK,
787
  tiled_mma_dV,
 
815
  blocksparse_tensors,
816
  qhead_per_kvhead_divmod,
817
  )
818
+ if const_expr(self.num_wg_dQ == self.num_wg_mma):
819
+ # Both WGs compute dQ
820
+ cute.arch.setmaxregister_increase(self.num_mma_regs_wg0)
821
+ self.mma(*mma_args, is_dQ_wg=True)
822
+ else:
823
+ # WG0 computes dQ, WG1 skips it
824
+ warp_idx_in_mma = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - 4
825
+ if warp_idx_in_mma < 4:
826
+ cute.arch.setmaxregister_increase(self.num_mma_regs_wg0)
827
+ self.mma(*mma_args, is_dQ_wg=True)
828
+ else:
829
+ cute.arch.setmaxregister_increase(self.num_mma_regs_wg1)
830
+ self.mma(*mma_args, is_dQ_wg=False)
831
 
832
  @cute.jit
833
  def load(
 
875
  if const_expr(self.qhead_per_kvhead == 1)
876
  else head_idx // qhead_per_kvhead_divmod
877
  )
878
+ mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
879
+ mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
880
  gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
 
881
  gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
882
 
883
+ mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
884
+ mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[
885
+ None, head_idx
886
+ ]
887
+ mdO_cur = seqlen.offset_batch_Q(mdO, batch_idx, dim=3)[None, None, head_idx]
888
+ mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[
889
+ None, head_idx
890
+ ]
891
  gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
 
892
  gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))
 
893
  gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
 
894
  gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
895
 
896
  load_K, _, _ = copy_utils.tma_get_copy_fn(
 
916
 
917
  if const_expr(not self.use_block_sparsity):
918
  total_m_block_cnt = m_block_max - m_block_min
919
+ process_tile = (
920
+ const_expr(not self.is_local and not self.is_varlen_q)
921
+ or m_block_min < m_block_max
922
+ )
923
  else:
924
  total_m_block_cnt = get_total_q_block_count_bwd(
925
  blocksparse_tensors,
 
939
  )
940
  load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))
941
  load_Q(first_m_block, producer_state=producer_state_Q)
942
+ # Wait for bwd preprocess to finish writing LSE and dPsum
943
+ cute.arch.griddepcontrol_wait()
944
  load_LSE(first_m_block, producer_state=producer_state_Q)
945
  producer_state_dO_cur = (
946
  producer_state_dO
 
1119
  fastdiv_mods=(None, None),
1120
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
1121
  qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,
1122
+ is_dQ_wg: cutlass.Constexpr[bool] = True,
1123
  ):
1124
  warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
1125
  warp_group_thread_layout = cute.make_layout(
1126
+ self.num_wg_mma, stride=self.num_threads_per_warp_group
1127
  )
1128
  thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)
1129
  wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))
1130
  wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))
1131
  wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))
1132
+ wg_mma_dQ = None
1133
+ if const_expr(is_dQ_wg):
1134
+ wg_idx_dQ = warp_group_idx if const_expr(self.num_wg_dQ > 1) else 0
1135
+ wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(wg_idx_dQ))
1136
  # S = Q @ K.T
1137
  shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)
1138
  _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
 
1178
  # dQ = dS @ K
1179
  sKt = layout_utils.transpose_view(sK)
1180
  shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)
1181
+ mma_dsk_fn = None
1182
+ if const_expr(is_dQ_wg):
1183
+ _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(
1184
+ wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB
1185
+ )
1186
+ mma_dsk_fn = partial(
1187
+ gemm_zero_init,
1188
+ tiled_mma_dQ,
1189
+ shape_mnk_dQ[:2],
1190
+ tdQrdS,
1191
+ tdQrKt,
1192
+ swap_AB=self.dQ_swapAB,
1193
+ )
1194
 
1195
+ # Smem copy atom tiling for P/dS R2S
1196
  copy_P_r2s = None
1197
+ mms_PdS = self.tile_n // (self.num_wg_mma // self.AtomLayoutMSdP)
1198
  if const_expr(sP is not None):
1199
  sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt
1200
  copy_P_r2s, _, _ = copy_utils.get_smem_store_C(
1201
+ tiled_mma_SdP,
1202
+ sP_cpy,
1203
+ tidx,
1204
+ self.arch,
1205
+ transpose=self.SdP_swapAB,
1206
+ position_independent=True,
1207
+ major_mode_size=mms_PdS,
1208
  )
1209
  sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt
1210
  copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(
1211
+ tiled_mma_SdP,
1212
+ sdS_cpy,
1213
+ tidx,
1214
+ self.arch,
1215
+ transpose=self.SdP_swapAB,
1216
+ position_independent=True,
1217
+ major_mode_size=mms_PdS,
1218
  )
1219
 
1220
  tLSEsLSE = layout_utils.mma_partition_C_vec(
 
1223
  tLSEsdPsum = layout_utils.mma_partition_C_vec(
1224
  sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB
1225
  )
1226
+ # When shuffle=True, rows are distributed across 8 quads (4 threads each) within a warp.
1227
+ # Each thread loads only ceil(num_rows/8) values;
1228
+ shfl_copy = copy_utils.tiled_copy_1d(sLSE.element_type, num_threads=8, num_copy_elems=2)
1229
+ if const_expr(self.shuffle_LSE):
1230
+ tLSEsLSE = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsLSE)
1231
+ # ((2, 1), 1, 2) -> (((2, 1), 1), 2)
1232
+ tLSEsLSE = cute.group_modes(tLSEsLSE, 0, 2)
1233
+ if const_expr(self.shuffle_dPsum):
1234
+ tLSEsdPsum = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsdPsum)
1235
+ tLSEsdPsum = cute.group_modes(tLSEsdPsum, 0, 2)
1236
+
1237
+ tdQsdQaccum = None
1238
+ if const_expr(is_dQ_wg):
1239
+ smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
1240
+ tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
1241
 
1242
  PdS_barrier = cutlass.pipeline.NamedBarrier(
1243
  barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
 
1276
  PdS_barrier=PdS_barrier,
1277
  # acc_dV=acc_dV,
1278
  # acc_dK=acc_dK,
1279
+ is_dQ_wg=is_dQ_wg,
1280
  )
1281
 
1282
  consumer_state_Q = cutlass.pipeline.make_pipeline_state(
 
1308
  m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1309
 
1310
  if const_expr(not self.use_block_sparsity):
1311
+ process_tile = (
1312
+ const_expr(not self.is_local and not self.is_varlen_q)
1313
+ or m_block_min < m_block_max
1314
+ )
1315
  else:
1316
  total_m_block_cnt = get_total_q_block_count_bwd(
1317
  blocksparse_tensors,
 
1393
  qhead_per_kvhead_divmod,
1394
  )
1395
  else:
1396
+ # KV tile with zero Q blocks produces no dK/dV; write zeros.
1397
+ if const_expr(self.use_block_sparsity or self.is_local or self.is_varlen_q):
1398
  acc_dK.fill(0.0)
1399
  acc_dV.fill(0.0)
1400
  self.epilogue_dKV(
 
1423
  if warp_idx == 4:
1424
  cute.arch.cp_async_bulk_wait_group(0, read=True)
1425
 
1426
+ @staticmethod
1427
+ @cute.jit
1428
+ def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: bool) -> Float32:
1429
+ """Retrieve the statistic for a given accumulator row.
1430
+
1431
+ When shuffle=False, direct register indexing.
1432
+ When shuffle=True, warp shuffle from the thread group that holds the value.
1433
+ """
1434
+ if const_expr(not shuffle):
1435
+ return tSrS[row]
1436
+ # tSrS: (((2, 1), 1), 1)), distributed across 8 threads in the warp
1437
+ vecsize = cute.size(tSrS, mode=[0, 0]) # 2
1438
+ idx0, off, idx1 = cute.idx2crd(row, (vecsize, 8, cute.shape(tSrS, mode=[0, 1])))
1439
+ # register index: 0, 1, 0, 1, ..., 2, 3, 2, 3, ...
1440
+ return utils.shuffle_sync(tSrS[idx0 + idx1 * vecsize], offset=off * 4 + (lane % 4))
1441
+
1442
  @cute.jit
1443
  def mma_one_m_block(
1444
  self,
 
1457
  pipeline_dO: cutlass.pipeline.PipelineAsync,
1458
  tLSEsLSE: cute.Tensor,
1459
  tLSEsdPsum: cute.Tensor,
1460
+ tdQsdQaccum: Optional[cute.Tensor],
1461
  softmax_scale_log2: Float32,
1462
  PdS_barrier: cutlass.pipeline.NamedBarrier,
1463
+ is_dQ_wg: cutlass.Constexpr[bool] = True,
1464
  mask_fn: Optional[Callable] = None,
1465
  score_mod_fn: Optional[Callable] = None,
1466
  score_mod_bwd_fn: Optional[Callable] = None,
1467
  dKV_accumulate: Boolean = True,
1468
  ):
1469
  consumer_state_dO_cur = (
1470
+ consumer_state_Q if const_expr(self.Q_stage == self.dO_stage) else consumer_state_dO
1471
  )
1472
  smem_idx_Q = consumer_state_Q.index
1473
  smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0
 
1475
  # (1) [GEMM 1] S = Q @ K^T
1476
  pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))
1477
  acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)
1478
+ # If shuffle_LSE, OOB reads are OK since sLSE is already padded
1479
  tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])
1480
  # (2) [GEMM 2] dP = dO @ V.T
1481
  pipeline_dO.consumer_wait(
 
1494
  if cutlass.const_expr(mask_fn is not None):
1495
  mask_fn(acc_S, m_block=m_block)
1496
  acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
1497
+ lane_idx = cute.arch.lane_idx()
1498
  for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
1499
+ lse_val = self._get_stat(tLSErLSE, r, lane_idx, shuffle=self.shuffle_LSE)
1500
  for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
1501
  acc_S_mn[r, c] = cute.math.exp2(
1502
+ acc_S_mn[r, c] * softmax_scale_log2 - lse_val, fastmath=True
1503
  )
1504
  tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
1505
 
 
1516
  warpgroup.wait_group(0)
1517
  acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
1518
  for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
1519
+ dpsum_val = self._get_stat(tLSErdPsum, r, lane_idx, shuffle=self.shuffle_dPsum)
1520
  for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
1521
+ acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - dpsum_val)
1522
 
1523
  if const_expr(self.score_mod_bwd is not None):
1524
  score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
 
1550
  # smem fence to make sure sdS is written before it's read by WGMMA
1551
  cute.arch.fence_view_async_shared()
1552
  PdS_barrier.arrive_and_wait()
 
 
 
 
1553
 
1554
+ if const_expr(is_dQ_wg):
1555
+ # (6) [GEMM 4] dQ = dS @ K
1556
+ acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
1557
+ pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done
 
 
 
 
1558
 
1559
+ # (7) [GEMM 5] dK += dS.T @ Q
1560
+ if const_expr(not self.mma_dkv_is_rs):
1561
+ mma_dsq_fn(
1562
+ A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
1563
+ )
1564
+ else:
1565
+ mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
 
 
 
 
1566
 
1567
+ # dQ R2S: wait for dQaccum_store to free the smem buffer, then write dQ to smem
1568
+ # When dQ_single_wg, only WG0 enters here so warp_group_idx == 0
1569
+ cute.arch.barrier(
1570
+ barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1571
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1572
+ )
1573
+ tdQrdQaccum_flat = cute.make_tensor(
1574
+ acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)
1575
+ )
1576
+ cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
1577
+ cute.arch.fence_view_async_shared()
1578
+ cute.arch.barrier_arrive(
1579
+ barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1580
+ number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
1581
+ )
1582
+
1583
+ warpgroup.wait_group(0)
1584
+ pipeline_Q.consumer_release(consumer_state_Q)
1585
+ else:
1586
+ # dQ_single_wg: WG1 skips dQ, only does dV wait + dK
1587
+ # (7) [GEMM 5] dK += dS.T @ Q
1588
+ if const_expr(not self.mma_dkv_is_rs):
1589
+ mma_dsq_fn(
1590
+ A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1
1591
+ )
1592
+ else:
1593
+ mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)
1594
+ pipeline_dO.consumer_release(consumer_state_dO_cur)
1595
+ warpgroup.wait_group(0)
1596
+ pipeline_Q.consumer_release(consumer_state_Q)
1597
 
1598
  consumer_state_Q.advance()
1599
  consumer_state_dO.advance()
 
1625
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1626
 
1627
  if const_expr(self.qhead_per_kvhead == 1):
1628
+ mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3, ragged=self.varlen_k)[
1629
+ None, None, head_idx
1630
+ ]
1631
+ mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3, ragged=self.varlen_k)[
1632
+ None, None, head_idx
1633
+ ]
1634
  gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
1635
  gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
1636
  store_dK, _, _ = copy_utils.tma_get_copy_fn(
 
1642
  sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)
1643
  sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)
1644
  copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(
1645
+ tiled_mma_dV,
1646
+ sdV,
1647
+ tidx,
1648
+ self.arch,
1649
+ transpose=self.dKV_swapAB,
1650
+ position_independent=True,
1651
  )
1652
  copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(
1653
+ tiled_mma_dK,
1654
+ sdK,
1655
+ tidx,
1656
+ self.arch,
1657
+ transpose=self.dKV_swapAB,
1658
+ position_independent=True,
1659
  )
1660
  cute.arch.cp_async_bulk_wait_group(1, read=True)
1661
  epi_barrier.arrive_and_wait()
 
1674
  store_dK()
1675
  cute.arch.cp_async_bulk_commit_group()
1676
  else:
1677
+ sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma
1678
+ sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma
1679
+ sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma))
1680
+ sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma))
1681
  head_idx_kv = head_idx // qhead_per_kvhead_divmod
1682
+ mdKaccum_cur = seqlen.offset_batch_K(
1683
+ mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim
1684
+ )[None, head_idx_kv]
1685
+ mdVaccum_cur = seqlen.offset_batch_K(
1686
+ mdV, batch_idx, dim=2, padded=True, multiple=self.tile_hdimv
1687
+ )[None, head_idx_kv]
1688
  gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))
1689
  gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))
 
1690
  gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))
1691
  gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))
1692
  # These two overlap each other
 
1695
  sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)
1696
  tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(
1697
  cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),
1698
+ cute.make_layout((self.num_threads_per_warp_group, self.num_wg_mma)),
1699
  cute.make_layout(128 // Float32.width),
1700
  )
1701
  thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)
 
1710
  epi_barrier.arrive_and_wait()
1711
  if warp_idx == 4:
1712
  with cute.arch.elect_one():
1713
+ for wg_idx in cutlass.range_constexpr(self.num_wg_mma):
1714
  copy_utils.cpasync_reduce_bulk_add_f32(
1715
  sdKaccum[None, wg_idx].iterator,
1716
  gdKaccum[None, wg_idx].iterator,
1717
+ self.tma_copy_bytes["dKacc"] // self.num_wg_mma,
1718
  )
1719
  cute.arch.cp_async_bulk_commit_group()
1720
 
 
1726
  epi_barrier.arrive_and_wait()
1727
  if warp_idx == 4:
1728
  with cute.arch.elect_one():
1729
+ for wg_idx in cutlass.range_constexpr(self.num_wg_mma):
1730
  copy_utils.cpasync_reduce_bulk_add_f32(
1731
  sdVaccum[None, wg_idx].iterator,
1732
  gdVaccum[None, wg_idx].iterator,
1733
+ self.tma_copy_bytes["dVacc"] // self.num_wg_mma,
1734
  )
1735
  cute.arch.cp_async_bulk_commit_group()
1736
 
 
1743
  TileSchedulerCls: cutlass.Constexpr[Callable],
1744
  SeqlenInfoCls: cutlass.Constexpr[Callable],
1745
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
1746
+ mdQ_semaphore: Optional[cute.Tensor] = None,
1747
  ):
1748
+ tidx, _, _ = cute.arch.thread_idx()
1749
+ # warp-local thread index (dQaccum_store runs on warp 1, global tidx 32-63)
1750
+ warp_local_tidx = tidx % cute.arch.WARP_SIZE
1751
+ read_flag = const_expr(not self.deterministic)
1752
+
1753
  tile_scheduler = TileSchedulerCls()
1754
  work_tile = tile_scheduler.initial_work_tile_info()
1755
  while work_tile.is_valid_tile:
1756
  n_block, head_idx, batch_idx, _ = work_tile.tile_idx
1757
  seqlen = SeqlenInfoCls(batch_idx)
1758
+ if const_expr(not seqlen.has_cu_seqlens_q):
1759
+ mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
1760
+ else:
1761
+ mdQaccum_cur = cute.domain_offset(
1762
+ (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]
1763
+ )
1764
+ # ((M * K / num_wg_dQ, num_wg_dQ), num_m_blocks)
1765
+ gdQaccum = cute.local_tile(
1766
+ mdQaccum_cur,
1767
+ (
1768
+ cute.make_layout(
1769
+ (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ)
1770
+ ),
1771
+ ),
1772
+ (None,),
1773
  )
1774
+
1775
+ if const_expr(mdQ_semaphore is not None):
1776
+ # mdQ_semaphore is (num_m_blocks, cluster_size, num_head, batch) after transpose
1777
+ mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]
1778
+
1779
  m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
1780
  if const_expr(not self.use_block_sparsity):
1781
+ process_tile = (
1782
+ const_expr(not self.is_local and not self.is_varlen_q)
1783
+ or m_block_min < m_block_max
1784
+ )
1785
  loop_count = m_block_max - m_block_min
1786
  else:
1787
  total_block_cnt = get_total_q_block_count_bwd(
 
1800
  m_block = m_block_min + iter_idx
1801
  m_block_safe = m_block
1802
 
1803
+ num_dQ_chunks = self.num_wg_dQ
1804
+ for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks):
1805
+ if const_expr(not self.deterministic):
1806
+ # If deterministic, we already waited at the end of the prev iter
1807
+ cute.arch.cp_async_bulk_wait_group(
1808
+ num_dQ_chunks - 1 - warp_group_idx, read=read_flag
1809
+ )
1810
  cute.arch.barrier_arrive(
1811
  barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
1812
  number_of_threads=self.num_threads_per_warp_group
1813
  + cute.arch.WARP_SIZE,
1814
  )
1815
 
1816
+ # Semaphore acquire: wait for prior n_blocks to finish writing this m_block
1817
+ if const_expr(self.deterministic):
1818
+ if const_expr(self.spt):
1819
+ _, n_block_max_for_m_block = block_info.get_n_block_min_max(
1820
+ seqlen, m_block_safe
1821
+ )
1822
+ lock_value = n_block_max_for_m_block - 1 - n_block
1823
+ else:
1824
+ lock_value = n_block
1825
+ barrier.wait_eq(
1826
+ mdQ_semaphore_cur[(m_block_safe, None)].iterator,
1827
+ warp_local_tidx,
1828
+ 0, # flag_offset
1829
+ lock_value,
1830
+ )
1831
+
1832
+ for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks):
1833
  cute.arch.barrier(
1834
  barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
1835
  number_of_threads=self.num_threads_per_warp_group
 
1838
  with cute.arch.elect_one():
1839
  copy_utils.cpasync_reduce_bulk_add_f32(
1840
  sdQaccum[None, warp_group_idx].iterator,
1841
+ gdQaccum[(None, warp_group_idx), m_block_safe].iterator,
1842
  self.tma_copy_bytes["dQ"],
1843
  )
1844
  cute.arch.cp_async_bulk_commit_group()
1845
+
1846
+ # Semaphore release: signal that this n_block is done with this m_block
1847
+ if const_expr(self.deterministic):
1848
+ cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
1849
+ barrier.arrive_inc(
1850
+ mdQ_semaphore_cur[(m_block_safe, None)].iterator,
1851
+ warp_local_tidx,
1852
+ 0, # flag_offset
1853
+ 1,
1854
+ )
1855
  else:
1856
+ assert not self.deterministic, (
1857
+ "Deterministic not implemented for block-sparse backward"
1858
+ )
1859
  dQaccum_store_block_sparse_bwd_sm90(
1860
  blocksparse_tensors,
1861
  batch_idx,
 
1865
  gdQaccum,
1866
  subtile_factor=self.subtile_factor,
1867
  m_block_max=m_block_max,
1868
+ num_dQ_warp_groups=self.num_wg_dQ,
1869
  num_threads_per_warp_group=self.num_threads_per_warp_group,
1870
  tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
1871
  )
1872
+
1873
+ # For local masking + deterministic (non-spt): signal remaining m_blocks
1874
+ # that this n_block won't visit, so they don't deadlock waiting.
1875
+ if const_expr(
1876
+ self.deterministic and not self.spt and block_info.window_size_left is not None
1877
+ ):
1878
+ m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m)
1879
+ for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1):
1880
+ barrier.arrive_inc(
1881
+ mdQ_semaphore_cur[(m_block, None)].iterator,
1882
+ warp_local_tidx,
1883
+ 0, # flag_offset
1884
+ 1,
1885
+ )
1886
+
1887
  tile_scheduler.advance_to_next_work()
1888
  work_tile = tile_scheduler.get_current_work()
1889
 
1890
+ if const_expr(not self.deterministic):
1891
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
build/torch-cuda/flash_fwd.py CHANGED
@@ -15,42 +15,28 @@ import cuda.bindings.driver as cuda
15
  import cutlass
16
  import cutlass.cute as cute
17
  from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
18
- from cutlass.cute.nvgpu import cpasync, warp, warpgroup
19
  import cutlass.utils as utils_basic
20
- from cutlass.utils import LayoutEnum
21
- import cutlass.utils.hopper_helpers as sm90_utils_basic
22
 
23
  from .quack import copy_utils
24
  from .quack import layout_utils
25
- from .quack import sm90_utils
26
 
27
  from . import ampere_helpers as sm80_utils
28
  from .cute_dsl_utils import assume_tensor_aligned
29
  from . import utils
30
  from .mask import AttentionMask
31
- from .softmax import Softmax, apply_score_mod_inner
32
  from .seqlen_info import SeqlenInfoQK
33
  from .block_info import BlockInfo
34
- from .block_sparsity import BlockSparseTensors
35
- from .block_sparse_utils import (
36
- produce_block_sparse_loads,
37
- consume_block_sparse_loads,
38
- )
39
- from . import pipeline
40
  from .pack_gqa import PackGQA
41
  from .named_barrier import NamedBarrierFwd
42
- from .quack.cute_dsl_utils import ParamsBase
43
- from .tile_scheduler import (
44
- TileSchedulerArguments,
45
- SingleTileScheduler,
46
- SingleTileLPTScheduler,
47
- SingleTileVarlenScheduler,
48
- )
49
- from cutlass.cute import FastDivmodDivisor
50
 
51
 
52
  class FlashAttentionForwardBase:
53
- arch: int = 80
54
 
55
  def __init__(
56
  self,
@@ -116,6 +102,12 @@ class FlashAttentionForwardBase:
116
  self.vec_size: cutlass.Constexpr = getattr(
117
  score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
118
  )
 
 
 
 
 
 
119
 
120
  @staticmethod
121
  def can_implement(
@@ -318,7 +310,8 @@ class FlashAttentionForwardBase:
318
  mO: cute.Tensor,
319
  mLSE: Optional[cute.Tensor],
320
  softmax_scale: Float32,
321
- stream: cuda.CUstream,
 
322
  ):
323
  """Configures and launches the flash attention kernel.
324
 
@@ -351,7 +344,7 @@ class FlashAttentionForwardBase:
351
  cute.arch.barrier(
352
  barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
353
  )
354
- smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype)
355
  smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
356
  taccOrO = smem_thr_copy_O.retile(rO)
357
  taccOsO = smem_thr_copy_O.partition_D(sO)
@@ -366,11 +359,7 @@ class FlashAttentionForwardBase:
366
 
367
  # Write LSE from rmem -> gmem
368
  if const_expr(mLSE is not None):
369
- if const_expr(not seqlen.has_cu_seqlens_q):
370
- mLSE_cur = mLSE[None, head_idx, batch_idx]
371
- else:
372
- offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
373
- mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
374
  if const_expr(not self.pack_gqa):
375
  gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
376
  gLSE_expanded_layout = cute.append(
@@ -384,7 +373,7 @@ class FlashAttentionForwardBase:
384
  t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))
385
  # Only the thread corresponding to column 0 writes out the lse to gmem
386
  if taccOcO[0][1] == 0:
387
- for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])):
388
  if (
389
  t0accOcO[m, 0][0]
390
  < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]
@@ -393,11 +382,8 @@ class FlashAttentionForwardBase:
393
  else:
394
  pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
395
 
396
- if const_expr(not seqlen.has_cu_seqlens_q):
397
- mO_cur = mO[None, None, head_idx, batch_idx]
398
- else:
399
- offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)
400
- mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx])
401
  # thr_mma = tiled_mma.get_slice(tidx)
402
  # taccOgO = thr_mma.partition_C(gO)
403
  # cute.autovec_copy(rO, taccOgO)
@@ -634,12 +620,19 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
634
  mV: cute.Tensor,
635
  mO: cute.Tensor,
636
  mLSE: Optional[cute.Tensor],
637
- stream: cuda.CUstream,
638
- softmax_scale: Optional[Float32] = None,
 
 
 
 
639
  window_size_left: Optional[Int32] = None,
640
  window_size_right: Optional[Int32] = None,
641
  learnable_sink: Optional[cute.Tensor] = None,
 
642
  aux_tensors=None,
 
 
643
  ):
644
  """Configures and launches the flash attention kernel.
645
 
@@ -648,7 +641,7 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
648
  """
649
  assert learnable_sink is None, "Learnable sink is not supported in this kernel"
650
  self._check_type(
651
- *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))
652
  )
653
  tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
654
  self.num_mma_threads = tiled_mma_pv.size
@@ -656,41 +649,54 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
656
  self.num_Q_load_threads = self.num_threads
657
  self.num_epilogue_threads = self.num_threads
658
  # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None
659
- self.use_tma_O = self.arch >= 90
660
  self._setup_attributes()
661
  SharedStorage = self._get_shared_storage_cls()
662
  mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
663
- mQ, mK, mV, mO = [
664
- cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0]))
665
- for t in (mQ, mK, mV, mO)
 
 
 
666
  ]
667
- mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0]))
668
- # grid_dim: (m_block, num_head, batch_size)
669
- grid_dim = (
670
- cute.ceil_div(mQ.shape[0], self.tile_m),
671
- cute.size(mQ.shape[2]),
672
- cute.size(mQ.shape[3]),
673
- )
674
- LOG2_E = math.log2(math.e)
675
- if const_expr(self.score_mod is None):
676
- softmax_scale_log2 = Float32(softmax_scale * LOG2_E)
677
- softmax_scale = None
678
  else:
679
- # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
680
- # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
681
- # and correctly apply the softmax_scale prior to score_mod in the softmax step
682
- softmax_scale_log2 = Float32(LOG2_E)
683
- softmax_scale = Float32(softmax_scale)
684
-
685
- fastdiv_mods = None
686
- if const_expr(aux_tensors is not None):
687
- seqlen_q = cute.size(mQ.shape[0]) // (
688
- self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
689
- )
690
- seqlen_k = cute.size(mK.shape[0])
691
- seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
692
- seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
693
- fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
 
 
 
 
 
 
 
 
 
 
 
694
 
695
  self.kernel(
696
  mQ,
@@ -698,6 +704,10 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
698
  mV,
699
  mO,
700
  mLSE,
 
 
 
 
701
  softmax_scale_log2,
702
  softmax_scale,
703
  window_size_left,
@@ -714,6 +724,8 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
714
  tiled_mma_qk,
715
  tiled_mma_pv,
716
  SharedStorage,
 
 
717
  aux_tensors,
718
  fastdiv_mods,
719
  ).launch(
@@ -731,6 +743,10 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
731
  mV: cute.Tensor,
732
  mO: cute.Tensor,
733
  mLSE: Optional[cute.Tensor],
 
 
 
 
734
  softmax_scale_log2: Float32,
735
  softmax_scale: Optional[Float32],
736
  window_size_left: Optional[Int32],
@@ -747,12 +763,17 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
747
  tiled_mma_qk: cute.TiledMma,
748
  tiled_mma_pv: cute.TiledMma,
749
  SharedStorage: cutlass.Constexpr,
 
 
750
  aux_tensors=None,
751
  fastdiv_mods=None,
752
  ):
753
  # Thread index, block index
754
  tidx, _, _ = cute.arch.thread_idx()
755
- m_block, num_head, batch_size = cute.arch.block_idx()
 
 
 
756
 
757
  block_info = BlockInfo(
758
  self.tile_m,
@@ -764,13 +785,21 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
764
  window_size_right,
765
  qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
766
  )
767
- seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0])
 
 
 
 
 
 
 
 
768
  n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
769
- # TODO: return early if n_block_max == 0
770
- # if self.is_causal:
771
- # if n_block_max <= 0:
772
- # return
773
- n_block = n_block_max - 1
774
 
775
  # ///////////////////////////////////////////////////////////////////////////////
776
  # Get the appropriate tiles for this thread block.
@@ -778,10 +807,20 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
778
  blkQ_shape = (self.tile_m, self.tile_hdim)
779
  blkK_shape = (self.tile_n, self.tile_hdim)
780
  blkV_shape = (self.tile_n, self.tile_hdimv)
781
- gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0))
782
  num_head_kv = num_head // self.qhead_per_kvhead
783
- gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0))
784
- gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0))
 
 
 
 
 
 
 
 
 
 
 
785
 
786
  # ///////////////////////////////////////////////////////////////////////////////
787
  # Get shared memory buffer
@@ -953,18 +992,20 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
953
  mask = AttentionMask(
954
  self.tile_m,
955
  self.tile_n,
956
- seqlen.seqlen_q,
957
- seqlen.seqlen_k,
958
  window_size_left,
959
  window_size_right,
960
  self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
961
  )
962
  mask_fn = partial(
963
  mask.apply_mask,
 
 
964
  m_block=m_block,
965
  thr_mma=thr_mma_qk,
966
  mask_causal=self.is_causal,
967
  mask_local=self.is_local,
 
968
  fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
969
  )
970
 
@@ -976,8 +1017,8 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
976
  smem_pipe_read,
977
  smem_pipe_write,
978
  is_first_n_block=True,
979
- check_inf=True,
980
- mask_fn=partial(mask_fn, mask_seqlen=True),
981
  )
982
  smem_pipe_read = self.advance_pipeline(smem_pipe_read)
983
  smem_pipe_write = self.advance_pipeline(smem_pipe_write)
@@ -992,15 +1033,17 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
992
  n_block,
993
  smem_pipe_read,
994
  smem_pipe_write,
995
- check_inf=True,
996
- mask_fn=partial(mask_fn, mask_seqlen=False),
997
  )
998
  smem_pipe_read = self.advance_pipeline(smem_pipe_read)
999
  smem_pipe_write = self.advance_pipeline(smem_pipe_write)
1000
  # The remaining iterations have no masking
1001
  for n_tile in cutlass.range(n_block, unroll=1):
1002
  compute_one_n_block(
1003
- n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True
 
 
1004
  )
1005
  smem_pipe_read = self.advance_pipeline(smem_pipe_read)
1006
  smem_pipe_write = self.advance_pipeline(smem_pipe_write)
@@ -1144,1283 +1187,9 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
1144
  # load_K_next()
1145
 
1146
 
1147
- class FlashAttentionForwardSm90(FlashAttentionForwardBase):
1148
- arch = 90
1149
-
1150
- def __init__(
1151
- self,
1152
- *args,
1153
- intra_wg_overlap: bool = True,
1154
- mma_pv_is_rs: bool = True,
1155
- **kwargs,
1156
- ):
1157
- super().__init__(*args, **kwargs)
1158
- self.intra_wg_overlap = intra_wg_overlap
1159
- self.mma_pv_is_rs = mma_pv_is_rs
1160
- self.buffer_align_bytes = 1024
1161
-
1162
- def _get_smem_layout_atom(self):
1163
- sQ_layout_atom = warpgroup.make_smem_layout_atom(
1164
- sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),
1165
- self.dtype,
1166
- )
1167
- sK_layout_atom = sQ_layout_atom
1168
- sV_layout_atom = warpgroup.make_smem_layout_atom(
1169
- sm90_utils_basic.get_smem_layout_atom(
1170
- LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv
1171
- ),
1172
- self.dtype,
1173
- )
1174
- sO_layout_atom = sV_layout_atom
1175
- if not self.mma_pv_is_rs:
1176
- sP_layout_atom = warpgroup.make_smem_layout_atom(
1177
- sm90_utils_basic.get_smem_layout_atom(
1178
- LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n
1179
- ),
1180
- self.dtype,
1181
- )
1182
- else:
1183
- sP_layout_atom = None
1184
- return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
1185
-
1186
- def _get_tiled_mma(self):
1187
- tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
1188
- self.dtype,
1189
- self.dtype,
1190
- warpgroup.OperandMajorMode.K,
1191
- warpgroup.OperandMajorMode.K,
1192
- Float32,
1193
- atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
1194
- tiler_mn=(64, self.tile_n),
1195
- )
1196
- tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
1197
- self.dtype,
1198
- self.dtype,
1199
- warpgroup.OperandMajorMode.K,
1200
- warpgroup.OperandMajorMode.MN,
1201
- Float32,
1202
- atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
1203
- tiler_mn=(64, self.tile_hdimv),
1204
- a_source=warpgroup.OperandSource.RMEM
1205
- if self.mma_pv_is_rs
1206
- else warpgroup.OperandSource.SMEM,
1207
- )
1208
- return tiled_mma_qk, tiled_mma_pv
1209
-
1210
- def _get_shared_storage_cls(self):
1211
- sQ_struct, sK_struct, sV_struct = [
1212
- cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes]
1213
- for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
1214
-
1215
- ]
1216
- cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
1217
- sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
1218
- cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0
1219
- sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
1220
- # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V,
1221
- mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2]
1222
- mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
1223
- mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
1224
-
1225
- @cute.struct
1226
- class SharedStorageQKV:
1227
- mbar_ptr: mbar_ptr_QO_struct
1228
- mbar_ptr_K: mbar_ptr_K_struct
1229
- mbar_ptr_V: mbar_ptr_V_struct
1230
- sV: sV_struct
1231
- sQ: sQ_struct
1232
- sK: sK_struct
1233
- sP: sP_struct
1234
-
1235
- @cute.struct
1236
- class SharedStorageSharedQV:
1237
- mbar_ptr: mbar_ptr_QO_struct
1238
- mbar_ptr_K: mbar_ptr_K_struct
1239
- mbar_ptr_V: mbar_ptr_V_struct
1240
- sQ: sQV_struct
1241
- sK: sK_struct
1242
- sP: sP_struct
1243
-
1244
- return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
1245
-
1246
- @cute.jit
1247
- def __call__(
1248
- self,
1249
- mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1250
- mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
1251
- mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
1252
- mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
1253
- mLSE: Optional[cute.Tensor],
1254
- softmax_scale: Float32,
1255
- stream: cuda.CUstream,
1256
- mCuSeqlensQ: Optional[cute.Tensor] = None,
1257
- mCuSeqlensK: Optional[cute.Tensor] = None,
1258
- mSeqUsedQ: Optional[cute.Tensor] = None,
1259
- mSeqUsedK: Optional[cute.Tensor] = None,
1260
- mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
1261
- window_size_left: Int32 | int | None = None,
1262
- window_size_right: Int32 | int | None = None,
1263
- learnable_sink: Optional[cute.Tensor] = None,
1264
- blocksparse_tensors: Optional[BlockSparseTensors] = None,
1265
- aux_tensors: Optional[list] = None,
1266
- ):
1267
- """Configures and launches the flash attention kernel.
1268
-
1269
- mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
1270
- (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
1271
- """
1272
-
1273
- self._check_type(
1274
- *(
1275
- t.element_type if t is not None else None
1276
- for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)
1277
- )
1278
- )
1279
-
1280
- mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
1281
- QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
1282
- mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
1283
- KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
1284
- mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)]
1285
- LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
1286
- mLSE = layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None
1287
-
1288
- tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
1289
- self.num_mma_threads = tiled_mma_qk.size
1290
- self.num_threads_per_warp_group = 128
1291
- self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group
1292
- self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1)
1293
- self.num_producer_threads = 32
1294
- self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q
1295
- self.num_epilogue_threads = self.num_mma_threads
1296
- self.num_mma_regs = (
1297
- 256
1298
- if self.num_mma_warp_groups == 1
1299
- else (240 if self.num_mma_warp_groups == 2 else 160)
1300
- )
1301
- self.num_producer_regs = (
1302
- 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32)
1303
- )
1304
- # self.num_mma_regs = 232
1305
- # self.num_producer_regs = 40
1306
- self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
1307
-
1308
- self.use_scheduler_barrier = (
1309
- (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128)
1310
- if const_expr(self.intra_wg_overlap)
1311
- else (self.num_mma_warp_groups == 2)
1312
- )
1313
- self.use_tma_Q = self.arch >= 90 and not (
1314
- self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0
1315
- )
1316
- self.use_tma_O = (
1317
- self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa
1318
- )
1319
- # TODO: rescale_O_before_gemm
1320
- self._setup_attributes()
1321
- # TODO: we prob don't need most of what's in _setup_attributes
1322
- self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [
1323
- sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)
1324
- for mX, shape, stage in [
1325
- (mQ, (self.tile_m, self.tile_hdim), None),
1326
- (mK, (self.tile_n, self.tile_hdim), self.num_stages),
1327
- (mV, (self.tile_n, self.tile_hdimv), self.num_stages),
1328
- (mO, (self.tile_m, self.tile_hdimv), None),
1329
- ]
1330
- ]
1331
- self.sP_layout = None
1332
- if const_expr(not self.mma_pv_is_rs):
1333
- self.sP_layout = sm90_utils.make_smem_layout(
1334
- mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
1335
- )
1336
-
1337
- SharedStorage = self._get_shared_storage_cls()
1338
-
1339
- if const_expr(self.pack_gqa):
1340
- shape_Q_packed = (
1341
- (self.qhead_per_kvhead, mQ.shape[0]),
1342
- mQ.shape[1],
1343
- mK.shape[2],
1344
- *mQ.shape[3:],
1345
- )
1346
- stride_Q_packed = (
1347
- (mQ.stride[2], mQ.stride[0]),
1348
- mQ.stride[1],
1349
- mQ.stride[2] * self.qhead_per_kvhead,
1350
- *mQ.stride[3:],
1351
- )
1352
- mQ = cute.make_tensor(
1353
- mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
1354
- )
1355
- shape_O_packed = (
1356
- (self.qhead_per_kvhead, mO.shape[0]),
1357
- mK.shape[1],
1358
- mK.shape[2],
1359
- *mO.shape[3:],
1360
- )
1361
- stride_O_packed = (
1362
- (mO.stride[2], mO.stride[0]),
1363
- mO.stride[1],
1364
- mO.stride[2] * self.qhead_per_kvhead,
1365
- *mO.stride[3:],
1366
- )
1367
- mO = cute.make_tensor(
1368
- mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
1369
- )
1370
- if const_expr(mLSE is not None):
1371
- shape_LSE_packed = (
1372
- (self.qhead_per_kvhead, mLSE.shape[0]),
1373
- mK.shape[2],
1374
- *mLSE.shape[2:],
1375
- )
1376
- stride_LSE_packed = (
1377
- (mLSE.stride[1], mLSE.stride[0]),
1378
- mLSE.stride[1] * self.qhead_per_kvhead,
1379
- *mLSE.stride[2:],
1380
- )
1381
- mLSE = cute.make_tensor(
1382
- mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
1383
- )
1384
-
1385
- # TMA
1386
- gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()
1387
- gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast
1388
- gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()
1389
- self.tma_copy_bytes = {
1390
- name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
1391
- for name, mX, layout in [
1392
- ("Q", mQ, self.sQ_layout),
1393
- ("K", mK, self.sK_layout),
1394
- ("V", mV, self.sV_layout),
1395
- ]
1396
- }
1397
- tma_atom_Q, tma_tensor_Q = None, None
1398
- if const_expr(self.use_tma_Q):
1399
- tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(
1400
- gmem_tiled_copy_Q,
1401
- mQ,
1402
- self.sQ_layout,
1403
- (self.tile_m, self.tile_hdim), # No mcast
1404
- )
1405
- tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
1406
- gmem_tiled_copy_KV,
1407
- mK,
1408
- cute.select(self.sK_layout, mode=[0, 1]),
1409
- (self.tile_n, self.tile_hdim),
1410
- 1, # No mcast for now
1411
- )
1412
- tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
1413
- gmem_tiled_copy_KV,
1414
- mV,
1415
- cute.select(self.sV_layout, mode=[0, 1]),
1416
- (self.tile_n, self.tile_hdimv),
1417
- 1, # No mcast for now
1418
- )
1419
- tma_atom_O, tma_tensor_O = None, None
1420
- if const_expr(self.use_tma_O):
1421
- tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom(
1422
- gmem_tiled_copy_O,
1423
- mO,
1424
- self.sO_layout,
1425
- (self.tile_m, self.tile_hdimv), # No mcast
1426
- )
1427
- if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
1428
- TileScheduler = SingleTileVarlenScheduler
1429
- else:
1430
- TileScheduler = (
1431
- SingleTileScheduler
1432
- if const_expr(not self.is_causal or self.is_local)
1433
- else SingleTileLPTScheduler
1434
- )
1435
- tile_sched_args = TileSchedulerArguments(
1436
- cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),
1437
- cute.size(mQ.shape[2]),
1438
- cute.size(mQ.shape[3])
1439
- if const_expr(mCuSeqlensQ is None)
1440
- else cute.size(mCuSeqlensQ.shape[0] - 1),
1441
- 1, # num_splits
1442
- cute.size(mK.shape[0]),
1443
- mQ.shape[1],
1444
- mV.shape[1],
1445
- total_q=cute.size(mQ.shape[0])
1446
- if const_expr(mCuSeqlensQ is not None)
1447
- else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
1448
- tile_shape_mn=(self.tile_m, self.tile_n),
1449
- mCuSeqlensQ=mCuSeqlensQ,
1450
- mSeqUsedQ=mSeqUsedQ,
1451
- qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1452
- element_size=self.dtype.width // 8,
1453
- is_persistent=False,
1454
- lpt=self.is_causal or self.is_local,
1455
- )
1456
- tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
1457
- grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
1458
- LOG2_E = math.log2(math.e)
1459
- if const_expr(self.score_mod is None):
1460
- softmax_scale_log2 = softmax_scale * LOG2_E
1461
- softmax_scale = None
1462
- else:
1463
- # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk
1464
- # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
1465
- # and correctly apply the softmax_scale prior to score_mod in the softmax step
1466
- softmax_scale_log2 = LOG2_E
1467
- softmax_scale = softmax_scale
1468
- if const_expr(window_size_left is not None):
1469
- window_size_left = Int32(window_size_left)
1470
- if const_expr(window_size_right is not None):
1471
- window_size_right = Int32(window_size_right)
1472
-
1473
- fastdiv_mods = None
1474
- if const_expr(aux_tensors is not None):
1475
- seqlen_q = cute.size(mQ.shape[0]) // (
1476
- self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
1477
- )
1478
- seqlen_k = (
1479
- cute.size(mK.shape[0])
1480
- if const_expr(mPageTable is None)
1481
- else mK.shape[0] * mPageTable.shape[1]
1482
- )
1483
- seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
1484
- seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
1485
- fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
1486
-
1487
- self.kernel(
1488
- tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,
1489
- tma_tensor_K,
1490
- tma_tensor_V,
1491
- tma_tensor_O if const_expr(self.use_tma_O) else mO,
1492
- mLSE,
1493
- mCuSeqlensQ,
1494
- mCuSeqlensK,
1495
- mSeqUsedQ,
1496
- mSeqUsedK,
1497
- tma_atom_Q,
1498
- tma_atom_K,
1499
- tma_atom_V,
1500
- tma_atom_O,
1501
- softmax_scale_log2,
1502
- softmax_scale,
1503
- window_size_left,
1504
- window_size_right,
1505
- learnable_sink,
1506
- blocksparse_tensors,
1507
- self.sQ_layout,
1508
- self.sK_layout,
1509
- self.sV_layout,
1510
- self.sO_layout,
1511
- self.sP_layout,
1512
- self.gmem_tiled_copy_Q,
1513
- self.gmem_tiled_copy_K,
1514
- self.gmem_tiled_copy_V,
1515
- self.gmem_tiled_copy_O,
1516
- tiled_mma_qk,
1517
- tiled_mma_pv,
1518
- tile_sched_params,
1519
- TileScheduler,
1520
- SharedStorage,
1521
- aux_tensors,
1522
- fastdiv_mods,
1523
- ).launch(
1524
- grid=grid_dim,
1525
- block=[self.num_threads, 1, 1],
1526
- stream=stream,
1527
- min_blocks_per_mp=1,
1528
- )
1529
-
1530
- @cute.kernel
1531
- def kernel(
1532
- self,
1533
- mQ: cute.Tensor,
1534
- mK: cute.Tensor,
1535
- mV: cute.Tensor,
1536
- mO: cute.Tensor,
1537
- mLSE: Optional[cute.Tensor],
1538
- mCuSeqlensQ: Optional[cute.Tensor],
1539
- mCuSeqlensK: Optional[cute.Tensor],
1540
- mSeqUsedQ: Optional[cute.Tensor],
1541
- mSeqUsedK: Optional[cute.Tensor],
1542
- tma_atom_Q: Optional[cute.CopyAtom],
1543
- tma_atom_K: Optional[cute.CopyAtom],
1544
- tma_atom_V: Optional[cute.CopyAtom],
1545
- tma_atom_O: Optional[cute.CopyAtom],
1546
- softmax_scale_log2: Float32,
1547
- softmax_scale: Optional[Float32],
1548
- window_size_left: Optional[Int32],
1549
- window_size_right: Optional[Int32],
1550
- learnable_sink: Optional[cute.Tensor],
1551
- blocksparse_tensors: Optional[BlockSparseTensors],
1552
- sQ_layout: cute.ComposedLayout,
1553
- sK_layout: cute.ComposedLayout,
1554
- sV_layout: cute.ComposedLayout,
1555
- sO_layout: cute.ComposedLayout,
1556
- sP_layout: cute.ComposedLayout | None,
1557
- gmem_tiled_copy_Q: cute.TiledCopy,
1558
- gmem_tiled_copy_K: cute.TiledCopy,
1559
- gmem_tiled_copy_V: cute.TiledCopy,
1560
- gmem_tiled_copy_O: cute.TiledCopy,
1561
- tiled_mma_qk: cute.TiledMma,
1562
- tiled_mma_pv: cute.TiledMma,
1563
- tile_sched_params: ParamsBase,
1564
- TileScheduler: cutlass.Constexpr[Callable],
1565
- SharedStorage: cutlass.Constexpr[Callable],
1566
- aux_tensors=Optional[list[cute.Tensor]],
1567
- fastdiv_mods=None,
1568
- ):
1569
- warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1570
- # Prefetch tma descriptor
1571
- if warp_idx == 0:
1572
- for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):
1573
- if const_expr(tma_atom is not None):
1574
- cpasync.prefetch_descriptor(tma_atom)
1575
-
1576
- smem = cutlass.utils.SmemAllocator()
1577
- storage = smem.allocate(SharedStorage)
1578
-
1579
- # Mbarrier init
1580
- mbar_ptr_Q = storage.mbar_ptr.data_ptr()
1581
- if warp_idx == 1:
1582
- # if tidx < 2:
1583
- # # barrierO num threads should be self.num_mma_threads
1584
- # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads)
1585
- if const_expr(not self.use_tma_Q):
1586
- cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads)
1587
- # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads)
1588
- # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync
1589
- pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(
1590
- cutlass.pipeline.Agent.Thread
1591
- )
1592
- pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup(
1593
- cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE
1594
- )
1595
- pipeline_k = pipeline.PipelineTmaAsync.create(
1596
- barrier_storage=storage.mbar_ptr_K.data_ptr(),
1597
- num_stages=self.num_stages,
1598
- producer_group=pipeline_kv_producer_group,
1599
- consumer_group=pipeline_kv_consumer_group,
1600
- tx_count=self.tma_copy_bytes["K"],
1601
- defer_sync=True,
1602
- )
1603
- pipeline_v = pipeline.PipelineTmaAsync.create(
1604
- barrier_storage=storage.mbar_ptr_V.data_ptr(),
1605
- num_stages=self.num_stages,
1606
- producer_group=pipeline_kv_producer_group,
1607
- consumer_group=pipeline_kv_consumer_group,
1608
- tx_count=self.tma_copy_bytes["V"],
1609
- defer_sync=False
1610
- )
1611
-
1612
- # ///////////////////////////////////////////////////////////////////////////////
1613
- # Get shared memory buffer
1614
- # ///////////////////////////////////////////////////////////////////////////////
1615
- sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
1616
- sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
1617
- if const_expr(not self.Q_in_regs):
1618
- sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
1619
- else:
1620
- sV = storage.sQ.get_tensor(
1621
- sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type
1622
- )
1623
- # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
1624
- sVt = layout_utils.transpose_view(sV)
1625
- sP = None
1626
- if const_expr(sP_layout is not None):
1627
- sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)
1628
- # reuse sQ's data iterator
1629
- sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)
1630
-
1631
- block_info = BlockInfo(
1632
- self.tile_m,
1633
- self.tile_n,
1634
- self.is_causal,
1635
- self.is_local,
1636
- False, # is_split_kv
1637
- window_size_left,
1638
- window_size_right,
1639
- qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1640
- )
1641
- SeqlenInfoCls = partial(
1642
- SeqlenInfoQK.create,
1643
- seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
1644
- seqlen_k_static=mK.shape[0],
1645
- mCuSeqlensQ=mCuSeqlensQ,
1646
- mCuSeqlensK=mCuSeqlensK,
1647
- mSeqUsedQ=mSeqUsedQ,
1648
- mSeqUsedK=mSeqUsedK,
1649
- )
1650
- AttentionMaskCls = partial(
1651
- AttentionMask,
1652
- self.tile_m,
1653
- self.tile_n,
1654
- window_size_left=window_size_left,
1655
- window_size_right=window_size_right,
1656
- qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1657
- )
1658
- TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
1659
-
1660
- if warp_idx < 4: # Producer
1661
- cute.arch.setmaxregister_decrease(self.num_producer_regs)
1662
- self.load(
1663
- mQ,
1664
- mK,
1665
- mV,
1666
- sQ,
1667
- sK,
1668
- sV,
1669
- tma_atom_Q,
1670
- tma_atom_K,
1671
- tma_atom_V,
1672
- pipeline_k,
1673
- pipeline_v,
1674
- mbar_ptr_Q,
1675
- blocksparse_tensors,
1676
- block_info,
1677
- SeqlenInfoCls,
1678
- TileSchedulerCls,
1679
- )
1680
-
1681
- else: # Consumer
1682
- cute.arch.setmaxregister_increase(self.num_mma_regs)
1683
- # ///////////////////////////////////////////////////////////////////////////////
1684
- # Tile MMA compute thread partitions and allocate accumulators
1685
- # ///////////////////////////////////////////////////////////////////////////////
1686
- tidx, _, _ = cute.arch.thread_idx()
1687
- tidx = tidx - 128
1688
- self.mma(
1689
- tiled_mma_qk,
1690
- tiled_mma_pv,
1691
- mQ,
1692
- mO,
1693
- mLSE,
1694
- sQ,
1695
- sK,
1696
- sVt,
1697
- sP,
1698
- sO,
1699
- learnable_sink,
1700
- pipeline_k,
1701
- pipeline_v,
1702
- mbar_ptr_Q,
1703
- gmem_tiled_copy_Q,
1704
- gmem_tiled_copy_O,
1705
- tma_atom_O,
1706
- tidx,
1707
- softmax_scale_log2,
1708
- softmax_scale,
1709
- block_info,
1710
- SeqlenInfoCls,
1711
- AttentionMaskCls,
1712
- TileSchedulerCls,
1713
- blocksparse_tensors,
1714
- aux_tensors,
1715
- fastdiv_mods,
1716
- )
1717
-
1718
- @cute.jit
1719
- def load(
1720
- self,
1721
- mQ: cute.Tensor,
1722
- mK: cute.Tensor,
1723
- mV: cute.Tensor,
1724
- sQ: cute.Tensor,
1725
- sK: cute.Tensor,
1726
- sV: cute.Tensor,
1727
- tma_atom_Q: cute.CopyAtom,
1728
- tma_atom_K: cute.CopyAtom,
1729
- tma_atom_V: cute.CopyAtom,
1730
- pipeline_k: cutlass.pipeline.PipelineAsync,
1731
- pipeline_v: cutlass.pipeline.PipelineAsync,
1732
- mbar_ptr_Q: cutlass.Pointer,
1733
- blocksparse_tensors: Optional[BlockSparseTensors],
1734
- block_info: BlockInfo,
1735
- SeqlenInfoCls: Callable,
1736
- TileSchedulerCls: Callable,
1737
- ):
1738
- warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
1739
- if warp_idx_in_wg == 0:
1740
- q_producer_phase = Int32(1)
1741
- kv_producer_state = pipeline.make_pipeline_state(
1742
- cutlass.pipeline.PipelineUserType.Producer, self.num_stages
1743
- )
1744
- tile_scheduler = TileSchedulerCls()
1745
- work_tile = tile_scheduler.initial_work_tile_info()
1746
- while work_tile.is_valid_tile:
1747
- # if work_tile.is_valid_tile:
1748
- m_block, head_idx, batch_idx, _ = work_tile.tile_idx
1749
- seqlen = SeqlenInfoCls(batch_idx)
1750
- mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
1751
- head_idx_kv = (
1752
- head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
1753
- )
1754
- mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
1755
- mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
1756
- gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))
1757
- gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
1758
- load_Q = None
1759
- if const_expr(self.use_tma_Q):
1760
- gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
1761
- load_Q, _, _ = copy_utils.tma_get_copy_fn(
1762
- tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True
1763
- )
1764
- # TODO: mcast
1765
- # TODO check warp_idx if we have 128 producer threads
1766
- load_K, _, _ = copy_utils.tma_get_copy_fn(
1767
- tma_atom_K, 0, cute.make_layout(1), gK, sK
1768
- )
1769
- load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k)
1770
- load_V, _, _ = copy_utils.tma_get_copy_fn(
1771
- tma_atom_V, 0, cute.make_layout(1), gV, sV
1772
- )
1773
- load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v)
1774
-
1775
- if const_expr(not self.use_block_sparsity):
1776
- n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
1777
- # if cute.arch.thread_idx()[0] == 0:
1778
- # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max)
1779
- # First iteration: load both Q & K with the same mbarrier
1780
- n_block = n_block_max - 1
1781
- pipeline_k.producer_acquire(
1782
- kv_producer_state,
1783
- extra_tx_count=self.tma_copy_bytes["Q"]
1784
- if const_expr(self.use_tma_Q)
1785
- else 0,
1786
- )
1787
- if const_expr(self.use_tma_Q):
1788
- load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state))
1789
- load_K(src_idx=n_block, producer_state=kv_producer_state)
1790
-
1791
- if const_expr(not self.intra_wg_overlap):
1792
- pipeline_v.producer_acquire(kv_producer_state)
1793
- load_V(src_idx=n_block, producer_state=kv_producer_state)
1794
- kv_producer_state.advance()
1795
- for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1796
- n_block = n_block_max - 1 - i - 1
1797
- pipeline_k.producer_acquire(kv_producer_state)
1798
- load_K(src_idx=n_block, producer_state=kv_producer_state)
1799
- pipeline_v.producer_acquire(kv_producer_state)
1800
- load_V(src_idx=n_block, producer_state=kv_producer_state)
1801
- kv_producer_state.advance()
1802
- else:
1803
- for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1804
- n_block_prev = n_block_max - i - 1
1805
- n_block = n_block_prev - 1
1806
- kv_producer_state_prev = kv_producer_state.clone()
1807
- kv_producer_state.advance()
1808
- pipeline_k.producer_acquire(kv_producer_state)
1809
- load_K(src_idx=n_block, producer_state=kv_producer_state)
1810
- pipeline_v.producer_acquire(kv_producer_state_prev)
1811
- load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)
1812
- n_block = n_block_min
1813
- pipeline_v.producer_acquire(kv_producer_state)
1814
- load_V(src_idx=n_block, producer_state=kv_producer_state)
1815
- kv_producer_state.advance()
1816
- else:
1817
- kv_producer_state = produce_block_sparse_loads(
1818
- blocksparse_tensors,
1819
- batch_idx,
1820
- head_idx,
1821
- m_block,
1822
- kv_producer_state,
1823
- load_Q,
1824
- load_K,
1825
- load_V,
1826
- pipeline_k,
1827
- pipeline_v,
1828
- self.use_tma_Q,
1829
- self.tma_copy_bytes["Q"],
1830
- self.intra_wg_overlap,
1831
- self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1832
- self.q_subtile_factor if self.q_subtile_factor is not None else 1,
1833
- )
1834
-
1835
- tile_scheduler.prefetch_next_work()
1836
- tile_scheduler.advance_to_next_work()
1837
- work_tile = tile_scheduler.get_current_work()
1838
- # End of persistent scheduler loop
1839
-
1840
- @cute.jit
1841
- def mma(
1842
- self,
1843
- tiled_mma_qk: cute.TiledMma,
1844
- tiled_mma_pv: cute.TiledMma,
1845
- # softmax: Softmax,
1846
- # acc_O: cute.Tensor,
1847
- mQ: cute.Tensor,
1848
- mO: cute.Tensor,
1849
- mLSE: Optional[cute.Tensor],
1850
- sQ: cute.Tensor,
1851
- sK: cute.Tensor,
1852
- sVt: cute.Tensor,
1853
- sP: Optional[cute.Tensor],
1854
- sO: cute.Tensor,
1855
- learnable_sink: Optional[cute.Tensor],
1856
- pipeline_k: cutlass.pipeline.PipelineAsync,
1857
- pipeline_v: cutlass.pipeline.PipelineAsync,
1858
- mbar_ptr_Q: cutlass.Pointer,
1859
- gmem_tiled_copy_Q: cute.TiledCopy,
1860
- gmem_tiled_copy_O: cute.TiledCopy,
1861
- tma_atom_O: Optional[cute.CopyAtom],
1862
- tidx: Int32,
1863
- softmax_scale_log2: Float32,
1864
- softmax_scale: Optional[Float32],
1865
- block_info: BlockInfo,
1866
- SeqlenInfoCls: Callable,
1867
- AttentionMaskCls: Callable,
1868
- TileSchedulerCls: Callable,
1869
- blocksparse_tensors: Optional[BlockSparseTensors],
1870
- aux_tensors: Optional[list],
1871
- fastdiv_mods=None,
1872
- ):
1873
- warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
1874
- warp_group_thread_layout = cute.make_layout(
1875
- self.num_mma_warp_groups, stride=self.num_threads_per_warp_group
1876
- )
1877
- thr_mma_qk = tiled_mma_qk.get_slice(tidx)
1878
- wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
1879
- wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
1880
- _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
1881
- wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK
1882
- )
1883
- mma_qk_fn = partial(
1884
- sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
1885
- )
1886
- acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC(
1887
- wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt
1888
- )
1889
- mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
1890
-
1891
- # ///////////////////////////////////////////////////////////////////////////////
1892
- # Smem copy atom tiling
1893
- # ///////////////////////////////////////////////////////////////////////////////
1894
- smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype)
1895
- smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
1896
- tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
1897
- smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
1898
-
1899
- self.mma_init()
1900
-
1901
- mma_one_n_block_all = partial(
1902
- self.mma_one_n_block_intrawg_overlap
1903
- if const_expr(self.intra_wg_overlap)
1904
- else self.mma_one_n_block,
1905
- mma_qk_fn=mma_qk_fn,
1906
- pipeline_k=pipeline_k,
1907
- pipeline_v=pipeline_v,
1908
- acc_O=acc_O,
1909
- tOrP=tOrP,
1910
- smem_copy_params=smem_copy_params,
1911
- check_inf=True,
1912
- )
1913
-
1914
- q_consumer_phase = Int32(0)
1915
- kv_consumer_state = pipeline.make_pipeline_state(
1916
- cutlass.pipeline.PipelineUserType.Consumer, self.num_stages
1917
- )
1918
-
1919
- tile_scheduler = TileSchedulerCls()
1920
- work_tile = tile_scheduler.initial_work_tile_info()
1921
- softmax = Softmax.create(
1922
- softmax_scale_log2,
1923
- num_rows=acc_O.shape[0][0] * acc_O.shape[1],
1924
- softmax_scale=softmax_scale,
1925
- )
1926
-
1927
- process_first_half_block = partial(
1928
- self.first_half_block_overlap,
1929
- mma_qk_fn=mma_qk_fn,
1930
- pipeline_k=pipeline_k,
1931
- tOrP=tOrP,
1932
- smem_copy_params=smem_copy_params,
1933
- softmax=softmax,
1934
- )
1935
- process_last_half_block = partial(
1936
- self.last_half_block_overlap,
1937
- pipeline_v=pipeline_v,
1938
- mma_pv_fn=mma_pv_fn,
1939
- )
1940
- while work_tile.is_valid_tile:
1941
- # if work_tile.is_valid_tile:
1942
-
1943
- # shape: (atom_v_m * rest_m)
1944
- m_block, head_idx, batch_idx, _ = work_tile.tile_idx
1945
- seqlen = SeqlenInfoCls(batch_idx)
1946
-
1947
- # Recompute fastdiv_mods if necessary for varlen with aux_tensors
1948
- recompute_fastdiv_mods_q = cutlass.const_expr(
1949
- aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
1950
- )
1951
- recompute_fastdiv_mods_k = cutlass.const_expr(
1952
- aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
1953
- )
1954
- if cutlass.const_expr(fastdiv_mods is not None):
1955
- seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
1956
- fastdiv_mods = (
1957
- seqlen_q_divmod
1958
- if not recompute_fastdiv_mods_q
1959
- else FastDivmodDivisor(seqlen.seqlen_q),
1960
- seqlen_k_divmod
1961
- if not recompute_fastdiv_mods_k
1962
- else FastDivmodDivisor(seqlen.seqlen_k),
1963
- )
1964
-
1965
- mask = AttentionMaskCls(seqlen)
1966
- mask_fn = partial(
1967
- mask.apply_mask,
1968
- batch_idx=batch_idx,
1969
- head_idx=head_idx,
1970
- m_block=m_block,
1971
- thr_mma=thr_mma_qk,
1972
- mask_causal=self.is_causal,
1973
- mask_local=self.is_local,
1974
- aux_tensors=aux_tensors,
1975
- fastdiv_mods=fastdiv_mods,
1976
- )
1977
- score_mod_fn = None
1978
- if const_expr(self.score_mod is not None):
1979
- score_mod_fn = partial(
1980
- self.apply_score_mod,
1981
- thr_mma_qk,
1982
- batch_idx,
1983
- head_idx,
1984
- m_block,
1985
- softmax_scale=softmax_scale,
1986
- aux_tensors=aux_tensors,
1987
- fastdiv_mods=fastdiv_mods,
1988
- )
1989
- mma_one_n_block = partial(
1990
- mma_one_n_block_all,
1991
- seqlen=seqlen,
1992
- softmax=softmax,
1993
- score_mod_fn=score_mod_fn,
1994
- )
1995
- # Load Q if not TMA_Q
1996
- if const_expr(not self.use_tma_Q):
1997
- pack_gqa = PackGQA(
1998
- self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead
1999
- )
2000
- mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
2001
- # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)
2002
- # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
2003
- # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q,
2004
- # headdim=mQ.shape[1])
2005
- pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q)
2006
- cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q)
2007
-
2008
- n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
2009
- if const_expr(not self.use_tma_Q):
2010
- cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase)
2011
- q_consumer_phase ^= 1
2012
- # For performance reason, we separate out two kinds of iterations:
2013
- # those that need masking on S, and those that don't.
2014
- # We need masking on S for the very last block when K and V has length not multiple of tile_n.
2015
- # We also need masking on S if it's causal, for the last several blocks.
2016
- # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True
2017
- O_should_accumulate = False
2018
-
2019
- # ==========================================
2020
- # MAINLOOP
2021
- # ==========================================
2022
- if const_expr(not self.use_block_sparsity):
2023
- # ==========================================
2024
- # No block-sparsity (original path)
2025
- # ==========================================
2026
- # First iteration with seqlen masking
2027
- if const_expr(self.intra_wg_overlap):
2028
- kv_consumer_state = process_first_half_block(
2029
- n_block=n_block_max - 1,
2030
- seqlen=seqlen,
2031
- kv_consumer_state=kv_consumer_state,
2032
- mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
2033
- score_mod_fn=score_mod_fn,
2034
- is_first_block=True,
2035
- )
2036
- # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter
2037
- # acc_O.fill(0.0)
2038
- else:
2039
- self.warp_scheduler_barrier_sync()
2040
- kv_consumer_state = mma_one_n_block(
2041
- kv_consumer_state,
2042
- n_block=n_block_max - 1,
2043
- seqlen=seqlen,
2044
- mma_pv_fn=partial(mma_pv_fn, zero_init=True),
2045
- is_first_n_block=True,
2046
- mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
2047
- )
2048
- O_should_accumulate = True
2049
- # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min)
2050
- n_block_max -= 1
2051
- # Next couple of iterations with causal masking
2052
- if const_expr(self.is_causal or self.is_local):
2053
- n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
2054
- seqlen, m_block, n_block_min
2055
- )
2056
- # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask)
2057
- for n_tile in cutlass.range(
2058
- n_block_max - n_block_min_causal_local_mask, unroll=1
2059
- ):
2060
- kv_consumer_state = mma_one_n_block(
2061
- kv_consumer_state,
2062
- n_block=n_block_max - 1 - n_tile,
2063
- seqlen=seqlen,
2064
- mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
2065
- mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
2066
- )
2067
- O_should_accumulate = True
2068
- n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
2069
- # The remaining iterations have no masking
2070
- n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
2071
- seqlen, m_block, n_block_min
2072
- )
2073
- # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min)
2074
- for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
2075
- kv_consumer_state = mma_one_n_block(
2076
- kv_consumer_state,
2077
- n_block=n_block_max - 1 - n_tile,
2078
- seqlen=seqlen,
2079
- mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
2080
- mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
2081
- )
2082
- O_should_accumulate = True
2083
- # Separate iterations with local masking on the left
2084
- if const_expr(self.is_local and block_info.window_size_left is not None):
2085
- n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
2086
- for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):
2087
- kv_consumer_state = mma_one_n_block(
2088
- kv_consumer_state,
2089
- n_block=n_block_max - 1 - n_tile,
2090
- seqlen=seqlen,
2091
- mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
2092
- mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
2093
- )
2094
- O_should_accumulate = True
2095
- # Last "half" iteration
2096
- if const_expr(self.intra_wg_overlap):
2097
- kv_consumer_state = process_last_half_block(
2098
- kv_consumer_state=kv_consumer_state,
2099
- zero_init=not O_should_accumulate,
2100
- )
2101
- O_should_accumulate = True
2102
- else:
2103
- self.warp_scheduler_barrier_arrive()
2104
-
2105
- else:
2106
- # ==========================================
2107
- # Block sparsity
2108
- # ==========================================
2109
- kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(
2110
- blocksparse_tensors,
2111
- batch_idx,
2112
- head_idx,
2113
- m_block,
2114
- seqlen,
2115
- kv_consumer_state,
2116
- mma_pv_fn,
2117
- mma_one_n_block,
2118
- process_first_half_block,
2119
- process_last_half_block,
2120
- mask_fn,
2121
- score_mod_fn,
2122
- O_should_accumulate,
2123
- self.mask_mod,
2124
- fastdiv_mods,
2125
- self.intra_wg_overlap,
2126
- self.warp_scheduler_barrier_sync,
2127
- self.warp_scheduler_barrier_arrive,
2128
- self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
2129
- self.q_subtile_factor if self.q_subtile_factor is not None else 1,
2130
- )
2131
-
2132
- # Handle empty case (when no blocks to process)
2133
- if not processed_any:
2134
- softmax.reset()
2135
- acc_O.fill(0.0)
2136
-
2137
- sink_val = None
2138
- if const_expr(learnable_sink is not None):
2139
- if const_expr(not self.pack_gqa):
2140
- sink_val = Float32(learnable_sink[head_idx])
2141
- else: # Each thread might have a different sink value due to different q_head
2142
- sink_val = cute.make_fragment_like(softmax.row_max, Float32)
2143
- cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
2144
- tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))
2145
- for r in cutlass.range(cute.size(sink_val), unroll_full=True):
2146
- row = m_block * self.tile_m + tScS_mn[r][0]
2147
- q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
2148
- sink_val[r] = Float32(learnable_sink[q_head_idx])
2149
-
2150
- # normalize acc_O by row_sum and calculate the lse
2151
- row_scale = softmax.finalize(sink_val=sink_val)
2152
- softmax.rescale_O(acc_O, row_scale)
2153
-
2154
- # ///////////////////////////////////////////////////////////////////////////////
2155
- # Epilogue
2156
- # ///////////////////////////////////////////////////////////////////////////////
2157
- self.epilogue(
2158
- acc_O,
2159
- softmax.row_sum,
2160
- mO,
2161
- mLSE,
2162
- sO,
2163
- seqlen,
2164
- gmem_tiled_copy_O,
2165
- tma_atom_O,
2166
- tiled_mma_pv,
2167
- tidx,
2168
- m_block,
2169
- head_idx,
2170
- batch_idx,
2171
- )
2172
-
2173
- tile_scheduler.advance_to_next_work()
2174
- work_tile = tile_scheduler.get_current_work()
2175
-
2176
-
2177
- @cute.jit
2178
- def first_half_block_overlap(
2179
- self,
2180
- n_block: Int32,
2181
- mma_qk_fn: Callable,
2182
- kv_consumer_state,
2183
- pipeline_k,
2184
- tOrP: cute.Tensor,
2185
- smem_copy_params: SimpleNamespace,
2186
- softmax: Softmax,
2187
- seqlen: SeqlenInfoQK,
2188
- mask_fn: Callable = None,
2189
- score_mod_fn: Optional[Callable] = None,
2190
- is_first_block: bool = False,
2191
- ):
2192
- """Processes the first half block when using intra-warpgroup-overlap"""
2193
-
2194
- pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
2195
- acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
2196
- pipeline_k.consumer_release(kv_consumer_state)
2197
-
2198
- # Apply score modification if present
2199
- if const_expr(score_mod_fn is not None):
2200
- score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
2201
-
2202
- # Apply mask; mask_seqlen always True for first block
2203
- # Caveat: if full block further right than mask block, seqlen masking is redundant;
2204
- # however, masking is being applied anyway, so essentially no perf hit
2205
- mask_fn(acc_S, n_block=n_block, mask_seqlen=True)
2206
-
2207
- softmax.online_softmax(acc_S, is_first=is_first_block)
2208
-
2209
- tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
2210
- tOrP_cur = (
2211
- tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
2212
- )
2213
- tOrP_cur.store(tOrP_acc.load().to(self.dtype))
2214
-
2215
- # if pv gemm not rs
2216
- if const_expr(not self.mma_pv_is_rs):
2217
- tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
2218
- cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
2219
- # Fence and barrier to make smem store visible to WGMMA
2220
- cute.arch.fence_view_async_shared()
2221
- cute.arch.sync_warp()
2222
-
2223
- return kv_consumer_state
2224
-
2225
- @cute.jit
2226
- def last_half_block_overlap(
2227
- self,
2228
- kv_consumer_state,
2229
- pipeline_v,
2230
- mma_pv_fn: Callable,
2231
- zero_init: bool,
2232
- ):
2233
- """Processes the final PV GEMM when using intra-warpgroup-overlap"""
2234
-
2235
- pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
2236
- mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)
2237
- pipeline_v.consumer_release(kv_consumer_state)
2238
- kv_consumer_state.advance()
2239
- return kv_consumer_state
2240
-
2241
- @cute.jit
2242
- def mma_one_n_block(
2243
- self,
2244
- smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
2245
- n_block: Int32,
2246
- mma_qk_fn: Callable,
2247
- mma_pv_fn: Callable,
2248
- pipeline_k: cutlass.pipeline.PipelineAsync,
2249
- pipeline_v: cutlass.pipeline.PipelineAsync,
2250
- acc_O: cute.Tensor,
2251
- tOrP: cute.Tensor,
2252
- smem_copy_params: SimpleNamespace,
2253
- softmax: Softmax,
2254
- seqlen: SeqlenInfoQK,
2255
- score_mod_fn: Optional[Callable] = None,
2256
- mask_fn: Optional[Callable] = None,
2257
- is_first_n_block: cutlass.Constexpr = False,
2258
- check_inf: cutlass.Constexpr = True,
2259
- ):
2260
- pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
2261
- # S = Q @ K.T
2262
- acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
2263
- self.warp_scheduler_barrier_arrive()
2264
- warpgroup.wait_group(0)
2265
- pipeline_k.consumer_release(smem_pipe_read)
2266
-
2267
- # handle score mods and masking
2268
- if const_expr(score_mod_fn is not None):
2269
- score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
2270
- if const_expr(mask_fn is not None):
2271
- mask_fn(acc_S=acc_S, n_block=n_block)
2272
-
2273
- row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
2274
- # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
2275
- tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
2276
- tOrP_cur = (
2277
- tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
2278
- )
2279
- # tOrP.store(tOrP_acc.load().to(self.dtype))
2280
- # the "to(self.dtype)" conversion fails to vectorize for block sizes other
2281
- # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
2282
- # 2 elements. So we just call ptx directly.
2283
- utils.cvt_f16(tOrP_acc, tOrP_cur)
2284
- if const_expr(not self.mma_pv_is_rs):
2285
- tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
2286
- cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
2287
- softmax.rescale_O(acc_O, row_scale)
2288
- if const_expr(not self.mma_pv_is_rs):
2289
- # Fence and barrier to make sure smem store is visible to WGMMA
2290
- cute.arch.fence_view_async_shared()
2291
- cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
2292
- pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
2293
- self.warp_scheduler_barrier_sync()
2294
- # O += P @ V
2295
- mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
2296
- pipeline_v.consumer_release(smem_pipe_read)
2297
- smem_pipe_read.advance()
2298
- return smem_pipe_read
2299
-
2300
- @cute.jit
2301
- def mma_one_n_block_intrawg_overlap(
2302
- self,
2303
- smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,
2304
- n_block: Int32,
2305
- mma_qk_fn: Callable,
2306
- mma_pv_fn: Callable,
2307
- pipeline_k: cutlass.pipeline.PipelineAsync,
2308
- pipeline_v: cutlass.pipeline.PipelineAsync,
2309
- acc_O: cute.Tensor,
2310
- tOrP: cute.Tensor,
2311
- smem_copy_params: SimpleNamespace,
2312
- softmax: Softmax,
2313
- seqlen: SeqlenInfoQK,
2314
- score_mod_fn: Optional[Callable] = None,
2315
- mask_fn: Optional[Callable] = None,
2316
- check_inf: cutlass.Constexpr = True,
2317
- ):
2318
- smem_pipe_read_v = smem_pipe_read.clone()
2319
- smem_pipe_read.advance()
2320
- pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
2321
- self.warp_scheduler_barrier_sync()
2322
- # S = Q @ K.T
2323
- acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
2324
- pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
2325
- # O += P @ V
2326
- mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
2327
- self.warp_scheduler_barrier_arrive()
2328
- warpgroup.wait_group(1)
2329
- pipeline_k.consumer_release(smem_pipe_read)
2330
-
2331
- # handle score mods and masking
2332
- if const_expr(score_mod_fn is not None):
2333
- score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
2334
- if const_expr(mask_fn is not None):
2335
- mask_fn(acc_S=acc_S, n_block=n_block)
2336
- # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
2337
-
2338
- row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
2339
- warpgroup.wait_group(0)
2340
- pipeline_v.consumer_release(smem_pipe_read_v)
2341
- tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
2342
- tOrP_cur = (
2343
- tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
2344
- )
2345
- # tOrP_cur.store(tOrP_acc.load().to(self.dtype))
2346
- # the "to(self.dtype)" conversion fails to vectorize for block sizes other
2347
- # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
2348
- # 2 elements. So we just call ptx directly.
2349
- utils.cvt_f16(tOrP_acc, tOrP_cur)
2350
- if const_expr(not self.mma_pv_is_rs):
2351
- tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
2352
- cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
2353
- softmax.rescale_O(acc_O, row_scale)
2354
- if const_expr(not self.mma_pv_is_rs):
2355
- # Fence and barrier to make sure smem store is visible to WGMMA
2356
- cute.arch.fence_view_async_shared()
2357
- cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
2358
- return smem_pipe_read
2359
-
2360
- @cute.jit
2361
- def mma_init(self):
2362
- warp_group_idx = utils.canonical_warp_group_idx(sync=False)
2363
- if const_expr(self.use_scheduler_barrier):
2364
- if warp_group_idx == 1:
2365
- cute.arch.barrier_arrive(
2366
- barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),
2367
- number_of_threads=2 * self.num_threads_per_warp_group,
2368
- )
2369
-
2370
- @cute.jit
2371
- def apply_score_mod(
2372
- self,
2373
- thr_mma_qk,
2374
- batch_idx,
2375
- head_idx,
2376
- m_block,
2377
- acc_S,
2378
- n_block,
2379
- softmax_scale,
2380
- seqlen,
2381
- aux_tensors: Optional[list] = None,
2382
- fastdiv_mods=None,
2383
- ):
2384
- # Prepare index tensor
2385
- cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
2386
- cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
2387
- tScS = thr_mma_qk.partition_C(cS)
2388
-
2389
- apply_score_mod_inner(
2390
- acc_S,
2391
- tScS,
2392
- self.score_mod,
2393
- batch_idx,
2394
- head_idx,
2395
- softmax_scale,
2396
- self.vec_size,
2397
- self.qk_acc_dtype,
2398
- aux_tensors,
2399
- fastdiv_mods,
2400
- seqlen_info=seqlen,
2401
- constant_q_idx=None,
2402
- qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
2403
- )
2404
-
2405
- def warp_scheduler_barrier_sync(self):
2406
- if const_expr(self.use_scheduler_barrier):
2407
- cute.arch.barrier(
2408
- barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)
2409
- - 1
2410
- + utils.canonical_warp_group_idx(sync=False),
2411
- number_of_threads=2 * self.num_threads_per_warp_group,
2412
- )
2413
-
2414
- def warp_scheduler_barrier_arrive(self):
2415
- if const_expr(self.use_scheduler_barrier):
2416
- assert self.num_mma_warp_groups in [2, 3]
2417
- cur_wg = utils.canonical_warp_group_idx(sync=False) - 1
2418
- if const_expr(self.num_mma_warp_groups == 2):
2419
- next_wg = 1 - cur_wg
2420
- else:
2421
- t = cur_wg + 1
2422
- next_wg = t % self.num_mma_warp_groups
2423
- cute.arch.barrier_arrive(
2424
- barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
2425
- number_of_threads=2 * self.num_threads_per_warp_group,
2426
- )
 
15
  import cutlass
16
  import cutlass.cute as cute
17
  from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
18
+ from cutlass.cute.nvgpu import cpasync, warp
19
  import cutlass.utils as utils_basic
20
+ from cutlass.base_dsl.arch import Arch
21
+ from cutlass.cutlass_dsl import BaseDSL
22
 
23
  from .quack import copy_utils
24
  from .quack import layout_utils
 
25
 
26
  from . import ampere_helpers as sm80_utils
27
  from .cute_dsl_utils import assume_tensor_aligned
28
  from . import utils
29
  from .mask import AttentionMask
30
+ from .softmax import Softmax
31
  from .seqlen_info import SeqlenInfoQK
32
  from .block_info import BlockInfo
 
 
 
 
 
 
33
  from .pack_gqa import PackGQA
34
  from .named_barrier import NamedBarrierFwd
35
+ from .block_sparsity import BlockSparseTensors
36
+ from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments
 
 
 
 
 
 
37
 
38
 
39
  class FlashAttentionForwardBase:
 
40
 
41
  def __init__(
42
  self,
 
102
  self.vec_size: cutlass.Constexpr = getattr(
103
  score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2
104
  )
105
+ if self.vec_size > 2:
106
+ raise ValueError(
107
+ f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 "
108
+ "due to accumulator thread ownership pattern."
109
+ )
110
+ self.arch = BaseDSL._get_dsl().get_arch_enum()
111
 
112
  @staticmethod
113
  def can_implement(
 
310
  mO: cute.Tensor,
311
  mLSE: Optional[cute.Tensor],
312
  softmax_scale: Float32,
313
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
314
+ stream: cuda.CUstream = None,
315
  ):
316
  """Configures and launches the flash attention kernel.
317
 
 
344
  cute.arch.barrier(
345
  barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
346
  )
347
+ smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)
348
  smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
349
  taccOrO = smem_thr_copy_O.retile(rO)
350
  taccOsO = smem_thr_copy_O.partition_D(sO)
 
359
 
360
  # Write LSE from rmem -> gmem
361
  if const_expr(mLSE is not None):
362
+ mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx]
 
 
 
 
363
  if const_expr(not self.pack_gqa):
364
  gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
365
  gLSE_expanded_layout = cute.append(
 
373
  t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))
374
  # Only the thread corresponding to column 0 writes out the lse to gmem
375
  if taccOcO[0][1] == 0:
376
+ for m in cutlass.range(cute.size(taccOgLSE.shape[1]), unroll_full=True):
377
  if (
378
  t0accOcO[m, 0][0]
379
  < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]
 
382
  else:
383
  pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)
384
 
385
+ ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
386
+ mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]
 
 
 
387
  # thr_mma = tiled_mma.get_slice(tidx)
388
  # taccOgO = thr_mma.partition_C(gO)
389
  # cute.autovec_copy(rO, taccOgO)
 
620
  mV: cute.Tensor,
621
  mO: cute.Tensor,
622
  mLSE: Optional[cute.Tensor],
623
+ softmax_scale: Float32,
624
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
625
+ mCuSeqlensK: Optional[cute.Tensor] = None,
626
+ mSeqUsedQ: Optional[cute.Tensor] = None,
627
+ mSeqUsedK: Optional[cute.Tensor] = None,
628
+ mPageTable: Optional[cute.Tensor] = None,
629
  window_size_left: Optional[Int32] = None,
630
  window_size_right: Optional[Int32] = None,
631
  learnable_sink: Optional[cute.Tensor] = None,
632
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
633
  aux_tensors=None,
634
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
635
+ stream: cuda.CUstream = None,
636
  ):
637
  """Configures and launches the flash attention kernel.
638
 
 
641
  """
642
  assert learnable_sink is None, "Learnable sink is not supported in this kernel"
643
  self._check_type(
644
+ *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))
645
  )
646
  tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
647
  self.num_mma_threads = tiled_mma_pv.size
 
649
  self.num_Q_load_threads = self.num_threads
650
  self.num_epilogue_threads = self.num_threads
651
  # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None
652
+ self.use_tma_O = self.arch >= Arch.sm_90
653
  self._setup_attributes()
654
  SharedStorage = self._get_shared_storage_cls()
655
  mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
656
+ # Layout permutation: 4D non-varlen vs 3D varlen
657
+ QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
658
+ KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
659
+ mQ, mO = [
660
+ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose))
661
+ for t in (mQ, mO)
662
  ]
663
+ mK, mV = [
664
+ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose))
665
+ for t in (mK, mV)
666
+ ]
667
+ if const_expr(mLSE is not None):
668
+ LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
669
+ mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
670
+ # TileScheduler for varlen, simple grid for non-varlen
671
+ if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
672
+ TileScheduler = SingleTileVarlenScheduler
 
673
  else:
674
+ TileScheduler = SingleTileScheduler
675
+ num_batch = (
676
+ mCuSeqlensQ.shape[0] - 1
677
+ if const_expr(mCuSeqlensQ is not None)
678
+ else mQ.shape[3]
679
+ )
680
+ tile_sched_args = TileSchedulerArguments(
681
+ num_block=cute.ceil_div(mQ.shape[0], self.tile_m),
682
+ num_head=cute.size(mQ.shape[2]),
683
+ num_batch=num_batch,
684
+ num_splits=1,
685
+ seqlen_k=0,
686
+ headdim=mQ.shape[1],
687
+ headdim_v=mV.shape[1],
688
+ total_q=cute.size(mQ.shape[0])
689
+ if const_expr(mCuSeqlensQ is not None)
690
+ else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
691
+ tile_shape_mn=(self.tile_m, self.tile_n),
692
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
693
+ mCuSeqlensQ=mCuSeqlensQ,
694
+ mSeqUsedQ=mSeqUsedQ,
695
+ )
696
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
697
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
698
+ softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)
699
+ fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors)
700
 
701
  self.kernel(
702
  mQ,
 
704
  mV,
705
  mO,
706
  mLSE,
707
+ mCuSeqlensQ,
708
+ mCuSeqlensK,
709
+ mSeqUsedQ,
710
+ mSeqUsedK,
711
  softmax_scale_log2,
712
  softmax_scale,
713
  window_size_left,
 
724
  tiled_mma_qk,
725
  tiled_mma_pv,
726
  SharedStorage,
727
+ tile_sched_params,
728
+ TileScheduler,
729
  aux_tensors,
730
  fastdiv_mods,
731
  ).launch(
 
743
  mV: cute.Tensor,
744
  mO: cute.Tensor,
745
  mLSE: Optional[cute.Tensor],
746
+ mCuSeqlensQ: Optional[cute.Tensor],
747
+ mCuSeqlensK: Optional[cute.Tensor],
748
+ mSeqUsedQ: Optional[cute.Tensor],
749
+ mSeqUsedK: Optional[cute.Tensor],
750
  softmax_scale_log2: Float32,
751
  softmax_scale: Optional[Float32],
752
  window_size_left: Optional[Int32],
 
763
  tiled_mma_qk: cute.TiledMma,
764
  tiled_mma_pv: cute.TiledMma,
765
  SharedStorage: cutlass.Constexpr,
766
+ tile_sched_params,
767
+ TileScheduler: cutlass.Constexpr[Callable],
768
  aux_tensors=None,
769
  fastdiv_mods=None,
770
  ):
771
  # Thread index, block index
772
  tidx, _, _ = cute.arch.thread_idx()
773
+
774
+ tile_scheduler = TileScheduler.create(tile_sched_params)
775
+ work_tile = tile_scheduler.initial_work_tile_info()
776
+ m_block, num_head, batch_size, _ = work_tile.tile_idx
777
 
778
  block_info = BlockInfo(
779
  self.tile_m,
 
785
  window_size_right,
786
  qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
787
  )
788
+ seqlen = SeqlenInfoQK.create(
789
+ batch_idx=batch_size,
790
+ seqlen_q_static=mQ.shape[0],
791
+ seqlen_k_static=mK.shape[0],
792
+ mCuSeqlensQ=mCuSeqlensQ,
793
+ mCuSeqlensK=mCuSeqlensK,
794
+ mSeqUsedQ=mSeqUsedQ,
795
+ mSeqUsedK=mSeqUsedK,
796
+ )
797
  n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
798
+ # For varlen, wasted grid tiles (where batch_idx >= num_batch) will have
799
+ # seqlen_q=seqlen_k=0 and n_block_max=0. Clamp to 0 so we don't use a
800
+ # negative block index for K/V loads; the load/store predicates already
801
+ # guard all memory accesses when seqlen is 0.
802
+ n_block = cutlass.max(n_block_max - 1, 0)
803
 
804
  # ///////////////////////////////////////////////////////////////////////////////
805
  # Get the appropriate tiles for this thread block.
 
807
  blkQ_shape = (self.tile_m, self.tile_hdim)
808
  blkK_shape = (self.tile_n, self.tile_hdim)
809
  blkV_shape = (self.tile_n, self.tile_hdimv)
 
810
  num_head_kv = num_head // self.qhead_per_kvhead
811
+ if const_expr(not seqlen.has_cu_seqlens_q):
812
+ mQ_cur = mQ[None, None, num_head, batch_size]
813
+ else:
814
+ mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, num_head])
815
+ if const_expr(not seqlen.has_cu_seqlens_k):
816
+ mK_cur = mK[None, None, num_head_kv, batch_size]
817
+ mV_cur = mV[None, None, num_head_kv, batch_size]
818
+ else:
819
+ mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, num_head_kv])
820
+ mV_cur = cute.domain_offset((seqlen.offset_k, 0), mV[None, None, num_head_kv])
821
+ gQ = cute.local_tile(mQ_cur, blkQ_shape, (m_block, 0))
822
+ gK = cute.local_tile(mK_cur, blkK_shape, (None, 0))
823
+ gV = cute.local_tile(mV_cur, blkV_shape, (None, 0))
824
 
825
  # ///////////////////////////////////////////////////////////////////////////////
826
  # Get shared memory buffer
 
992
  mask = AttentionMask(
993
  self.tile_m,
994
  self.tile_n,
995
+ seqlen,
 
996
  window_size_left,
997
  window_size_right,
998
  self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
999
  )
1000
  mask_fn = partial(
1001
  mask.apply_mask,
1002
+ batch_idx=batch_size,
1003
+ head_idx=num_head,
1004
  m_block=m_block,
1005
  thr_mma=thr_mma_qk,
1006
  mask_causal=self.is_causal,
1007
  mask_local=self.is_local,
1008
+ aux_tensors=aux_tensors,
1009
  fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,
1010
  )
1011
 
 
1017
  smem_pipe_read,
1018
  smem_pipe_write,
1019
  is_first_n_block=True,
1020
+ seqlen=seqlen,
1021
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
1022
  )
1023
  smem_pipe_read = self.advance_pipeline(smem_pipe_read)
1024
  smem_pipe_write = self.advance_pipeline(smem_pipe_write)
 
1033
  n_block,
1034
  smem_pipe_read,
1035
  smem_pipe_write,
1036
+ seqlen=seqlen,
1037
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
1038
  )
1039
  smem_pipe_read = self.advance_pipeline(smem_pipe_read)
1040
  smem_pipe_write = self.advance_pipeline(smem_pipe_write)
1041
  # The remaining iterations have no masking
1042
  for n_tile in cutlass.range(n_block, unroll=1):
1043
  compute_one_n_block(
1044
+ n_block - n_tile - 1, smem_pipe_read, smem_pipe_write,
1045
+ seqlen=seqlen, is_first_n_block=False,
1046
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False)
1047
  )
1048
  smem_pipe_read = self.advance_pipeline(smem_pipe_read)
1049
  smem_pipe_write = self.advance_pipeline(smem_pipe_write)
 
1187
  # load_K_next()
1188
 
1189
 
1190
+ # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility
1191
+ def __getattr__(name):
1192
+ if name == "FlashAttentionForwardSm90":
1193
+ from .flash_fwd_sm90 import FlashAttentionForwardSm90
1194
+ return FlashAttentionForwardSm90
1195
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/flash_fwd_combine.py CHANGED
@@ -10,7 +10,7 @@ import cuda.bindings.driver as cuda
10
  import cutlass
11
  import cutlass.cute as cute
12
  from cutlass.cute.nvgpu import cpasync
13
- from cutlass import Float32, Int32, const_expr
14
 
15
  from . import utils
16
  from .cute_dsl_utils import assume_tensor_aligned
@@ -24,7 +24,7 @@ class FlashAttentionForwardCombine:
24
  dtype: Type[cutlass.Numeric],
25
  dtype_partial: Type[cutlass.Numeric],
26
  head_dim: int,
27
- m_block_size: int = 8,
28
  k_block_size: int = 64,
29
  log_max_splits: int = 4,
30
  num_threads: int = 256,
@@ -36,7 +36,7 @@ class FlashAttentionForwardCombine:
36
  :param dtype: output data type
37
  :param dtype_partial: partial accumulation data type
38
  :param head_dim: head dimension
39
- :param m_block_size: m block size
40
  :param k_block_size: k block size
41
  :param log_max_splits: log2 of maximum splits
42
  :param num_threads: number of threads
@@ -46,7 +46,7 @@ class FlashAttentionForwardCombine:
46
  self.dtype = dtype
47
  self.dtype_partial = dtype_partial
48
  self.head_dim = head_dim
49
- self.m_block_size = m_block_size
50
  self.k_block_size = k_block_size
51
  self.max_splits = 1 << log_max_splits
52
  self.num_threads = num_threads
@@ -58,7 +58,7 @@ class FlashAttentionForwardCombine:
58
  dtype,
59
  dtype_partial,
60
  head_dim,
61
- m_block_size,
62
  k_block_size,
63
  log_max_splits,
64
  num_threads,
@@ -72,12 +72,12 @@ class FlashAttentionForwardCombine:
72
  return False
73
  if num_threads % 32 != 0:
74
  return False
75
- if m_block_size % 8 != 0:
76
  return False
77
  max_splits = 1 << log_max_splits
78
  if max_splits > 256:
79
  return False
80
- if (m_block_size * max_splits) % num_threads != 0:
81
  return False
82
  return True
83
 
@@ -124,15 +124,11 @@ class FlashAttentionForwardCombine:
124
  lse_copy_bits = Float32.width # 1 element per copy, width is in bits
125
  m_block_smem = (
126
  128
127
- if self.m_block_size % 128 == 0
128
  else (
129
  64
130
- if self.m_block_size % 64 == 0
131
- else (
132
- 32
133
- if self.m_block_size % 32 == 0
134
- else (16 if self.m_block_size % 16 == 0 else 8)
135
- )
136
  )
137
  )
138
  gmem_threads_per_row_lse = m_block_smem
@@ -183,12 +179,12 @@ class FlashAttentionForwardCombine:
183
  smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
184
  )
185
  self.smem_layout_lse = cute.tile_to_shape(
186
- smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
187
  )
188
 
189
  # O partial shared memory layout (simple layout for pipeline stages)
190
  self.smem_layout_o = cute.make_ordered_layout(
191
- (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
192
  )
193
 
194
  @cute.jit
@@ -201,7 +197,9 @@ class FlashAttentionForwardCombine:
201
  cu_seqlens: Optional[cute.Tensor] = None,
202
  seqused: Optional[cute.Tensor] = None,
203
  num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
 
204
  semaphore_to_reset: Optional[cute.Tensor] = None,
 
205
  stream: cuda.CUstream = None,
206
  ):
207
  # Type checking
@@ -269,7 +267,7 @@ class FlashAttentionForwardCombine:
269
  sLSE: cute.struct.Align[
270
  cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
271
  ]
272
- sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
273
  sO: cute.struct.Align[
274
  cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
275
  ]
@@ -290,7 +288,7 @@ class FlashAttentionForwardCombine:
290
  head_divmod = FastDivmodDivisor(num_head)
291
 
292
  grid_dim = (
293
- cute.ceil_div(seqlen * num_head, self.m_block_size),
294
  cute.ceil_div(self.head_dim, self.k_block_size),
295
  batch_size,
296
  )
@@ -303,6 +301,7 @@ class FlashAttentionForwardCombine:
303
  cu_seqlens,
304
  seqused,
305
  num_splits_dynamic_ptr,
 
306
  semaphore_to_reset,
307
  SharedStorage,
308
  self.smem_layout_lse,
@@ -331,6 +330,7 @@ class FlashAttentionForwardCombine:
331
  cu_seqlens: Optional[cute.Tensor],
332
  seqused: Optional[cute.Tensor],
333
  num_splits_dynamic_ptr: Optional[cute.Tensor],
 
334
  semaphore_to_reset: Optional[cute.Tensor],
335
  SharedStorage: cutlass.Constexpr,
336
  smem_layout_lse: cute.Layout | cute.ComposedLayout,
@@ -345,7 +345,14 @@ class FlashAttentionForwardCombine:
345
  ):
346
  # Thread and block indices
347
  tidx, _, _ = cute.arch.thread_idx()
348
- m_block, k_block, batch_idx = cute.arch.block_idx()
 
 
 
 
 
 
 
349
 
350
  # ///////////////////////////////////////////////////////////////////////////////
351
  # Get shared memory buffer
@@ -353,22 +360,23 @@ class FlashAttentionForwardCombine:
353
  smem = cutlass.utils.SmemAllocator()
354
  storage = smem.allocate(SharedStorage)
355
  sLSE = storage.sLSE.get_tensor(smem_layout_lse)
356
- sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,))
357
  sO = storage.sO.get_tensor(smem_layout_o)
358
 
359
- # Handle semaphore reset
360
  if const_expr(semaphore_to_reset is not None):
361
  if (
362
  tidx == 0
363
  and m_block == cute.arch.grid_dim()[0] - 1
364
  and k_block == cute.arch.grid_dim()[1] - 1
365
- and batch_idx == cute.arch.grid_dim()[2] - 1
366
  ):
 
367
  semaphore_to_reset[0] = 0
368
 
369
- # Get number of splits
370
  num_splits = (
371
- num_splits_dynamic_ptr[batch_idx]
372
  if const_expr(num_splits_dynamic_ptr is not None)
373
  else mLSE_partial.shape[1]
374
  )
@@ -378,6 +386,7 @@ class FlashAttentionForwardCombine:
378
  seqlen_static=mO_partial.shape[0],
379
  cu_seqlens=cu_seqlens,
380
  seqused=seqused,
 
381
  )
382
  seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
383
 
@@ -387,29 +396,27 @@ class FlashAttentionForwardCombine:
387
 
388
  # Early exit for single split if dynamic
389
  if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
390
- const_expr(not varlen) or m_block * self.m_block_size < max_idx
391
  ):
 
 
 
392
  # ===============================
393
  # Step 1: Load LSE_partial from gmem to shared memory
394
  # ===============================
395
 
396
- if const_expr(cu_seqlens is None):
397
- mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx]
398
- else:
399
- mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial)
400
  mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
401
-
402
  gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
403
  tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
404
-
405
  # Create identity tensor for coordinate tracking
406
- cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size))
407
  tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
408
 
409
  # Load LSE partial values
410
  for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
411
  mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
412
- idx = m_block * self.m_block_size + mi
413
  if idx < max_idx:
414
  # Calculate actual sequence position and head using FastDivmodDivisor
415
  if const_expr(not varlen):
@@ -436,22 +443,19 @@ class FlashAttentionForwardCombine:
436
  # ===============================
437
 
438
  gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
439
- cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size))
440
  tOcO = gmem_thr_copy_O_partial.partition_D(cO)
441
  tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
442
- if const_expr(cu_seqlens is None):
443
- mO_partial_cur = mO_partial[None, None, None, None, batch_idx]
444
- else:
445
- mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial)
446
 
447
  # Precompute these values to avoid recomputing them in the loop
448
  num_rows = const_expr(cute.size(tOcO, mode=[1]))
449
- tOmidx = cute.make_fragment(num_rows, cutlass.Int32)
450
- tOhidx = cute.make_fragment(num_rows, cutlass.Int32)
451
- tOrOptr = cute.make_fragment(num_rows, cutlass.Int64)
452
  for m in cutlass.range(num_rows, unroll_full=True):
453
  mi = tOcO[0, m, 0][0] # m coordinate
454
- idx = m_block * self.m_block_size + mi
455
  if const_expr(not varlen):
456
  tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
457
  else:
@@ -463,11 +467,12 @@ class FlashAttentionForwardCombine:
463
  if idx >= max_idx:
464
  tOhidx[m] = -1
465
 
466
- tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean)
467
  if const_expr(not self.is_even_k):
 
468
  for k in cutlass.range(cute.size(tOpO), unroll_full=True):
469
  tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
470
- # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
471
 
472
  load_O_partial = partial(
473
  self.load_O_partial,
@@ -501,17 +506,17 @@ class FlashAttentionForwardCombine:
501
 
502
  s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
503
  ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
504
- ts2rrLSE = cute.make_fragment_like(ts2rsLSE)
505
  cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
506
 
507
  # ===============================
508
  # Step 4: Compute final LSE along split dimension
509
  # ===============================
510
 
511
- lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32)
512
  ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
513
  # We compute the max valid split for each row to short-circuit the computation later
514
- max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32)
515
  assert cute.size(ts2rrLSE, mode=[0]) == 1
516
  # Compute max, scales, and final LSE for each row
517
  for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
@@ -561,7 +566,7 @@ class FlashAttentionForwardCombine:
561
  for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
562
  if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
563
  mi = ts2rcLSE[0, 0, m][1]
564
- if mi < self.m_block_size:
565
  sMaxValidSplit[mi] = max_valid_split[m]
566
 
567
  # ===============================
@@ -577,7 +582,7 @@ class FlashAttentionForwardCombine:
577
  for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
578
  if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
579
  mi = ts2rcLSE[0, 0, m][1]
580
- idx = m_block * self.m_block_size + mi
581
  if idx < max_idx:
582
  if const_expr(not varlen):
583
  head_idx, m_idx = divmod(idx, seqlen_divmod)
@@ -594,11 +599,11 @@ class FlashAttentionForwardCombine:
594
 
595
  # Get max valid split for this thread
596
  thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
597
- for m in cutlass.range(1, cute.size(tOcO, mode=[1])):
598
  thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
599
 
600
- tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0])
601
- tOrO = cute.make_fragment_like(tOrO_partial, Float32)
602
  tOrO.fill(0.0)
603
 
604
  stage_load = self.stages - 1
@@ -607,7 +612,7 @@ class FlashAttentionForwardCombine:
607
  # Main accumulation loop
608
  for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
609
  # Get scales for this split
610
- scale = cute.make_fragment(num_rows, Float32)
611
  for m in cutlass.range(num_rows, unroll_full=True):
612
  scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
613
 
@@ -637,8 +642,9 @@ class FlashAttentionForwardCombine:
637
  # Step 7: Write final O to gmem
638
  # ===============================
639
 
640
- rO = cute.make_fragment_like(tOrO, self.dtype)
641
  rO.store(tOrO.load().to(self.dtype))
 
642
  if const_expr(cu_seqlens is None):
643
  mO_cur = mO[None, None, None, batch_idx]
644
  else:
@@ -665,7 +671,7 @@ class FlashAttentionForwardCombine:
665
  tOrOptr: cute.Tensor,
666
  tOsO_partial: cute.Tensor,
667
  tOhidx: cute.Tensor,
668
- tOpO: cute.Tensor,
669
  tOcO: cute.Tensor,
670
  mO_cur_partial_layout: cute.Layout,
671
  split: Int32,
@@ -684,7 +690,7 @@ class FlashAttentionForwardCombine:
684
  mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
685
  for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
686
  k_idx = tOcO[0, 0, k][1] // elems_per_load
687
- if const_expr(self.is_even_k) or tOpO[k]:
688
  cute.copy(
689
  gmem_tiled_copy_O_partial,
690
  mO_partial_cur_copy[None, k_idx, split],
 
10
  import cutlass
11
  import cutlass.cute as cute
12
  from cutlass.cute.nvgpu import cpasync
13
+ from cutlass import Float32, Int32, Boolean, const_expr
14
 
15
  from . import utils
16
  from .cute_dsl_utils import assume_tensor_aligned
 
24
  dtype: Type[cutlass.Numeric],
25
  dtype_partial: Type[cutlass.Numeric],
26
  head_dim: int,
27
+ tile_m: int = 8,
28
  k_block_size: int = 64,
29
  log_max_splits: int = 4,
30
  num_threads: int = 256,
 
36
  :param dtype: output data type
37
  :param dtype_partial: partial accumulation data type
38
  :param head_dim: head dimension
39
+ :param tile_m: m block size
40
  :param k_block_size: k block size
41
  :param log_max_splits: log2 of maximum splits
42
  :param num_threads: number of threads
 
46
  self.dtype = dtype
47
  self.dtype_partial = dtype_partial
48
  self.head_dim = head_dim
49
+ self.tile_m = tile_m
50
  self.k_block_size = k_block_size
51
  self.max_splits = 1 << log_max_splits
52
  self.num_threads = num_threads
 
58
  dtype,
59
  dtype_partial,
60
  head_dim,
61
+ tile_m,
62
  k_block_size,
63
  log_max_splits,
64
  num_threads,
 
72
  return False
73
  if num_threads % 32 != 0:
74
  return False
75
+ if tile_m % 8 != 0:
76
  return False
77
  max_splits = 1 << log_max_splits
78
  if max_splits > 256:
79
  return False
80
+ if (tile_m * max_splits) % num_threads != 0:
81
  return False
82
  return True
83
 
 
124
  lse_copy_bits = Float32.width # 1 element per copy, width is in bits
125
  m_block_smem = (
126
  128
127
+ if self.tile_m % 128 == 0
128
  else (
129
  64
130
+ if self.tile_m % 64 == 0
131
+ else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))
 
 
 
 
132
  )
133
  )
134
  gmem_threads_per_row_lse = m_block_smem
 
179
  smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
180
  )
181
  self.smem_layout_lse = cute.tile_to_shape(
182
+ smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1)
183
  )
184
 
185
  # O partial shared memory layout (simple layout for pipeline stages)
186
  self.smem_layout_o = cute.make_ordered_layout(
187
+ (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)
188
  )
189
 
190
  @cute.jit
 
197
  cu_seqlens: Optional[cute.Tensor] = None,
198
  seqused: Optional[cute.Tensor] = None,
199
  num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
200
+ varlen_batch_idx: Optional[cute.Tensor] = None,
201
  semaphore_to_reset: Optional[cute.Tensor] = None,
202
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
203
  stream: cuda.CUstream = None,
204
  ):
205
  # Type checking
 
267
  sLSE: cute.struct.Align[
268
  cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
269
  ]
270
+ sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128]
271
  sO: cute.struct.Align[
272
  cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
273
  ]
 
288
  head_divmod = FastDivmodDivisor(num_head)
289
 
290
  grid_dim = (
291
+ cute.ceil_div(seqlen * num_head, self.tile_m),
292
  cute.ceil_div(self.head_dim, self.k_block_size),
293
  batch_size,
294
  )
 
301
  cu_seqlens,
302
  seqused,
303
  num_splits_dynamic_ptr,
304
+ varlen_batch_idx,
305
  semaphore_to_reset,
306
  SharedStorage,
307
  self.smem_layout_lse,
 
330
  cu_seqlens: Optional[cute.Tensor],
331
  seqused: Optional[cute.Tensor],
332
  num_splits_dynamic_ptr: Optional[cute.Tensor],
333
+ varlen_batch_idx: Optional[cute.Tensor],
334
  semaphore_to_reset: Optional[cute.Tensor],
335
  SharedStorage: cutlass.Constexpr,
336
  smem_layout_lse: cute.Layout | cute.ComposedLayout,
 
345
  ):
346
  # Thread and block indices
347
  tidx, _, _ = cute.arch.thread_idx()
348
+ m_block, k_block, maybe_virtual_batch = cute.arch.block_idx()
349
+
350
+ # Map virtual batch index to real batch index (for persistent tile schedulers)
351
+ batch_idx = (
352
+ varlen_batch_idx[maybe_virtual_batch]
353
+ if const_expr(varlen_batch_idx is not None)
354
+ else maybe_virtual_batch
355
+ )
356
 
357
  # ///////////////////////////////////////////////////////////////////////////////
358
  # Get shared memory buffer
 
360
  smem = cutlass.utils.SmemAllocator()
361
  storage = smem.allocate(SharedStorage)
362
  sLSE = storage.sLSE.get_tensor(smem_layout_lse)
363
+ sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))
364
  sO = storage.sO.get_tensor(smem_layout_o)
365
 
366
+ # Handle semaphore reset — wait for dependent grids first
367
  if const_expr(semaphore_to_reset is not None):
368
  if (
369
  tidx == 0
370
  and m_block == cute.arch.grid_dim()[0] - 1
371
  and k_block == cute.arch.grid_dim()[1] - 1
372
+ and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1
373
  ):
374
+ cute.arch.griddepcontrol_wait()
375
  semaphore_to_reset[0] = 0
376
 
377
+ # Get number of splits (use maybe_virtual_batch for per-batch-slot splits)
378
  num_splits = (
379
+ num_splits_dynamic_ptr[maybe_virtual_batch]
380
  if const_expr(num_splits_dynamic_ptr is not None)
381
  else mLSE_partial.shape[1]
382
  )
 
386
  seqlen_static=mO_partial.shape[0],
387
  cu_seqlens=cu_seqlens,
388
  seqused=seqused,
389
+ # Don't need to pass in tile size since we won't use offset_padded
390
  )
391
  seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
392
 
 
396
 
397
  # Early exit for single split if dynamic
398
  if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
399
+ const_expr(not varlen) or m_block * self.tile_m < max_idx
400
  ):
401
+ # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial)
402
+ cute.arch.griddepcontrol_wait()
403
+
404
  # ===============================
405
  # Step 1: Load LSE_partial from gmem to shared memory
406
  # ===============================
407
 
408
+ mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)
 
 
 
409
  mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))
 
410
  gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)
411
  tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)
 
412
  # Create identity tensor for coordinate tracking
413
+ cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m))
414
  tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)
415
 
416
  # Load LSE partial values
417
  for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):
418
  mi = tLSEcLSE[0, 0, m][1] # Get m coordinate
419
+ idx = m_block * self.tile_m + mi
420
  if idx < max_idx:
421
  # Calculate actual sequence position and head using FastDivmodDivisor
422
  if const_expr(not varlen):
 
443
  # ===============================
444
 
445
  gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)
446
+ cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))
447
  tOcO = gmem_thr_copy_O_partial.partition_D(cO)
448
  tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)
449
+ mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4)
 
 
 
450
 
451
  # Precompute these values to avoid recomputing them in the loop
452
  num_rows = const_expr(cute.size(tOcO, mode=[1]))
453
+ tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
454
+ tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)
455
+ tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64)
456
  for m in cutlass.range(num_rows, unroll_full=True):
457
  mi = tOcO[0, m, 0][0] # m coordinate
458
+ idx = m_block * self.tile_m + mi
459
  if const_expr(not varlen):
460
  tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)
461
  else:
 
467
  if idx >= max_idx:
468
  tOhidx[m] = -1
469
 
470
+ tOpO = None
471
  if const_expr(not self.is_even_k):
472
+ tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean)
473
  for k in cutlass.range(cute.size(tOpO), unroll_full=True):
474
  tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size
475
+ # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)
476
 
477
  load_O_partial = partial(
478
  self.load_O_partial,
 
506
 
507
  s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)
508
  ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)
509
+ ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)
510
  cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)
511
 
512
  # ===============================
513
  # Step 4: Compute final LSE along split dimension
514
  # ===============================
515
 
516
+ lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)
517
  ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)
518
  # We compute the max valid split for each row to short-circuit the computation later
519
+ max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)
520
  assert cute.size(ts2rrLSE, mode=[0]) == 1
521
  # Compute max, scales, and final LSE for each row
522
  for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
 
566
  for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
567
  if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
568
  mi = ts2rcLSE[0, 0, m][1]
569
+ if mi < self.tile_m:
570
  sMaxValidSplit[mi] = max_valid_split[m]
571
 
572
  # ===============================
 
582
  for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):
583
  if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes
584
  mi = ts2rcLSE[0, 0, m][1]
585
+ idx = m_block * self.tile_m + mi
586
  if idx < max_idx:
587
  if const_expr(not varlen):
588
  head_idx, m_idx = divmod(idx, seqlen_divmod)
 
599
 
600
  # Get max valid split for this thread
601
  thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]
602
+ for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):
603
  thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])
604
 
605
+ tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])
606
+ tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)
607
  tOrO.fill(0.0)
608
 
609
  stage_load = self.stages - 1
 
612
  # Main accumulation loop
613
  for s in cutlass.range(thr_max_valid_split + 1, unroll=4):
614
  # Get scales for this split
615
+ scale = cute.make_rmem_tensor(num_rows, Float32)
616
  for m in cutlass.range(num_rows, unroll_full=True):
617
  scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem
618
 
 
642
  # Step 7: Write final O to gmem
643
  # ===============================
644
 
645
+ rO = cute.make_rmem_tensor_like(tOrO, self.dtype)
646
  rO.store(tOrO.load().to(self.dtype))
647
+ mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3)
648
  if const_expr(cu_seqlens is None):
649
  mO_cur = mO[None, None, None, batch_idx]
650
  else:
 
671
  tOrOptr: cute.Tensor,
672
  tOsO_partial: cute.Tensor,
673
  tOhidx: cute.Tensor,
674
+ tOpO: Optional[cute.Tensor],
675
  tOcO: cute.Tensor,
676
  mO_cur_partial_layout: cute.Layout,
677
  split: Int32,
 
690
  mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
691
  for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
692
  k_idx = tOcO[0, 0, k][1] // elems_per_load
693
+ if const_expr(tOpO is None) or tOpO[k]:
694
  cute.copy(
695
  gmem_tiled_copy_O_partial,
696
  mO_partial_cur_copy[None, k_idx, split],
build/torch-cuda/flash_fwd_sm100.py CHANGED
@@ -13,9 +13,8 @@
13
  # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
14
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py
15
 
16
- import enum
17
  import math
18
- from typing import Type, Tuple, Callable, Optional, Literal
19
  from functools import partial
20
 
21
  import cuda.bindings.driver as cuda
@@ -28,6 +27,7 @@ import cutlass.cute.nvgpu.tcgen05 as tcgen05
28
  import cutlass.utils.blackwell_helpers as sm100_utils_basic
29
  from cutlass import pipeline
30
  from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
 
31
  from cutlass.base_dsl.arch import Arch
32
  from cutlass.cutlass_dsl import BaseDSL
33
 
@@ -35,7 +35,9 @@ from .quack import copy_utils, layout_utils
35
 
36
  from .paged_kv import PagedKVManager
37
  from .cute_dsl_utils import assume_tensor_aligned
 
38
  from . import pipeline as pipeline_custom
 
39
  from .mask import AttentionMask
40
  from .softmax import SoftmaxSm100, apply_score_mod_inner
41
  from .seqlen_info import SeqlenInfoQK
@@ -47,33 +49,45 @@ from .block_sparse_utils import (
47
  softmax_block_sparse_sm100,
48
  handle_block_sparse_empty_tile_correction_sm100,
49
  )
50
- from .pack_gqa import PackGQA
51
  from . import mma_sm100_desc as sm100_desc
52
  from . import blackwell_helpers as sm100_utils
 
53
  from cutlass.cute import FastDivmodDivisor
54
  from .quack.cute_dsl_utils import ParamsBase
55
  from .tile_scheduler import (
 
 
56
  TileSchedulerArguments,
 
57
  SingleTileScheduler,
58
  StaticPersistentTileScheduler,
59
  SingleTileLPTScheduler,
60
  SingleTileVarlenScheduler,
61
  )
62
-
63
-
64
- class NamedBarrierFwd(enum.IntEnum):
65
- Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
66
- TmemPtr = enum.auto()
67
- SoftmaxStatsW0 = enum.auto()
68
- SoftmaxStatsW1 = enum.auto()
69
- SoftmaxStatsW2 = enum.auto()
70
- SoftmaxStatsW3 = enum.auto()
71
- SoftmaxStatsW4 = enum.auto()
72
- SoftmaxStatsW5 = enum.auto()
73
- SoftmaxStatsW6 = enum.auto()
74
- SoftmaxStatsW7 = enum.auto()
75
- # WarpSchedulerWG1 = enum.auto()
76
- # WarpSchedulerWG2 = enum.auto()
 
 
 
 
 
 
 
 
77
 
78
 
79
  class FlashAttentionForwardSm100:
@@ -99,6 +113,7 @@ class FlashAttentionForwardSm100:
99
  paged_kv_non_tma: bool = False,
100
  is_varlen_q: bool = False,
101
  use_2cta_instrs: bool = False,
 
102
  ):
103
  self.use_tma_KV = not paged_kv_non_tma
104
  # self.dtype = dtype
@@ -145,10 +160,6 @@ class FlashAttentionForwardSm100:
145
  self.is_split_kv = is_split_kv
146
  self.pack_gqa = pack_gqa
147
  self.q_subtile_factor = q_subtile_factor
148
- if pack_gqa:
149
- assert m_block_size % self.qhead_per_kvhead == 0, (
150
- "For PackGQA, m_block_size must be divisible by qhead_per_kvhead"
151
- )
152
  assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (
153
  "SplitKV is not supported for hdim >= 192"
154
  )
@@ -160,8 +171,10 @@ class FlashAttentionForwardSm100:
160
  # Does S1 need to wait for S0 to finish
161
  # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
162
  is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
163
- # self.enable_ex2_emu = self.head_dim_padded <= 128 and not is_sm103
164
- self.enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103
 
 
165
  self.s0_s1_barrier = False
166
  self.overlap_sO_sQ = (
167
  (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
@@ -174,6 +187,32 @@ class FlashAttentionForwardSm100:
174
  "Paged KV does not support irregular head dim"
175
  )
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  self.softmax0_warp_ids = (0, 1, 2, 3)
178
  self.softmax1_warp_ids = (4, 5, 6, 7)
179
  self.correction_warp_ids = (8, 9, 10, 11)
@@ -195,8 +234,10 @@ class FlashAttentionForwardSm100:
195
  )
196
  )
197
 
 
 
198
  if self.q_stage == 1:
199
- if not self.use_tma_KV:
200
  self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids
201
  self.load_warp_ids = self.softmax1_warp_ids
202
  else:
@@ -212,6 +253,8 @@ class FlashAttentionForwardSm100:
212
  elif self.is_varlen_q: # fallback
213
  self.epilogue_warp_ids = (13, 14)
214
 
 
 
215
  self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
216
  self.tmem_o_offset = [
217
  self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
@@ -227,31 +270,26 @@ class FlashAttentionForwardSm100:
227
  # vec buffer for row_max & row_sum
228
  self.tmem_vec_offset = self.tmem_s_offset
229
 
 
 
 
 
 
230
  if self.head_dim_padded < 96:
231
  self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
232
  self.num_regs_correction = 64
233
  self.num_regs_other = 48 if not paged_kv_non_tma else 80
234
  else:
235
- # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184
236
- if not self.enable_ex2_emu:
237
- self.num_regs_softmax = 192 if not paged_kv_non_tma else 184
 
 
 
238
  else:
239
- # self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
240
- self.num_regs_softmax = 192 if not paged_kv_non_tma else 184
241
- # self.num_regs_softmax = 176
242
- # self.num_regs_correction = 96
243
- # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80
244
- if not self.enable_ex2_emu:
245
- self.num_regs_correction = 80 if not paged_kv_non_tma else 64
246
- else:
247
- # self.num_regs_correction = 64
248
- self.num_regs_correction = 80 if not paged_kv_non_tma else 64
249
- # self.num_regs_other = 32
250
- # self.num_regs_other = 64
251
- # self.num_regs_other = 80
252
- self.num_regs_other = 48 if not paged_kv_non_tma else 80
253
- # self.num_regs_other = 96 if self.is_causal or self.is_local else 80
254
- # self.num_regs_other = 64 if self.is_causal or self.is_local else 80
255
 
256
  self.buffer_align_bytes = 1024
257
 
@@ -289,7 +327,7 @@ class FlashAttentionForwardSm100:
289
  self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3
290
  )
291
  self.uneven_kv_smem_offset = (
292
- self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2
293
  if self.uneven_kv_smem
294
  else 0
295
  )
@@ -304,7 +342,6 @@ class FlashAttentionForwardSm100:
304
  mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
305
  mLSE: Optional[cute.Tensor],
306
  softmax_scale: Float32,
307
- stream: cuda.CUstream,
308
  mCuSeqlensQ: Optional[cute.Tensor] = None,
309
  mCuSeqlensK: Optional[cute.Tensor] = None,
310
  mSeqUsedQ: Optional[cute.Tensor] = None,
@@ -315,6 +352,8 @@ class FlashAttentionForwardSm100:
315
  learnable_sink: Optional[cute.Tensor] = None,
316
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
317
  aux_tensors: Optional[list] = None,
 
 
318
  ):
319
  """Execute the Fused Multi-Head Attention operation on the provided tensors.
320
 
@@ -367,22 +406,21 @@ class FlashAttentionForwardSm100:
367
  if const_expr(self.q_dtype != self.v_dtype):
368
  raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
369
  self._setup_attributes()
370
- self.use_tma_O = self.arch >= Arch.sm_90 and mCuSeqlensQ is None and mSeqUsedQ is None
371
- # This can be tuned
372
- # This is currently very ad-hoc, we should tune it systematically
 
 
 
 
373
  self.ex2_emu_freq = 0
374
- # self.ex2_emu_start_frg = 1 if self.is_causal else 0
375
- self.ex2_emu_start_frg = 1
376
  if const_expr(self.enable_ex2_emu):
377
- self.ex2_emu_freq = 16
378
- if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs):
379
- self.ex2_emu_freq = 12
380
  if const_expr(
381
  self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local
382
  ):
383
- self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10
384
- if const_expr(self.head_dim_padded > 64 and self.is_causal):
385
- self.ex2_emu_freq = 10
386
 
387
  cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
388
  q_major_mode = tcgen05.OperandMajorMode.K
@@ -462,50 +500,11 @@ class FlashAttentionForwardSm100:
462
  )
463
 
464
  if const_expr(self.pack_gqa):
465
- shape_Q_packed = (
466
- (self.qhead_per_kvhead, mQ.shape[0]),
467
- mQ.shape[1],
468
- mK.shape[2],
469
- *mQ.shape[3:],
470
- )
471
- stride_Q_packed = (
472
- (mQ.stride[2], mQ.stride[0]),
473
- mQ.stride[1],
474
- mQ.stride[2] * self.qhead_per_kvhead,
475
- *mQ.stride[3:],
476
- )
477
- mQ = cute.make_tensor(
478
- mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)
479
- )
480
- shape_O_packed = (
481
- (self.qhead_per_kvhead, mO.shape[0]),
482
- mO.shape[1],
483
- mK.shape[2],
484
- *mO.shape[3:],
485
- )
486
- stride_O_packed = (
487
- (mO.stride[2], mO.stride[0]),
488
- mO.stride[1],
489
- mO.stride[2] * self.qhead_per_kvhead,
490
- *mO.stride[3:],
491
- )
492
- mO = cute.make_tensor(
493
- mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)
494
- )
495
  if const_expr(mLSE is not None):
496
- shape_LSE_packed = (
497
- (self.qhead_per_kvhead, mLSE.shape[0]),
498
- mK.shape[2],
499
- *mLSE.shape[2:],
500
- )
501
- stride_LSE_packed = (
502
- (mLSE.stride[1], mLSE.stride[0]),
503
- mLSE.stride[1] * self.qhead_per_kvhead,
504
- *mLSE.stride[2:],
505
- )
506
- mLSE = cute.make_tensor(
507
- mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
508
- )
509
 
510
  self.tma_copy_bytes = {
511
  name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
@@ -522,14 +521,24 @@ class FlashAttentionForwardSm100:
522
  tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
523
  tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
524
 
525
- tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A(
526
- tma_load_op,
527
- mQ,
528
- cute.select(sQ_layout, mode=[0, 1, 2]),
529
- self.mma_tiler_qk,
530
- tiled_mma_qk,
531
- cta_layout_vmnk.shape,
532
- )
 
 
 
 
 
 
 
 
 
 
533
 
534
  tma_atom_K = None
535
  tma_atom_V = None
@@ -578,19 +587,10 @@ class FlashAttentionForwardSm100:
578
  vO_layout = cute.make_layout((1, async_copy_elems))
579
  gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
580
 
581
- if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
582
- TileScheduler = SingleTileVarlenScheduler
583
- else:
584
- if const_expr(self.is_causal or self.is_local):
585
- TileScheduler = SingleTileLPTScheduler
586
- else:
587
- TileScheduler = (
588
- SingleTileScheduler
589
- if const_expr(not self.is_persistent)
590
- else StaticPersistentTileScheduler
591
- )
592
  tile_sched_args = TileSchedulerArguments(
593
- cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]),
594
  cute.size(mQ.shape[2]),
595
  cute.size(mQ.shape[3])
596
  if const_expr(mCuSeqlensQ is None)
@@ -613,8 +613,11 @@ class FlashAttentionForwardSm100:
613
  lpt=self.is_causal or self.is_local,
614
  is_split_kv=self.is_split_kv,
615
  cluster_shape_mn=self.cluster_shape_mn,
 
 
 
 
616
  )
617
- tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
618
  self.tile_scheduler_cls = TileScheduler
619
  grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
620
 
@@ -624,6 +627,9 @@ class FlashAttentionForwardSm100:
624
  cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)
625
  )
626
 
 
 
 
627
  @cute.struct
628
  class SharedStorage:
629
  # m_barriers for pipelines
@@ -643,6 +649,13 @@ class FlashAttentionForwardSm100:
643
  # Smem tensors
644
  # store row max and row sum
645
  sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]
 
 
 
 
 
 
 
646
  sO: cute.struct.Align[
647
  cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes
648
  ]
@@ -657,35 +670,10 @@ class FlashAttentionForwardSm100:
657
 
658
  self.shared_storage = SharedStorage
659
 
660
- LOG2_E = math.log2(math.e)
661
- if const_expr(self.score_mod is None):
662
- softmax_scale_log2 = softmax_scale * LOG2_E
663
- softmax_scale = None
664
- else:
665
- # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk
666
- # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base
667
- # and correctly apply the softmax_scale prior to score_mod in the softmax step
668
- softmax_scale_log2 = LOG2_E
669
- softmax_scale = softmax_scale
670
-
671
- if const_expr(window_size_left is not None):
672
- window_size_left = Int32(window_size_left)
673
- if const_expr(window_size_right is not None):
674
- window_size_right = Int32(window_size_right)
675
-
676
- fastdiv_mods = None
677
- if cutlass.const_expr(aux_tensors is not None):
678
- seqlen_q = cute.size(mQ.shape[0]) // (
679
- self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
680
- )
681
- seqlen_k = (
682
- cute.size(mK.shape[0])
683
- if const_expr(mPageTable is None)
684
- else mK.shape[0] * mPageTable.shape[1]
685
- )
686
- seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
687
- seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
688
- fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
689
 
690
  head_divmod = None
691
  if cutlass.const_expr(self.pack_gqa):
@@ -722,6 +710,7 @@ class FlashAttentionForwardSm100:
722
  tP_layout,
723
  sV_layout,
724
  sO_layout,
 
725
  gmem_tiled_copy_O,
726
  tiled_mma_qk,
727
  tiled_mma_pv,
@@ -752,7 +741,7 @@ class FlashAttentionForwardSm100:
752
  mSeqUsedQ: Optional[cute.Tensor],
753
  mSeqUsedK: Optional[cute.Tensor],
754
  mPageTable: Optional[cute.Tensor],
755
- tma_atom_Q: cute.CopyAtom,
756
  tma_atom_K: Optional[cute.CopyAtom],
757
  tma_atom_V: Optional[cute.CopyAtom],
758
  tma_atom_O: Optional[cute.CopyAtom],
@@ -767,6 +756,7 @@ class FlashAttentionForwardSm100:
767
  tP_layout: cute.ComposedLayout,
768
  sV_layout: cute.ComposedLayout,
769
  sO_layout: cute.ComposedLayout,
 
770
  gmem_tiled_copy_O: Optional[cute.TiledCopy],
771
  tiled_mma_qk: cute.TiledMma,
772
  tiled_mma_pv: cute.TiledMma,
@@ -814,7 +804,7 @@ class FlashAttentionForwardSm100:
814
  storage = smem.allocate(self.shared_storage)
815
 
816
  tmem_alloc_barrier = pipeline.NamedBarrier(
817
- barrier_id=int(NamedBarrierFwd.TmemPtr),
818
  num_threads=cute.arch.WARP_SIZE * len(
819
  (self.mma_warp_id,
820
  *self.softmax0_warp_ids,
@@ -833,8 +823,8 @@ class FlashAttentionForwardSm100:
833
 
834
  ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)
835
  mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id]))
836
- load_warps = ThreadCooperativeGroup(len(self.load_warp_ids))
837
  tma_warp = ThreadCooperativeGroup(1)
 
838
  softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids))
839
  softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids))
840
  # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE)
@@ -857,15 +847,25 @@ class FlashAttentionForwardSm100:
857
  softmax_correction_threads_cluster = ThreadCooperativeGroup(
858
  cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size
859
  )
860
- pipeline_q = pipeline_custom.PipelineTmaUmma.create(
861
- barrier_storage=storage.mbar_load_Q.data_ptr(),
862
- num_stages=self.q_stage,
863
- producer_group=tma_warp,
864
- consumer_group=mma_warp,
865
- tx_count=self.tma_copy_bytes["Q"],
866
- cta_layout_vmnk=cta_layout_vmnk,
867
- defer_sync=True,
868
- )
 
 
 
 
 
 
 
 
 
 
869
  if const_expr(self.use_tma_KV):
870
  pipeline_kv = pipeline_custom.PipelineTmaUmma.create(
871
  barrier_storage=storage.mbar_load_KV.data_ptr(),
@@ -877,13 +877,10 @@ class FlashAttentionForwardSm100:
877
  defer_sync=True,
878
  )
879
  else:
880
- cpasync_producer_group = pipeline.CooperativeGroup(
881
- pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE
882
- )
883
  pipeline_kv = pipeline.PipelineAsyncUmma.create(
884
  barrier_storage=storage.mbar_load_KV.data_ptr(),
885
  num_stages=self.kv_stage,
886
- producer_group=cpasync_producer_group,
887
  consumer_group=mma_warp,
888
  cta_layout_vmnk=cta_layout_vmnk,
889
  defer_sync=True,
@@ -938,7 +935,7 @@ class FlashAttentionForwardSm100:
938
  )
939
  # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats
940
  sm_stats_barrier = pipeline_custom.NamedBarrier(
941
- barrier_id=int(NamedBarrierFwd.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2
942
  )
943
  pipeline_o_epi = None
944
  if const_expr(not self.use_correction_warps_for_epi):
@@ -1019,17 +1016,69 @@ class FlashAttentionForwardSm100:
1019
  window_size_right=window_size_right,
1020
  qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1021
  )
1022
- TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
1023
-
1024
  # Cluster wait before tensor memory alloc
1025
  pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)
1026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1027
  # ///////////////////////////////////////////////////////////////////////////////
1028
- # EMPTY
1029
  # ///////////////////////////////////////////////////////////////////////////////
1030
- for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
1031
- if warp_idx == self.empty_warp_ids[i]:
1032
  cute.arch.setmaxregister_decrease(self.num_regs_other)
 
 
 
 
 
 
 
 
 
 
 
 
1033
 
1034
  # ///////////////////////////////////////////////////////////////////////////////
1035
  # LOAD
@@ -1049,13 +1098,14 @@ class FlashAttentionForwardSm100:
1049
  tma_atom_Q,
1050
  tma_atom_K,
1051
  tma_atom_V,
 
1052
  pipeline_q,
1053
  pipeline_kv,
1054
  block_info,
1055
  num_splits,
1056
  SeqlenInfoCls,
1057
- TileSchedulerCls,
1058
  blocksparse_tensors,
 
1059
  )
1060
 
1061
  # ///////////////////////////////////////////////////////////////////////////////
@@ -1085,8 +1135,8 @@ class FlashAttentionForwardSm100:
1085
  block_info,
1086
  num_splits,
1087
  SeqlenInfoCls,
1088
- TileSchedulerCls,
1089
  blocksparse_tensors,
 
1090
  )
1091
  # Dealloc the tensor memory buffer
1092
  tmem.relinquish_alloc_permit()
@@ -1108,8 +1158,8 @@ class FlashAttentionForwardSm100:
1108
  block_info,
1109
  num_splits,
1110
  SeqlenInfoCls,
1111
- TileSchedulerCls,
1112
  mma_tile_coord_v,
 
1113
  )
1114
 
1115
  # ///////////////////////////////////////////////////////////////////////////////
@@ -1141,11 +1191,11 @@ class FlashAttentionForwardSm100:
1141
  num_splits=num_splits,
1142
  SeqlenInfoCls=SeqlenInfoCls,
1143
  AttentionMaskCls=AttentionMaskCls,
1144
- TileSchedulerCls=TileSchedulerCls,
1145
  aux_tensors=aux_tensors,
1146
  fastdiv_mods=fastdiv_mods,
1147
  head_divmod=head_divmod,
1148
  blocksparse_tensors=blocksparse_tensors,
 
1149
  )
1150
 
1151
  if const_expr(not self.s0_s1_barrier):
@@ -1189,8 +1239,8 @@ class FlashAttentionForwardSm100:
1189
  block_info,
1190
  num_splits,
1191
  SeqlenInfoCls,
1192
- TileSchedulerCls,
1193
  blocksparse_tensors,
 
1194
  )
1195
  tmem_alloc_barrier.arrive()
1196
 
@@ -1208,35 +1258,38 @@ class FlashAttentionForwardSm100:
1208
  sK: cute.Tensor,
1209
  sV: cute.Tensor,
1210
  mPageTable: Optional[cute.Tensor],
1211
- tma_atom_Q: cute.CopyAtom,
1212
  tma_atom_K: Optional[cute.CopyAtom],
1213
  tma_atom_V: Optional[cute.CopyAtom],
 
1214
  pipeline_q: pipeline.PipelineAsync,
1215
  pipeline_kv: pipeline.PipelineAsync,
1216
  block_info: BlockInfo,
1217
  num_splits: Int32,
1218
  SeqlenInfoCls: Callable,
1219
- TileSchedulerCls: Callable,
1220
  blocksparse_tensors: Optional[BlockSparseTensors],
 
1221
  ):
1222
  num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE
1223
  tidx = cute.arch.thread_idx()[0] % num_load_threads
1224
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
 
 
 
 
 
 
 
 
1225
  q_producer_phase = Int32(1)
1226
  kv_producer_state = pipeline.make_pipeline_state(
1227
  pipeline.PipelineUserType.Producer, self.kv_stage
1228
  )
1229
- tile_scheduler = TileSchedulerCls()
1230
  work_tile = tile_scheduler.initial_work_tile_info()
1231
  while work_tile.is_valid_tile:
1232
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
1233
  seqlen = SeqlenInfoCls(batch_idx)
1234
  mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
1235
- tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded)
1236
- gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128)
1237
- gQ = layout_utils.select(
1238
- cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1]
1239
- ) # (128, 128, 2)
1240
 
1241
  head_idx_kv = (
1242
  head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
@@ -1258,12 +1311,32 @@ class FlashAttentionForwardSm100:
1258
  gV = cute.local_tile(
1259
  mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)
1260
  )
1261
- tSgQ = thr_mma_qk.partition_A(gQ)
1262
  tSgK = thr_mma_qk.partition_B(gK)
1263
  tOgV = thr_mma_pv.partition_B(gV)
1264
- load_Q_fn, _, _ = copy_utils.tma_get_copy_fn(
1265
- tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ
1266
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1267
 
1268
  if const_expr(self.use_tma_KV):
1269
  tKsK, tKgK = cpasync.tma_partition(
@@ -1302,7 +1375,6 @@ class FlashAttentionForwardSm100:
1302
  tKsK, tKgK = None, None
1303
  tVsV, tVgV = None, None
1304
 
1305
- load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase)
1306
  load_K = partial(
1307
  self.load_KV,
1308
  tma_atom_K,
@@ -1337,24 +1409,19 @@ class FlashAttentionForwardSm100:
1337
  )
1338
  if const_expr(not self.use_tma_KV):
1339
  paged_kv_manager.load_page_table(n_block_first)
1340
- load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0
 
1341
  # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0
1342
- if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]:
1343
- # load_Q(block=0, stage=0) # Q0
1344
- pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
1345
- # pipeline_q.sync_object_empty.wait(0, q_producer_phase)
1346
- tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(0)
1347
- # tma_bar_ptr = pipeline_kv.producer_get_barrier(kv_producer_state)
1348
- load_Q_fn(src_idx=0, dst_idx=0, tma_bar_ptr=tma_bar_ptr)
1349
- kv_producer_state.advance()
1350
- if const_expr(self.q_stage == 2) and (const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]):
1351
- # load_Q(block=1, stage=1) # Q1
1352
- pipeline_q.producer_acquire_w_index_phase(1, q_producer_phase)
1353
- tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(1)
1354
- load_Q_fn(src_idx=1, dst_idx=1, tma_bar_ptr=tma_bar_ptr)
1355
  q_producer_phase ^= 1
1356
- load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0
1357
- kv_producer_state.advance()
 
1358
  for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1359
  n_block = n_block_max - 2 - i
1360
  page_idx = (
@@ -1365,10 +1432,11 @@ class FlashAttentionForwardSm100:
1365
  if const_expr(not self.use_tma_KV):
1366
  paged_kv_manager.load_page_table(n_block)
1367
  # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx)
1368
- load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki
1369
- kv_producer_state.advance()
1370
- load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi
1371
- kv_producer_state.advance()
 
1372
 
1373
  else:
1374
  kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(
@@ -1387,14 +1455,14 @@ class FlashAttentionForwardSm100:
1387
  self.q_subtile_factor if self.q_subtile_factor is not None else 1,
1388
  )
1389
 
1390
- tile_scheduler.prefetch_next_work()
1391
- tile_scheduler.advance_to_next_work()
1392
- work_tile = tile_scheduler.get_current_work()
1393
  # End of persistent scheduler loop
1394
 
1395
- pipeline_kv.producer_tail(kv_producer_state)
1396
- # This is equivalent to pipeline_q.producer_tail
1397
- if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]:
 
1398
  pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase)
1399
 
1400
  @cute.jit
@@ -1417,8 +1485,8 @@ class FlashAttentionForwardSm100:
1417
  block_info: BlockInfo,
1418
  num_splits: Int32,
1419
  SeqlenInfoCls: Callable,
1420
- TileSchedulerCls: Callable,
1421
  blocksparse_tensors: Optional[BlockSparseTensors],
 
1422
  ):
1423
  tSrQ = tiled_mma_qk.make_fragment_A(sQ)
1424
  tSrK = tiled_mma_qk.make_fragment_B(sK)
@@ -1507,7 +1575,6 @@ class FlashAttentionForwardSm100:
1507
  )
1508
  P_full_O_rescaled_phase = Int32(0)
1509
 
1510
- tile_scheduler = TileSchedulerCls()
1511
  work_tile = tile_scheduler.initial_work_tile_info()
1512
  while work_tile.is_valid_tile:
1513
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
@@ -1678,8 +1745,7 @@ class FlashAttentionForwardSm100:
1678
  # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)
1679
 
1680
  # Advance to next tile
1681
- tile_scheduler.advance_to_next_work()
1682
- work_tile = tile_scheduler.get_current_work()
1683
  # End of persistent scheduler loop
1684
 
1685
  # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end
@@ -1708,11 +1774,11 @@ class FlashAttentionForwardSm100:
1708
  num_splits: Int32,
1709
  SeqlenInfoCls: Callable,
1710
  AttentionMaskCls: Callable,
1711
- TileSchedulerCls: Callable,
1712
  aux_tensors: Optional[list] = None,
1713
  fastdiv_mods=(None, None),
1714
  head_divmod=None,
1715
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
 
1716
  ):
1717
  """Compute softmax on attention scores from QK matrix multiplication.
1718
 
@@ -1772,7 +1838,6 @@ class FlashAttentionForwardSm100:
1772
 
1773
  warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
1774
 
1775
- tile_scheduler = TileSchedulerCls()
1776
  work_tile = tile_scheduler.initial_work_tile_info()
1777
  while work_tile.is_valid_tile:
1778
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
@@ -2015,8 +2080,7 @@ class FlashAttentionForwardSm100:
2015
  # gLSE[tidx] = lse
2016
 
2017
  # Advance to next tile
2018
- tile_scheduler.advance_to_next_work()
2019
- work_tile = tile_scheduler.get_current_work()
2020
  # End of persistent scheduler loop
2021
 
2022
  # This is equivalent to pipeline_sm_stats.producer_tail
@@ -2186,8 +2250,8 @@ class FlashAttentionForwardSm100:
2186
  block_info: BlockInfo,
2187
  num_splits: Int32,
2188
  SeqlenInfoCls: Callable,
2189
- TileSchedulerCls: Callable,
2190
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
 
2191
  ):
2192
  tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
2193
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
@@ -2217,7 +2281,6 @@ class FlashAttentionForwardSm100:
2217
  o_corr_consumer_phase = Int32(0)
2218
  corr_epi_producer_phase = Int32(1)
2219
 
2220
- tile_scheduler = TileSchedulerCls()
2221
  work_tile = tile_scheduler.initial_work_tile_info()
2222
  while work_tile.is_valid_tile:
2223
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
@@ -2228,12 +2291,14 @@ class FlashAttentionForwardSm100:
2228
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
2229
  else:
2230
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
2231
- tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)
2232
- gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128)
2233
- gO = layout_utils.select(
2234
- cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]
2235
- ) # (128, 128, 2)
2236
- gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]
 
 
2237
 
2238
  # Default LSE to -inf for invalid split_idx tiles
2239
  stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
@@ -2334,6 +2399,7 @@ class FlashAttentionForwardSm100:
2334
  pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)
2335
  if const_expr(not self.use_correction_warps_for_epi):
2336
  pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)
 
2337
  self.correction_epilogue(
2338
  thr_mma_pv,
2339
  tOtO[None, None, None, stage],
@@ -2344,7 +2410,7 @@ class FlashAttentionForwardSm100:
2344
  scale,
2345
  sO[None, None, stage],
2346
  mO_cur,
2347
- gO[None, None, stage],
2348
  gmem_tiled_copy_O,
2349
  )
2350
  # Signal for the next work tile that O buffers in tmem are already read, so
@@ -2414,7 +2480,6 @@ class FlashAttentionForwardSm100:
2414
  mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
2415
  for stage in cutlass.range_constexpr(self.q_stage):
2416
  m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
2417
- gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,))
2418
  row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]
2419
  # if tidx == 0 and stage <= 1:
2420
  # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
@@ -2429,13 +2494,24 @@ class FlashAttentionForwardSm100:
2429
  if const_expr(not self.pack_gqa)
2430
  else seqlen.seqlen_q * self.qhead_per_kvhead
2431
  )
2432
- if tidx < seqlen_q - m_tile_idx * self.m_block_size:
2433
- # This actually just works with PackGQA too
2434
- gLSE[tidx] = lse
 
 
 
 
 
 
 
 
 
 
 
 
2435
 
2436
  # Advance to next tile
2437
- tile_scheduler.advance_to_next_work()
2438
- work_tile = tile_scheduler.get_current_work()
2439
  # End of persistent scheduler loop
2440
 
2441
  # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps
@@ -2574,7 +2650,7 @@ class FlashAttentionForwardSm100:
2574
  if const_expr(self.use_correction_warps_for_epi):
2575
  assert(not self.use_tma_O)
2576
  assert(gmem_tiled_copy_O is not None)
2577
- cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue),
2578
  number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
2579
  mma_tile_coord_v = thr_mma.thr_idx
2580
  m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
@@ -2586,7 +2662,7 @@ class FlashAttentionForwardSm100:
2586
  def _store_O_to_gmem(
2587
  self,
2588
  sO_stage: cute.Tensor,
2589
- gO: cute.Tensor,
2590
  mO_cur: cute.Tensor,
2591
  gmem_tiled_copy_O: cute.TiledCopy,
2592
  tidx: Int32,
@@ -2597,7 +2673,6 @@ class FlashAttentionForwardSm100:
2597
  gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
2598
  tOsO = gmem_thr_copy_O.partition_S(sO_stage)
2599
  cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
2600
- tOgO = gmem_thr_copy_O.partition_D(gO)
2601
  tOcO = gmem_thr_copy_O.partition_S(cO)
2602
  t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
2603
  tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1])
@@ -2613,6 +2688,8 @@ class FlashAttentionForwardSm100:
2613
  cute.autovec_copy(tOsO, tOrO)
2614
  # copy acc O from rmem to gmem
2615
  if const_expr(not self.pack_gqa):
 
 
2616
  for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
2617
  if (
2618
  t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0]
@@ -2641,11 +2718,10 @@ class FlashAttentionForwardSm100:
2641
  block_info: BlockInfo,
2642
  num_splits: int,
2643
  SeqlenInfoCls: Callable,
2644
- TileSchedulerCls: Callable,
2645
  mma_tile_coord_v: Int32 = 0,
 
2646
  ):
2647
  epi_consumer_phase = Int32(0)
2648
- tile_scheduler = TileSchedulerCls()
2649
  work_tile = tile_scheduler.initial_work_tile_info()
2650
  while work_tile.is_valid_tile:
2651
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
@@ -2657,12 +2733,14 @@ class FlashAttentionForwardSm100:
2657
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
2658
  else:
2659
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
2660
- tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)
2661
- gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128)
2662
- gO = layout_utils.select(
2663
- cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]
2664
- ) # (128, 128, 2)
2665
- gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]
 
 
2666
 
2667
  if const_expr(self.use_tma_O):
2668
  store_O, _, _ = copy_utils.tma_get_copy_fn(
@@ -2689,8 +2767,9 @@ class FlashAttentionForwardSm100:
2689
  pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase)
2690
  # 2. copy O0 / O1 to gmem
2691
  m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
 
2692
  self._store_O_to_gmem(
2693
- sO[None, None, stage], gO[None, None, stage], mO_cur, gmem_tiled_copy_O,
2694
  tidx, seqlen.seqlen_q, m_tile_idx,
2695
  )
2696
  pipeline_o_epi.consumer_release_w_index(stage)
@@ -2698,8 +2777,39 @@ class FlashAttentionForwardSm100:
2698
  epi_consumer_phase ^= 1
2699
 
2700
  # Advance to next tile
2701
- tile_scheduler.advance_to_next_work()
2702
- work_tile = tile_scheduler.get_current_work()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2703
 
2704
  def load_Q(
2705
  self,
@@ -2712,6 +2822,39 @@ class FlashAttentionForwardSm100:
2712
  pipeline_q.producer_acquire_w_index_phase(stage, phase)
2713
  load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage))
2714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2715
  @cute.jit
2716
  def load_KV(
2717
  self,
@@ -2754,7 +2897,10 @@ class FlashAttentionForwardSm100:
2754
  else:
2755
  assert paged_kv_manager is not None
2756
  assert extra_tx_count is None
2757
- paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V)
 
 
 
2758
  cute.arch.cp_async_commit_group()
2759
  pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage)
2760
 
@@ -2765,6 +2911,9 @@ class FlashAttentionForwardSm100:
2765
  # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if
2766
  # phase == 0, or left by offset if phase == 1.
2767
  offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)
 
 
 
2768
  return cute.make_tensor(sX.iterator + offset, sX.layout)
2769
  else:
2770
  return sX
@@ -2774,12 +2923,12 @@ class FlashAttentionForwardSm100:
2774
  # warp_group_idx = utils.canonical_warp_group_idx(sync=False)
2775
  # if warp_group_idx == 0:
2776
  # cute.arch.barrier_arrive(
2777
- # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128,
2778
  # )
2779
 
2780
  # def warp_scheduler_barrier_sync(self):
2781
  # cute.arch.barrier(
2782
- # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False),
2783
  # number_of_threads=2 * 128
2784
  # )
2785
 
@@ -2787,7 +2936,7 @@ class FlashAttentionForwardSm100:
2787
  # cur_wg = utils.canonical_warp_group_idx(sync=False)
2788
  # next_wg = 1 - cur_wg
2789
  # cute.arch.barrier_arrive(
2790
- # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128,
2791
  # )
2792
 
2793
  @cute.jit
 
13
  # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
14
  # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py
15
 
 
16
  import math
17
+ from typing import Tuple, Callable, Optional, Literal
18
  from functools import partial
19
 
20
  import cuda.bindings.driver as cuda
 
27
  import cutlass.utils.blackwell_helpers as sm100_utils_basic
28
  from cutlass import pipeline
29
  from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
30
+ from cutlass.utils import ClcDynamicPersistentTileScheduler
31
  from cutlass.base_dsl.arch import Arch
32
  from cutlass.cutlass_dsl import BaseDSL
33
 
 
35
 
36
  from .paged_kv import PagedKVManager
37
  from .cute_dsl_utils import assume_tensor_aligned
38
+ from . import utils
39
  from . import pipeline as pipeline_custom
40
+ import cutlass.pipeline as cutlass_pipeline
41
  from .mask import AttentionMask
42
  from .softmax import SoftmaxSm100, apply_score_mod_inner
43
  from .seqlen_info import SeqlenInfoQK
 
49
  softmax_block_sparse_sm100,
50
  handle_block_sparse_empty_tile_correction_sm100,
51
  )
52
+ from .pack_gqa import PackGQA, pack_gqa_layout
53
  from . import mma_sm100_desc as sm100_desc
54
  from . import blackwell_helpers as sm100_utils
55
+ from .named_barrier import NamedBarrierFwdSm100
56
  from cutlass.cute import FastDivmodDivisor
57
  from .quack.cute_dsl_utils import ParamsBase
58
  from .tile_scheduler import (
59
+ ClcState,
60
+ SchedulingMode,
61
  TileSchedulerArguments,
62
+ TileSchedulerProtocol,
63
  SingleTileScheduler,
64
  StaticPersistentTileScheduler,
65
  SingleTileLPTScheduler,
66
  SingleTileVarlenScheduler,
67
  )
68
+ from .fa_logging import fa_log, fa_printf
69
+ from .utils import smid
70
+
71
+ # === TUNING KNOBS (agent-editable) ===
72
+ # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool)
73
+ # Values:
74
+ # ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation).
75
+ # SM103 has fast native exp2, so set freq=0 there.
76
+ # ex2_emu_start_frg: int — fragment index to start emulation from
77
+ # num_regs_softmax: int — register count for softmax warps (multiple of 8)
78
+ # num_regs_correction: int — register count for correction warps (multiple of 8)
79
+ # num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction
80
+ _TUNING_CONFIG = {
81
+ (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88},
82
+ (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72},
83
+ (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80},
84
+ (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72},
85
+ (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80},
86
+ (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64},
87
+ (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64},
88
+ (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72},
89
+ }
90
+ # === END TUNING KNOBS ===
91
 
92
 
93
  class FlashAttentionForwardSm100:
 
113
  paged_kv_non_tma: bool = False,
114
  is_varlen_q: bool = False,
115
  use_2cta_instrs: bool = False,
116
+ use_clc_scheduler: bool = False,
117
  ):
118
  self.use_tma_KV = not paged_kv_non_tma
119
  # self.dtype = dtype
 
160
  self.is_split_kv = is_split_kv
161
  self.pack_gqa = pack_gqa
162
  self.q_subtile_factor = q_subtile_factor
 
 
 
 
163
  assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (
164
  "SplitKV is not supported for hdim >= 192"
165
  )
 
171
  # Does S1 need to wait for S0 to finish
172
  # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
173
  is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
174
+ self.is_sm103 = is_sm103
175
+ # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic
176
+ _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103
177
+ self.enable_ex2_emu = _default_enable_ex2_emu
178
  self.s0_s1_barrier = False
179
  self.overlap_sO_sQ = (
180
  (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
 
187
  "Paged KV does not support irregular head dim"
188
  )
189
 
190
+ self.use_clc_scheduler = (
191
+ use_clc_scheduler
192
+ and self.use_tma_KV
193
+ and not self.overlap_sO_sQ
194
+ )
195
+ self.sched_stages = 1
196
+ if self.use_clc_scheduler:
197
+ assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}"
198
+ assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}"
199
+ assert self.cluster_shape_mn[0] == self.cta_group_size, (
200
+ f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}"
201
+ )
202
+
203
+ self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC
204
+
205
+ if is_varlen_q:
206
+ self.TileScheduler = SingleTileVarlenScheduler
207
+ elif self.is_causal or self.is_local or self.use_clc_scheduler:
208
+ self.TileScheduler = SingleTileLPTScheduler
209
+ elif self.is_persistent:
210
+ self.TileScheduler = StaticPersistentTileScheduler
211
+ else:
212
+ self.TileScheduler = SingleTileScheduler
213
+
214
+ fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}, USE_2CTA={self.use_2cta_instrs}")
215
+
216
  self.softmax0_warp_ids = (0, 1, 2, 3)
217
  self.softmax1_warp_ids = (4, 5, 6, 7)
218
  self.correction_warp_ids = (8, 9, 10, 11)
 
234
  )
235
  )
236
 
237
+ self.use_tma_Q = not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0)
238
+
239
  if self.q_stage == 1:
240
+ if not self.use_tma_KV or not self.use_tma_Q:
241
  self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids
242
  self.load_warp_ids = self.softmax1_warp_ids
243
  else:
 
253
  elif self.is_varlen_q: # fallback
254
  self.epilogue_warp_ids = (13, 14)
255
 
256
+ self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None
257
+
258
  self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
259
  self.tmem_o_offset = [
260
  self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
 
270
  # vec buffer for row_max & row_sum
271
  self.tmem_vec_offset = self.tmem_s_offset
272
 
273
+ # Look up tuning config for register counts and ex2_emu params
274
+ _tune_key = (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103)
275
+ self._tune = _TUNING_CONFIG.get(_tune_key, {})
276
+ if "ex2_emu_freq" in self._tune:
277
+ self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0
278
  if self.head_dim_padded < 96:
279
  self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
280
  self.num_regs_correction = 64
281
  self.num_regs_other = 48 if not paged_kv_non_tma else 80
282
  else:
283
+ if not paged_kv_non_tma and "num_regs_softmax" in self._tune:
284
+ self.num_regs_softmax = self._tune["num_regs_softmax"]
285
+ self.num_regs_correction = self._tune["num_regs_correction"]
286
+ elif not paged_kv_non_tma:
287
+ self.num_regs_softmax = 192
288
+ self.num_regs_correction = 80
289
  else:
290
+ self.num_regs_softmax = 184
291
+ self.num_regs_correction = 64
292
+ self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  self.buffer_align_bytes = 1024
295
 
 
327
  self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3
328
  )
329
  self.uneven_kv_smem_offset = (
330
+ self.n_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2
331
  if self.uneven_kv_smem
332
  else 0
333
  )
 
342
  mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
343
  mLSE: Optional[cute.Tensor],
344
  softmax_scale: Float32,
 
345
  mCuSeqlensQ: Optional[cute.Tensor] = None,
346
  mCuSeqlensK: Optional[cute.Tensor] = None,
347
  mSeqUsedQ: Optional[cute.Tensor] = None,
 
352
  learnable_sink: Optional[cute.Tensor] = None,
353
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
354
  aux_tensors: Optional[list] = None,
355
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
356
+ stream: cuda.CUstream = None,
357
  ):
358
  """Execute the Fused Multi-Head Attention operation on the provided tensors.
359
 
 
406
  if const_expr(self.q_dtype != self.v_dtype):
407
  raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}")
408
  self._setup_attributes()
409
+ self.use_tma_O = (
410
+ self.arch >= Arch.sm_90
411
+ and mCuSeqlensQ is None
412
+ and mSeqUsedQ is None
413
+ and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0)
414
+ and not (self.pack_gqa and self.is_split_kv)
415
+ )
416
  self.ex2_emu_freq = 0
417
+ self.ex2_emu_start_frg = self._tune.get("ex2_emu_start_frg", 1)
 
418
  if const_expr(self.enable_ex2_emu):
419
+ self.ex2_emu_freq = self._tune.get("ex2_emu_freq", 16)
 
 
420
  if const_expr(
421
  self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local
422
  ):
423
+ self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else self._tune.get("ex2_emu_freq", 10)
 
 
424
 
425
  cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
426
  q_major_mode = tcgen05.OperandMajorMode.K
 
500
  )
501
 
502
  if const_expr(self.pack_gqa):
503
+ nheads_kv = mK.shape[2]
504
+ mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2)
505
+ mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  if const_expr(mLSE is not None):
507
+ mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1)
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
  self.tma_copy_bytes = {
510
  name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
 
521
  tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
522
  tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
523
 
524
+ if const_expr(self.use_tma_Q):
525
+ tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A(
526
+ tma_load_op,
527
+ mQ,
528
+ cute.select(sQ_layout, mode=[0, 1, 2]),
529
+ self.mma_tiler_qk,
530
+ tiled_mma_qk,
531
+ cta_layout_vmnk.shape,
532
+ )
533
+ gmem_tiled_copy_Q = None
534
+ else:
535
+ tma_atom_Q = None
536
+ async_copy_elems = 128 // self.q_dtype.width
537
+ num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids)
538
+ threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads)
539
+ gmem_tiled_copy_Q = copy_utils.tiled_copy_2d(
540
+ self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True
541
+ )
542
 
543
  tma_atom_K = None
544
  tma_atom_V = None
 
587
  vO_layout = cute.make_layout((1, async_copy_elems))
588
  gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
589
 
590
+ TileScheduler = self.TileScheduler
591
+ _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1)
 
 
 
 
 
 
 
 
 
592
  tile_sched_args = TileSchedulerArguments(
593
+ cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor),
594
  cute.size(mQ.shape[2]),
595
  cute.size(mQ.shape[3])
596
  if const_expr(mCuSeqlensQ is None)
 
613
  lpt=self.is_causal or self.is_local,
614
  is_split_kv=self.is_split_kv,
615
  cluster_shape_mn=self.cluster_shape_mn,
616
+ use_cluster_idx=not self.is_persistent and self.cta_group_size > 1,
617
+ )
618
+ tile_sched_params = TileScheduler.to_underlying_arguments(
619
+ tile_sched_args, scheduling_mode=self.scheduling_mode
620
  )
 
621
  self.tile_scheduler_cls = TileScheduler
622
  grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
623
 
 
627
  cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)
628
  )
629
 
630
+ clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0
631
+ clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0
632
+
633
  @cute.struct
634
  class SharedStorage:
635
  # m_barriers for pipelines
 
649
  # Smem tensors
650
  # store row max and row sum
651
  sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]
652
+ # CLC buffers placed here to utilize padding before sO's 1024-byte alignment.
653
+ # This avoids adding bytes at the end when we're at the smem limit.
654
+ # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty).
655
+ clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size]
656
+ # CLC response storage (16 bytes per stage, stored as 4 Int32s).
657
+ clc_response: cute.struct.MemRange[Int32, clc_response_size]
658
+ # Large TMA buffers with 1024-byte alignment
659
  sO: cute.struct.Align[
660
  cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes
661
  ]
 
670
 
671
  self.shared_storage = SharedStorage
672
 
673
+ softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)
674
+ window_size_left = Int32(window_size_left) if window_size_left is not None else None
675
+ window_size_right = Int32(window_size_right) if window_size_right is not None else None
676
+ fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  head_divmod = None
679
  if cutlass.const_expr(self.pack_gqa):
 
710
  tP_layout,
711
  sV_layout,
712
  sO_layout,
713
+ gmem_tiled_copy_Q,
714
  gmem_tiled_copy_O,
715
  tiled_mma_qk,
716
  tiled_mma_pv,
 
741
  mSeqUsedQ: Optional[cute.Tensor],
742
  mSeqUsedK: Optional[cute.Tensor],
743
  mPageTable: Optional[cute.Tensor],
744
+ tma_atom_Q: Optional[cute.CopyAtom],
745
  tma_atom_K: Optional[cute.CopyAtom],
746
  tma_atom_V: Optional[cute.CopyAtom],
747
  tma_atom_O: Optional[cute.CopyAtom],
 
756
  tP_layout: cute.ComposedLayout,
757
  sV_layout: cute.ComposedLayout,
758
  sO_layout: cute.ComposedLayout,
759
+ gmem_tiled_copy_Q: Optional[cute.TiledCopy],
760
  gmem_tiled_copy_O: Optional[cute.TiledCopy],
761
  tiled_mma_qk: cute.TiledMma,
762
  tiled_mma_pv: cute.TiledMma,
 
804
  storage = smem.allocate(self.shared_storage)
805
 
806
  tmem_alloc_barrier = pipeline.NamedBarrier(
807
+ barrier_id=int(NamedBarrierFwdSm100.TmemPtr),
808
  num_threads=cute.arch.WARP_SIZE * len(
809
  (self.mma_warp_id,
810
  *self.softmax0_warp_ids,
 
823
 
824
  ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)
825
  mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id]))
 
826
  tma_warp = ThreadCooperativeGroup(1)
827
+ load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE)
828
  softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids))
829
  softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids))
830
  # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE)
 
847
  softmax_correction_threads_cluster = ThreadCooperativeGroup(
848
  cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size
849
  )
850
+ if const_expr(self.use_tma_Q):
851
+ pipeline_q = pipeline_custom.PipelineTmaUmma.create(
852
+ barrier_storage=storage.mbar_load_Q.data_ptr(),
853
+ num_stages=self.q_stage,
854
+ producer_group=tma_warp,
855
+ consumer_group=mma_warp,
856
+ tx_count=self.tma_copy_bytes["Q"],
857
+ cta_layout_vmnk=cta_layout_vmnk,
858
+ defer_sync=True,
859
+ )
860
+ else:
861
+ pipeline_q = pipeline_custom.PipelineAsyncUmma.create(
862
+ barrier_storage=storage.mbar_load_Q.data_ptr(),
863
+ num_stages=self.q_stage,
864
+ producer_group=load_threads,
865
+ consumer_group=mma_warp,
866
+ cta_layout_vmnk=cta_layout_vmnk,
867
+ defer_sync=True,
868
+ )
869
  if const_expr(self.use_tma_KV):
870
  pipeline_kv = pipeline_custom.PipelineTmaUmma.create(
871
  barrier_storage=storage.mbar_load_KV.data_ptr(),
 
877
  defer_sync=True,
878
  )
879
  else:
 
 
 
880
  pipeline_kv = pipeline.PipelineAsyncUmma.create(
881
  barrier_storage=storage.mbar_load_KV.data_ptr(),
882
  num_stages=self.kv_stage,
883
+ producer_group=load_threads,
884
  consumer_group=mma_warp,
885
  cta_layout_vmnk=cta_layout_vmnk,
886
  defer_sync=True,
 
935
  )
936
  # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats
937
  sm_stats_barrier = pipeline_custom.NamedBarrier(
938
+ barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2
939
  )
940
  pipeline_o_epi = None
941
  if const_expr(not self.use_correction_warps_for_epi):
 
1016
  window_size_right=window_size_right,
1017
  qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1018
  )
 
 
1019
  # Cluster wait before tensor memory alloc
1020
  pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)
1021
 
1022
+ if const_expr(self.use_clc_scheduler):
1023
+ clc_response_ptr = storage.clc_response.data_ptr()
1024
+ clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr()
1025
+
1026
+ clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup(
1027
+ cutlass_pipeline.Agent.Thread
1028
+ )
1029
+ num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE
1030
+ # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume
1031
+ num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size
1032
+ clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup(
1033
+ cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps
1034
+ )
1035
+
1036
+ block_idx = cute.arch.block_idx()
1037
+ clc = ClcState.create(
1038
+ hw_scheduler=ClcDynamicPersistentTileScheduler.create(
1039
+ self.tile_scheduler_cls.clc_problem_shape(tile_sched_params),
1040
+ block_idx,
1041
+ cute.arch.grid_dim(),
1042
+ clc_response_ptr,
1043
+ ),
1044
+ pipeline=cutlass_pipeline.PipelineClcFetchAsync.create(
1045
+ barrier_storage=clc_mbar_ptr,
1046
+ num_stages=self.sched_stages,
1047
+ producer_group=clc_pipeline_producer_group,
1048
+ consumer_group=clc_pipeline_consumer_group,
1049
+ tx_count=16,
1050
+ cta_layout_vmnk=cta_layout_vmnk,
1051
+ ),
1052
+ consumer_state=cutlass_pipeline.make_pipeline_state(
1053
+ cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages
1054
+ ),
1055
+ producer_state=cutlass_pipeline.make_pipeline_state(
1056
+ cutlass_pipeline.PipelineUserType.Producer, self.sched_stages
1057
+ ),
1058
+ )
1059
+ tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc)
1060
+ else:
1061
+ tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params)
1062
+ assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}"
1063
+
1064
  # ///////////////////////////////////////////////////////////////////////////////
1065
+ # EMPTY / CLC SCHEDULER WARP
1066
  # ///////////////////////////////////////////////////////////////////////////////
1067
+ if const_expr(self.use_clc_scheduler):
1068
+ if warp_idx == self.clc_scheduler_warp_id:
1069
  cute.arch.setmaxregister_decrease(self.num_regs_other)
1070
+ if is_leader_cta:
1071
+ self.clc_scheduler_warp(tile_scheduler)
1072
+ else:
1073
+ self.empty_warp(tile_scheduler)
1074
+ for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
1075
+ if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id:
1076
+ cute.arch.setmaxregister_decrease(self.num_regs_other)
1077
+ self.empty_warp(tile_scheduler)
1078
+ else:
1079
+ for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
1080
+ if warp_idx == self.empty_warp_ids[i]:
1081
+ cute.arch.setmaxregister_decrease(self.num_regs_other)
1082
 
1083
  # ///////////////////////////////////////////////////////////////////////////////
1084
  # LOAD
 
1098
  tma_atom_Q,
1099
  tma_atom_K,
1100
  tma_atom_V,
1101
+ gmem_tiled_copy_Q,
1102
  pipeline_q,
1103
  pipeline_kv,
1104
  block_info,
1105
  num_splits,
1106
  SeqlenInfoCls,
 
1107
  blocksparse_tensors,
1108
+ tile_scheduler=tile_scheduler,
1109
  )
1110
 
1111
  # ///////////////////////////////////////////////////////////////////////////////
 
1135
  block_info,
1136
  num_splits,
1137
  SeqlenInfoCls,
 
1138
  blocksparse_tensors,
1139
+ tile_scheduler=tile_scheduler,
1140
  )
1141
  # Dealloc the tensor memory buffer
1142
  tmem.relinquish_alloc_permit()
 
1158
  block_info,
1159
  num_splits,
1160
  SeqlenInfoCls,
 
1161
  mma_tile_coord_v,
1162
+ tile_scheduler=tile_scheduler,
1163
  )
1164
 
1165
  # ///////////////////////////////////////////////////////////////////////////////
 
1191
  num_splits=num_splits,
1192
  SeqlenInfoCls=SeqlenInfoCls,
1193
  AttentionMaskCls=AttentionMaskCls,
 
1194
  aux_tensors=aux_tensors,
1195
  fastdiv_mods=fastdiv_mods,
1196
  head_divmod=head_divmod,
1197
  blocksparse_tensors=blocksparse_tensors,
1198
+ tile_scheduler=tile_scheduler,
1199
  )
1200
 
1201
  if const_expr(not self.s0_s1_barrier):
 
1239
  block_info,
1240
  num_splits,
1241
  SeqlenInfoCls,
 
1242
  blocksparse_tensors,
1243
+ tile_scheduler=tile_scheduler,
1244
  )
1245
  tmem_alloc_barrier.arrive()
1246
 
 
1258
  sK: cute.Tensor,
1259
  sV: cute.Tensor,
1260
  mPageTable: Optional[cute.Tensor],
1261
+ tma_atom_Q: Optional[cute.CopyAtom],
1262
  tma_atom_K: Optional[cute.CopyAtom],
1263
  tma_atom_V: Optional[cute.CopyAtom],
1264
+ gmem_tiled_copy_Q: Optional[cute.TiledCopy],
1265
  pipeline_q: pipeline.PipelineAsync,
1266
  pipeline_kv: pipeline.PipelineAsync,
1267
  block_info: BlockInfo,
1268
  num_splits: Int32,
1269
  SeqlenInfoCls: Callable,
 
1270
  blocksparse_tensors: Optional[BlockSparseTensors],
1271
+ tile_scheduler: TileSchedulerProtocol,
1272
  ):
1273
  num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE
1274
  tidx = cute.arch.thread_idx()[0] % num_load_threads
1275
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
1276
+ issue_kv_for_this_warp = (
1277
+ const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or
1278
+ warp_idx == self.load_warp_ids[0]
1279
+ )
1280
+ issue_q_for_this_warp = (
1281
+ const_expr(not self.use_tma_Q or len(self.load_warp_ids) == 1) or
1282
+ warp_idx == self.load_warp_ids[0]
1283
+ )
1284
  q_producer_phase = Int32(1)
1285
  kv_producer_state = pipeline.make_pipeline_state(
1286
  pipeline.PipelineUserType.Producer, self.kv_stage
1287
  )
 
1288
  work_tile = tile_scheduler.initial_work_tile_info()
1289
  while work_tile.is_valid_tile:
1290
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
1291
  seqlen = SeqlenInfoCls(batch_idx)
1292
  mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
 
 
 
 
 
1293
 
1294
  head_idx_kv = (
1295
  head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
 
1311
  gV = cute.local_tile(
1312
  mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)
1313
  )
 
1314
  tSgK = thr_mma_qk.partition_B(gK)
1315
  tOgV = thr_mma_pv.partition_B(gV)
1316
+ if const_expr(self.use_tma_Q):
1317
+ tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded)
1318
+ gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128)
1319
+ gQ = layout_utils.select(
1320
+ cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1]
1321
+ ) # (128, 128, 2)
1322
+ tSgQ = thr_mma_qk.partition_A(gQ)
1323
+ load_Q_fn, _, _ = copy_utils.tma_get_copy_fn(
1324
+ tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ
1325
+ )
1326
+ load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase)
1327
+ else:
1328
+ assert gmem_tiled_copy_Q is not None
1329
+ load_Q = partial(
1330
+ self.load_Q_non_tma,
1331
+ mQ_cur,
1332
+ sQ,
1333
+ gmem_tiled_copy_Q,
1334
+ pipeline_q,
1335
+ tidx,
1336
+ seqlen.seqlen_q,
1337
+ m_block,
1338
+ phase=q_producer_phase,
1339
+ )
1340
 
1341
  if const_expr(self.use_tma_KV):
1342
  tKsK, tKgK = cpasync.tma_partition(
 
1375
  tKsK, tKgK = None, None
1376
  tVsV, tVgV = None, None
1377
 
 
1378
  load_K = partial(
1379
  self.load_KV,
1380
  tma_atom_K,
 
1409
  )
1410
  if const_expr(not self.use_tma_KV):
1411
  paged_kv_manager.load_page_table(n_block_first)
1412
+ if issue_kv_for_this_warp:
1413
+ load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0
1414
  # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0
1415
+ if issue_q_for_this_warp:
1416
+ load_Q(block=0, stage=0)
1417
+ if issue_kv_for_this_warp:
1418
+ kv_producer_state.advance()
1419
+ if const_expr(self.q_stage == 2) and issue_q_for_this_warp:
1420
+ load_Q(block=1, stage=1)
 
 
 
 
 
 
 
1421
  q_producer_phase ^= 1
1422
+ if issue_kv_for_this_warp:
1423
+ load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0
1424
+ kv_producer_state.advance()
1425
  for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
1426
  n_block = n_block_max - 2 - i
1427
  page_idx = (
 
1432
  if const_expr(not self.use_tma_KV):
1433
  paged_kv_manager.load_page_table(n_block)
1434
  # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx)
1435
+ if issue_kv_for_this_warp:
1436
+ load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki
1437
+ kv_producer_state.advance()
1438
+ load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi
1439
+ kv_producer_state.advance()
1440
 
1441
  else:
1442
  kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(
 
1455
  self.q_subtile_factor if self.q_subtile_factor is not None else 1,
1456
  )
1457
 
1458
+
1459
+ work_tile = tile_scheduler.advance_to_next_work()
 
1460
  # End of persistent scheduler loop
1461
 
1462
+ if issue_kv_for_this_warp:
1463
+ pipeline_kv.producer_tail(kv_producer_state)
1464
+ # This is equivalent to pipeline_q.producer_tail for the TMA-Q producer warp.
1465
+ if issue_q_for_this_warp:
1466
  pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase)
1467
 
1468
  @cute.jit
 
1485
  block_info: BlockInfo,
1486
  num_splits: Int32,
1487
  SeqlenInfoCls: Callable,
 
1488
  blocksparse_tensors: Optional[BlockSparseTensors],
1489
+ tile_scheduler=None,
1490
  ):
1491
  tSrQ = tiled_mma_qk.make_fragment_A(sQ)
1492
  tSrK = tiled_mma_qk.make_fragment_B(sK)
 
1575
  )
1576
  P_full_O_rescaled_phase = Int32(0)
1577
 
 
1578
  work_tile = tile_scheduler.initial_work_tile_info()
1579
  while work_tile.is_valid_tile:
1580
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
 
1745
  # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)
1746
 
1747
  # Advance to next tile
1748
+ work_tile = tile_scheduler.advance_to_next_work()
 
1749
  # End of persistent scheduler loop
1750
 
1751
  # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end
 
1774
  num_splits: Int32,
1775
  SeqlenInfoCls: Callable,
1776
  AttentionMaskCls: Callable,
 
1777
  aux_tensors: Optional[list] = None,
1778
  fastdiv_mods=(None, None),
1779
  head_divmod=None,
1780
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
1781
+ tile_scheduler=None,
1782
  ):
1783
  """Compute softmax on attention scores from QK matrix multiplication.
1784
 
 
1838
 
1839
  warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
1840
 
 
1841
  work_tile = tile_scheduler.initial_work_tile_info()
1842
  while work_tile.is_valid_tile:
1843
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
 
2080
  # gLSE[tidx] = lse
2081
 
2082
  # Advance to next tile
2083
+ work_tile = tile_scheduler.advance_to_next_work()
 
2084
  # End of persistent scheduler loop
2085
 
2086
  # This is equivalent to pipeline_sm_stats.producer_tail
 
2250
  block_info: BlockInfo,
2251
  num_splits: Int32,
2252
  SeqlenInfoCls: Callable,
 
2253
  blocksparse_tensors: Optional[BlockSparseTensors] = None,
2254
+ tile_scheduler=None,
2255
  ):
2256
  tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))
2257
  warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
 
2281
  o_corr_consumer_phase = Int32(0)
2282
  corr_epi_producer_phase = Int32(1)
2283
 
 
2284
  work_tile = tile_scheduler.initial_work_tile_info()
2285
  while work_tile.is_valid_tile:
2286
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
 
2291
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
2292
  else:
2293
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
2294
+ gO = None
2295
+ if const_expr(self.use_tma_O or not self.pack_gqa):
2296
+ tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)
2297
+ gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128)
2298
+ gO = layout_utils.select(
2299
+ cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]
2300
+ ) # (128, 128, 2)
2301
+ gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]
2302
 
2303
  # Default LSE to -inf for invalid split_idx tiles
2304
  stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
 
2399
  pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)
2400
  if const_expr(not self.use_correction_warps_for_epi):
2401
  pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)
2402
+ gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None
2403
  self.correction_epilogue(
2404
  thr_mma_pv,
2405
  tOtO[None, None, None, stage],
 
2410
  scale,
2411
  sO[None, None, stage],
2412
  mO_cur,
2413
+ gO_stage,
2414
  gmem_tiled_copy_O,
2415
  )
2416
  # Signal for the next work tile that O buffers in tmem are already read, so
 
2480
  mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])
2481
  for stage in cutlass.range_constexpr(self.q_stage):
2482
  m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
 
2483
  row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]
2484
  # if tidx == 0 and stage <= 1:
2485
  # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
 
2494
  if const_expr(not self.pack_gqa)
2495
  else seqlen.seqlen_q * self.qhead_per_kvhead
2496
  )
2497
+ if const_expr(not self.pack_gqa or self.m_block_size % self.qhead_per_kvhead == 0):
2498
+ gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,))
2499
+ if tidx < seqlen_q - m_tile_idx * self.m_block_size:
2500
+ # This actually just works with PackGQA too
2501
+ gLSE[tidx] = lse
2502
+ else:
2503
+ idx = m_tile_idx * self.m_block_size + tidx
2504
+ if idx < seqlen_q:
2505
+ m_idx = idx // self.qhead_per_kvhead
2506
+ h_idx = idx - m_idx * self.qhead_per_kvhead
2507
+ lse_ptr_i64 = utils.elem_pointer(mLSE_cur, ((h_idx, m_idx),)).toint()
2508
+ lse_gmem_ptr = cute.make_ptr(
2509
+ mLSE_cur.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4
2510
+ )
2511
+ cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse
2512
 
2513
  # Advance to next tile
2514
+ work_tile = tile_scheduler.advance_to_next_work()
 
2515
  # End of persistent scheduler loop
2516
 
2517
  # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps
 
2650
  if const_expr(self.use_correction_warps_for_epi):
2651
  assert(not self.use_tma_O)
2652
  assert(gmem_tiled_copy_O is not None)
2653
+ cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue),
2654
  number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
2655
  mma_tile_coord_v = thr_mma.thr_idx
2656
  m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
 
2662
  def _store_O_to_gmem(
2663
  self,
2664
  sO_stage: cute.Tensor,
2665
+ gO: Optional[cute.Tensor],
2666
  mO_cur: cute.Tensor,
2667
  gmem_tiled_copy_O: cute.TiledCopy,
2668
  tidx: Int32,
 
2673
  gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
2674
  tOsO = gmem_thr_copy_O.partition_S(sO_stage)
2675
  cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
 
2676
  tOcO = gmem_thr_copy_O.partition_S(cO)
2677
  t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
2678
  tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1])
 
2688
  cute.autovec_copy(tOsO, tOrO)
2689
  # copy acc O from rmem to gmem
2690
  if const_expr(not self.pack_gqa):
2691
+ assert gO is not None
2692
+ tOgO = gmem_thr_copy_O.partition_D(gO)
2693
  for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
2694
  if (
2695
  t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0]
 
2718
  block_info: BlockInfo,
2719
  num_splits: int,
2720
  SeqlenInfoCls: Callable,
 
2721
  mma_tile_coord_v: Int32 = 0,
2722
+ tile_scheduler=None,
2723
  ):
2724
  epi_consumer_phase = Int32(0)
 
2725
  work_tile = tile_scheduler.initial_work_tile_info()
2726
  while work_tile.is_valid_tile:
2727
  m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx
 
2733
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
2734
  else:
2735
  mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
2736
+ gO = None
2737
+ if const_expr(self.use_tma_O or not self.pack_gqa):
2738
+ tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)
2739
+ gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128)
2740
+ gO = layout_utils.select(
2741
+ cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]
2742
+ ) # (128, 128, 2)
2743
+ gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]
2744
 
2745
  if const_expr(self.use_tma_O):
2746
  store_O, _, _ = copy_utils.tma_get_copy_fn(
 
2767
  pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase)
2768
  # 2. copy O0 / O1 to gmem
2769
  m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v
2770
+ gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None
2771
  self._store_O_to_gmem(
2772
+ sO[None, None, stage], gO_stage, mO_cur, gmem_tiled_copy_O,
2773
  tidx, seqlen.seqlen_q, m_tile_idx,
2774
  )
2775
  pipeline_o_epi.consumer_release_w_index(stage)
 
2777
  epi_consumer_phase ^= 1
2778
 
2779
  # Advance to next tile
2780
+ work_tile = tile_scheduler.advance_to_next_work()
2781
+
2782
+ @cute.jit
2783
+ def clc_scheduler_warp(
2784
+ self,
2785
+ tile_scheduler: TileSchedulerProtocol,
2786
+ ):
2787
+ work_tile = tile_scheduler.initial_work_tile_info()
2788
+ while work_tile.is_valid_tile:
2789
+ tile_scheduler.prefetch_next_work()
2790
+ work_tile = tile_scheduler.advance_to_next_work()
2791
+ if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE:
2792
+ fa_printf(
2793
+ 3,
2794
+ "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n",
2795
+ smid(),
2796
+ cute.arch.block_idx()[0],
2797
+ work_tile.tile_idx[0],
2798
+ work_tile.tile_idx[1],
2799
+ work_tile.tile_idx[2],
2800
+ work_tile.tile_idx[3],
2801
+ work_tile.is_valid_tile,
2802
+ )
2803
+ tile_scheduler.producer_tail()
2804
+
2805
+ @cute.jit
2806
+ def empty_warp(
2807
+ self,
2808
+ tile_scheduler: TileSchedulerProtocol,
2809
+ ):
2810
+ work_tile = tile_scheduler.initial_work_tile_info()
2811
+ while work_tile.is_valid_tile:
2812
+ work_tile = tile_scheduler.advance_to_next_work()
2813
 
2814
  def load_Q(
2815
  self,
 
2822
  pipeline_q.producer_acquire_w_index_phase(stage, phase)
2823
  load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage))
2824
 
2825
+ def load_Q_non_tma(
2826
+ self,
2827
+ mQ: cute.Tensor,
2828
+ sQ: cute.Tensor,
2829
+ gmem_tiled_copy_Q: cute.TiledCopy,
2830
+ pipeline_q: pipeline.PipelineAsync,
2831
+ tidx: Int32,
2832
+ seqlen_q: Int32,
2833
+ m_block: Int32,
2834
+ block: Int32,
2835
+ stage: int,
2836
+ phase: Int32,
2837
+ ):
2838
+ assert self.cta_group_size == 1, "cta_group_size must be 1 for non-tma Q load"
2839
+ pipeline_q.producer_acquire_w_index_phase(stage, phase)
2840
+ pack_gqa = PackGQA(
2841
+ self.m_block_size,
2842
+ self.head_dim_padded,
2843
+ self.check_hdim_oob,
2844
+ self.qhead_per_kvhead,
2845
+ )
2846
+ sQ_stage = sQ[None, None, None, stage]
2847
+ sQ_pi = cute.make_tensor(
2848
+ sQ_stage.iterator,
2849
+ cute.make_layout(
2850
+ (sQ_stage.shape[0][0], (sQ_stage.shape[0][1], sQ_stage.shape[2])),
2851
+ stride=(sQ_stage.stride[0][0], (sQ_stage.stride[0][1], sQ_stage.stride[2])),
2852
+ ),
2853
+ )
2854
+ pack_gqa.load_Q(mQ, sQ_pi, gmem_tiled_copy_Q, tidx, m_block * self.q_stage + block, seqlen_q)
2855
+ cute.arch.cp_async_commit_group()
2856
+ pipeline_q.sync_object_full.arrive_cp_async_mbarrier(stage)
2857
+
2858
  @cute.jit
2859
  def load_KV(
2860
  self,
 
2897
  else:
2898
  assert paged_kv_manager is not None
2899
  assert extra_tx_count is None
2900
+ sX_cur = sX[None, None, None, stage]
2901
+ if const_expr(self.uneven_kv_smem):
2902
+ sX_cur = self.offset_kv_smem(sX_cur, stage, phase ^ 1)
2903
+ paged_kv_manager.load_KV(block, sX_cur, K_or_V)
2904
  cute.arch.cp_async_commit_group()
2905
  pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage)
2906
 
 
2911
  # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if
2912
  # phase == 0, or left by offset if phase == 1.
2913
  offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)
2914
+ # Hint that the offset is 128-bit aligned so that
2915
+ # ptr + offset preserves the alignment needed by cp.async.
2916
+ offset = cute.assume(offset, divby=128 // self.k_dtype.width)
2917
  return cute.make_tensor(sX.iterator + offset, sX.layout)
2918
  else:
2919
  return sX
 
2923
  # warp_group_idx = utils.canonical_warp_group_idx(sync=False)
2924
  # if warp_group_idx == 0:
2925
  # cute.arch.barrier_arrive(
2926
+ # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128,
2927
  # )
2928
 
2929
  # def warp_scheduler_barrier_sync(self):
2930
  # cute.arch.barrier(
2931
+ # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False),
2932
  # number_of_threads=2 * 128
2933
  # )
2934
 
 
2936
  # cur_wg = utils.canonical_warp_group_idx(sync=False)
2937
  # next_wg = 1 - cur_wg
2938
  # cute.arch.barrier_arrive(
2939
+ # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128,
2940
  # )
2941
 
2942
  @cute.jit
build/torch-cuda/flash_fwd_sm120.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # SM120 (Blackwell GeForce / DGX Spark) forward pass.
3
+ #
4
+ # SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has
5
+ # a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses
6
+ # FlashAttentionForwardSm80 and overrides the SMEM capacity check accordingly.
7
+
8
+ import cutlass
9
+ import cutlass.utils as utils_basic
10
+
11
+ from .flash_fwd import FlashAttentionForwardSm80
12
+
13
+
14
+ class FlashAttentionForwardSm120(FlashAttentionForwardSm80):
15
+ # Keep arch = 80 to use CpAsync code paths (no TMA for output).
16
+ # The compilation target is determined by the GPU at compile time, not this field.
17
+ arch = 80
18
+
19
+ @staticmethod
20
+ def can_implement(
21
+ dtype,
22
+ head_dim,
23
+ head_dim_v,
24
+ tile_m,
25
+ tile_n,
26
+ num_stages,
27
+ num_threads,
28
+ is_causal,
29
+ Q_in_regs=False,
30
+ ) -> bool:
31
+ """Check if the kernel can be implemented on SM120.
32
+
33
+ Same logic as SM80 but uses SM120's shared memory capacity (99 KB).
34
+ """
35
+ if dtype not in [cutlass.Float16, cutlass.BFloat16]:
36
+ return False
37
+ if head_dim % 8 != 0:
38
+ return False
39
+ if head_dim_v % 8 != 0:
40
+ return False
41
+ if tile_n % 16 != 0:
42
+ return False
43
+ if num_threads % 32 != 0:
44
+ return False
45
+ # Shared memory usage: Q tile + (K tile + V tile)
46
+ smem_usage_Q = tile_m * head_dim * 2
47
+ smem_usage_K = tile_n * head_dim * num_stages * 2
48
+ smem_usage_V = tile_n * head_dim_v * num_stages * 2
49
+ smem_usage_QV = (
50
+ (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)
51
+ )
52
+ smem_usage = smem_usage_QV + smem_usage_K
53
+ # SM120 has 99 KB shared memory (vs 163 KB on SM80)
54
+ smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120")
55
+ if smem_usage > smem_capacity:
56
+ return False
57
+ if (tile_m * 2) % num_threads != 0:
58
+ return False
59
+ return True
build/torch-cuda/flash_fwd_sm90.py ADDED
@@ -0,0 +1,1534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ # SM90 (Hopper) forward pass for flash attention, extracted from flash_fwd.py.
3
+
4
+ from types import SimpleNamespace
5
+ from typing import Callable, Literal, Optional
6
+ from functools import partial
7
+
8
+ import cuda.bindings.driver as cuda
9
+
10
+ import cutlass
11
+ import cutlass.cute as cute
12
+ from cutlass import Float32, Int32, const_expr
13
+ from cutlass.cute.nvgpu import cpasync, warpgroup
14
+ from cutlass.utils import LayoutEnum
15
+ import cutlass.utils.hopper_helpers as sm90_utils_basic
16
+ from cutlass import pipeline
17
+ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
18
+ from cutlass.base_dsl.arch import Arch
19
+
20
+ from .quack import copy_utils
21
+ from .quack import layout_utils
22
+ from .quack import sm90_utils
23
+
24
+ from .cute_dsl_utils import assume_tensor_aligned
25
+ from . import utils
26
+ from .mask import AttentionMask
27
+ from .softmax import Softmax, apply_score_mod_inner
28
+ from .seqlen_info import SeqlenInfoQK
29
+ from .block_info import BlockInfo
30
+ from .block_sparsity import BlockSparseTensors
31
+ from .block_sparse_utils import (
32
+ produce_block_sparse_loads,
33
+ consume_block_sparse_loads,
34
+ )
35
+ from . import pipeline as pipeline_custom
36
+ from .pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom
37
+ from .paged_kv import PagedKVManager
38
+ from .named_barrier import NamedBarrierFwd
39
+ from .quack.cute_dsl_utils import ParamsBase
40
+ from .tile_scheduler import (
41
+ TileSchedulerArguments,
42
+ SingleTileScheduler,
43
+ SingleTileLPTScheduler,
44
+ SingleTileVarlenScheduler,
45
+ )
46
+ from cutlass.cute import FastDivmodDivisor
47
+
48
+ from .flash_fwd import FlashAttentionForwardBase
49
+
50
+
51
+ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
52
+ def __init__(
53
+ self,
54
+ *args,
55
+ intra_wg_overlap: bool = True,
56
+ mma_pv_is_rs: bool = True,
57
+ paged_kv_non_tma: bool = False,
58
+ **kwargs,
59
+ ):
60
+ super().__init__(*args, **kwargs)
61
+ self.intra_wg_overlap = intra_wg_overlap
62
+ self.mma_pv_is_rs = mma_pv_is_rs
63
+ self.buffer_align_bytes = 1024
64
+ self.use_tma_KV = not paged_kv_non_tma
65
+ assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), (
66
+ "Paged KV does not support irregular head dim"
67
+ )
68
+ self.cluster_shape_mn = (1, 1)
69
+ assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported"
70
+
71
+ def _get_smem_layout_atom(self):
72
+ sQ_layout_atom = warpgroup.make_smem_layout_atom(
73
+ sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),
74
+ self.dtype,
75
+ )
76
+ sK_layout_atom = sQ_layout_atom
77
+ sV_layout_atom = warpgroup.make_smem_layout_atom(
78
+ sm90_utils_basic.get_smem_layout_atom(
79
+ LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv
80
+ ),
81
+ self.dtype,
82
+ )
83
+ sO_layout_atom = sV_layout_atom
84
+ if not self.mma_pv_is_rs:
85
+ sP_layout_atom = warpgroup.make_smem_layout_atom(
86
+ sm90_utils_basic.get_smem_layout_atom(
87
+ LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n
88
+ ),
89
+ self.dtype,
90
+ )
91
+ else:
92
+ sP_layout_atom = None
93
+ return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom
94
+
95
+ def _get_tiled_mma(self):
96
+ tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
97
+ self.dtype,
98
+ self.dtype,
99
+ warpgroup.OperandMajorMode.K,
100
+ warpgroup.OperandMajorMode.K,
101
+ Float32,
102
+ atom_layout_mnk=(self.tile_m // 64, 1, 1),
103
+ tiler_mn=(64, self.tile_n),
104
+ )
105
+ tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
106
+ self.dtype,
107
+ self.dtype,
108
+ warpgroup.OperandMajorMode.K,
109
+ warpgroup.OperandMajorMode.MN,
110
+ Float32,
111
+ atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
112
+ tiler_mn=(64, self.tile_hdimv),
113
+ a_source=warpgroup.OperandSource.RMEM
114
+ if self.mma_pv_is_rs
115
+ else warpgroup.OperandSource.SMEM,
116
+ )
117
+ return tiled_mma_qk, tiled_mma_pv
118
+
119
+ def _get_shared_storage_cls(self):
120
+ sQ_struct, sK_struct, sV_struct = [
121
+ cute.struct.Align[
122
+ cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes
123
+ ]
124
+ for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
125
+ ]
126
+ cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))
127
+ sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]
128
+ cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0
129
+ sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]
130
+ # 1 stage * 2 for Q pipeline (full + empty), self.num_stages*2 for K, self.num_stages*2 for V,
131
+ mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, 1 * 2]
132
+ mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
133
+ mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]
134
+
135
+ @cute.struct
136
+ class SharedStorageQKV:
137
+ mbar_ptr_Q: mbar_ptr_Q_struct
138
+ mbar_ptr_K: mbar_ptr_K_struct
139
+ mbar_ptr_V: mbar_ptr_V_struct
140
+ sV: sV_struct
141
+ sQ: sQ_struct
142
+ sK: sK_struct
143
+ sP: sP_struct
144
+
145
+ @cute.struct
146
+ class SharedStorageSharedQV:
147
+ mbar_ptr_Q: mbar_ptr_Q_struct
148
+ mbar_ptr_K: mbar_ptr_K_struct
149
+ mbar_ptr_V: mbar_ptr_V_struct
150
+ sQ: sQV_struct
151
+ sK: sK_struct
152
+ sP: sP_struct
153
+
154
+ return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV
155
+
156
+ @cute.jit
157
+ def __call__(
158
+ self,
159
+ mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
160
+ mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table
161
+ mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
162
+ mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
163
+ mLSE: Optional[cute.Tensor],
164
+ softmax_scale: Float32,
165
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
166
+ mCuSeqlensK: Optional[cute.Tensor] = None,
167
+ mSeqUsedQ: Optional[cute.Tensor] = None,
168
+ mSeqUsedK: Optional[cute.Tensor] = None,
169
+ mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq)
170
+ window_size_left: Int32 | int | None = None,
171
+ window_size_right: Int32 | int | None = None,
172
+ learnable_sink: Optional[cute.Tensor] = None,
173
+ blocksparse_tensors: Optional[BlockSparseTensors] = None,
174
+ aux_tensors: Optional[list] = None,
175
+ # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
176
+ stream: cuda.CUstream = None,
177
+ ):
178
+ """Configures and launches the flash attention kernel.
179
+
180
+ mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:
181
+ (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)
182
+ """
183
+
184
+ self._check_type(
185
+ *(
186
+ t.element_type if t is not None else None
187
+ for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)
188
+ )
189
+ )
190
+
191
+ self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
192
+
193
+ mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]
194
+ QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
195
+ mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)]
196
+ KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
197
+ mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)]
198
+ LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
199
+ mLSE = (
200
+ layout_utils.select(mLSE, LSE_layout_transpose)
201
+ if const_expr(mLSE is not None)
202
+ else None
203
+ )
204
+
205
+ tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()
206
+ self.num_mma_threads = tiled_mma_qk.size
207
+ self.num_threads_per_warp_group = 128
208
+ self.num_wg_mma = self.num_mma_threads // self.num_threads_per_warp_group
209
+ assert self.num_wg_mma in [1, 2, 3]
210
+ self.num_threads = self.num_threads_per_warp_group * (self.num_wg_mma + 1)
211
+ self.num_producer_threads = 32
212
+ self.num_Q_load_threads = self.num_threads_per_warp_group # If not TMA_Q
213
+ self.num_epilogue_threads = self.num_mma_threads
214
+ self.num_mma_regs, self.num_producer_regs = {1: (256, 56), 2: (240, 24), 3: (160, 32)}[
215
+ self.num_wg_mma
216
+ ]
217
+ self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
218
+
219
+ self.use_scheduler_barrier = (
220
+ (self.num_wg_mma >= 2 and self.tile_hdim <= 128)
221
+ if const_expr(self.intra_wg_overlap)
222
+ else (self.num_wg_mma == 2)
223
+ )
224
+ self.use_tma_Q = self.arch >= Arch.sm_90 and not (
225
+ self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0
226
+ )
227
+ self.use_tma_O = self.use_tma_Q
228
+ # Producer needs more registers when doing cp.async Q or KV loads
229
+ if const_expr(self.num_wg_mma == 2 and (not self.use_tma_Q or not self.use_tma_KV)):
230
+ self.num_mma_regs, self.num_producer_regs = 224, 40
231
+ self.rescale_O_before_gemm = self.tile_hdimv > 128 and self.intra_wg_overlap
232
+ self._setup_attributes()
233
+ # TODO: we prob don't need most of what's in _setup_attributes
234
+ self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [
235
+ sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)
236
+ for mX, shape, stage in [
237
+ (mQ, (self.tile_m, self.tile_hdim), None),
238
+ (mK, (self.tile_n, self.tile_hdim), self.num_stages),
239
+ (mV, (self.tile_n, self.tile_hdimv), self.num_stages),
240
+ (mO, (self.tile_m, self.tile_hdimv), None),
241
+ ]
242
+ ]
243
+ self.sP_layout = None
244
+ if const_expr(not self.mma_pv_is_rs):
245
+ self.sP_layout = sm90_utils.make_smem_layout(
246
+ mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)
247
+ )
248
+
249
+ SharedStorage = self._get_shared_storage_cls()
250
+
251
+ mQ_og, mO_og = mQ, mO
252
+ if const_expr(self.pack_gqa):
253
+ nheads_kv = mK.shape[2]
254
+ mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2)
255
+ mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2)
256
+ if const_expr(mLSE is not None):
257
+ mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1)
258
+
259
+ # TMA
260
+ gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()
261
+ gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast
262
+ gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()
263
+ self.tma_copy_bytes = {
264
+ name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))
265
+ for name, mX, layout in [
266
+ ("Q", mQ, self.sQ_layout),
267
+ ("K", mK, self.sK_layout),
268
+ ("V", mV, self.sV_layout),
269
+ ]
270
+ }
271
+ make_tiled_tma_atom_fn = (
272
+ partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2)
273
+ if const_expr(self.pack_gqa)
274
+ else cpasync.make_tiled_tma_atom
275
+ )
276
+ tma_atom_Q, tma_tensor_Q = None, None
277
+ if const_expr(self.use_tma_Q):
278
+ tma_atom_Q, tma_tensor_Q = make_tiled_tma_atom_fn(
279
+ gmem_tiled_copy_Q,
280
+ mQ_og if const_expr(self.pack_gqa) else mQ,
281
+ self.sQ_layout,
282
+ (self.tile_m, self.tile_hdim), # No mcast
283
+ )
284
+ tma_atom_K, tma_tensor_K = None, None
285
+ tma_atom_V, tma_tensor_V = None, None
286
+ if const_expr(self.use_tma_KV):
287
+ tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(
288
+ gmem_tiled_copy_KV,
289
+ mK,
290
+ cute.select(self.sK_layout, mode=[0, 1]),
291
+ (self.tile_n, self.tile_hdim),
292
+ 1, # No mcast for now
293
+ )
294
+ tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(
295
+ gmem_tiled_copy_KV,
296
+ mV,
297
+ cute.select(self.sV_layout, mode=[0, 1]),
298
+ (self.tile_n, self.tile_hdimv),
299
+ 1, # No mcast for now
300
+ )
301
+ tma_atom_O, tma_tensor_O = None, None
302
+ if const_expr(self.use_tma_O):
303
+ mO_tma = mO_og if const_expr(self.pack_gqa) else mO
304
+ if const_expr(self.varlen_q):
305
+ mO_tma = copy_utils.create_ragged_tensor_for_tma(
306
+ mO_tma, ragged_dim=0, ptr_shift=True
307
+ )
308
+ tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn(
309
+ gmem_tiled_copy_O,
310
+ mO_tma,
311
+ self.sO_layout,
312
+ (self.tile_m, self.tile_hdimv), # No mcast
313
+ )
314
+ if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
315
+ TileScheduler = SingleTileVarlenScheduler
316
+ else:
317
+ TileScheduler = (
318
+ SingleTileScheduler
319
+ if const_expr(not self.is_causal or self.is_local)
320
+ else SingleTileLPTScheduler
321
+ )
322
+ tile_sched_args = TileSchedulerArguments(
323
+ cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),
324
+ cute.size(mQ.shape[2]),
325
+ cute.size(mQ.shape[3])
326
+ if const_expr(mCuSeqlensQ is None)
327
+ else cute.size(mCuSeqlensQ.shape[0] - 1),
328
+ 1, # num_splits
329
+ cute.size(mK.shape[0])
330
+ if const_expr(mPageTable is None)
331
+ else mK.shape[0] * mPageTable.shape[1],
332
+ mQ.shape[1],
333
+ mV.shape[1],
334
+ total_q=cute.size(mQ.shape[0])
335
+ if const_expr(mCuSeqlensQ is not None)
336
+ else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),
337
+ tile_shape_mn=(self.tile_m, self.tile_n),
338
+ mCuSeqlensQ=mCuSeqlensQ,
339
+ mSeqUsedQ=mSeqUsedQ,
340
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
341
+ element_size=self.dtype.width // 8,
342
+ is_persistent=False,
343
+ lpt=self.is_causal or self.is_local,
344
+ )
345
+ tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
346
+ grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
347
+ softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(
348
+ softmax_scale, self.score_mod
349
+ )
350
+ window_size_left = Int32(window_size_left) if window_size_left is not None else None
351
+ window_size_right = Int32(window_size_right) if window_size_right is not None else None
352
+ fastdiv_mods = utils.compute_fastdiv_mods(
353
+ mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable
354
+ )
355
+
356
+ self.kernel(
357
+ tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,
358
+ tma_tensor_K if const_expr(self.use_tma_KV) else mK,
359
+ tma_tensor_V if const_expr(self.use_tma_KV) else mV,
360
+ tma_tensor_O if const_expr(self.use_tma_O) else mO,
361
+ mLSE,
362
+ mCuSeqlensQ,
363
+ mCuSeqlensK,
364
+ mSeqUsedQ,
365
+ mSeqUsedK,
366
+ mPageTable,
367
+ tma_atom_Q,
368
+ tma_atom_K,
369
+ tma_atom_V,
370
+ tma_atom_O,
371
+ softmax_scale_log2,
372
+ softmax_scale,
373
+ window_size_left,
374
+ window_size_right,
375
+ learnable_sink,
376
+ blocksparse_tensors,
377
+ self.sQ_layout,
378
+ self.sK_layout,
379
+ self.sV_layout,
380
+ self.sO_layout,
381
+ self.sP_layout,
382
+ self.gmem_tiled_copy_Q,
383
+ self.gmem_tiled_copy_K,
384
+ self.gmem_tiled_copy_V,
385
+ self.gmem_tiled_copy_O,
386
+ tiled_mma_qk,
387
+ tiled_mma_pv,
388
+ tile_sched_params,
389
+ TileScheduler,
390
+ SharedStorage,
391
+ aux_tensors,
392
+ fastdiv_mods,
393
+ ).launch(
394
+ grid=grid_dim,
395
+ block=[self.num_threads, 1, 1],
396
+ stream=stream,
397
+ min_blocks_per_mp=1,
398
+ )
399
+
400
+ @cute.kernel
401
+ def kernel(
402
+ self,
403
+ mQ: cute.Tensor,
404
+ mK: cute.Tensor,
405
+ mV: cute.Tensor,
406
+ mO: cute.Tensor,
407
+ mLSE: Optional[cute.Tensor],
408
+ mCuSeqlensQ: Optional[cute.Tensor],
409
+ mCuSeqlensK: Optional[cute.Tensor],
410
+ mSeqUsedQ: Optional[cute.Tensor],
411
+ mSeqUsedK: Optional[cute.Tensor],
412
+ mPageTable: Optional[cute.Tensor],
413
+ tma_atom_Q: Optional[cute.CopyAtom],
414
+ tma_atom_K: Optional[cute.CopyAtom],
415
+ tma_atom_V: Optional[cute.CopyAtom],
416
+ tma_atom_O: Optional[cute.CopyAtom],
417
+ softmax_scale_log2: Float32,
418
+ softmax_scale: Optional[Float32],
419
+ window_size_left: Optional[Int32],
420
+ window_size_right: Optional[Int32],
421
+ learnable_sink: Optional[cute.Tensor],
422
+ blocksparse_tensors: Optional[BlockSparseTensors],
423
+ sQ_layout: cute.ComposedLayout,
424
+ sK_layout: cute.ComposedLayout,
425
+ sV_layout: cute.ComposedLayout,
426
+ sO_layout: cute.ComposedLayout,
427
+ sP_layout: cute.ComposedLayout | None,
428
+ gmem_tiled_copy_Q: cute.TiledCopy,
429
+ gmem_tiled_copy_K: cute.TiledCopy,
430
+ gmem_tiled_copy_V: cute.TiledCopy,
431
+ gmem_tiled_copy_O: cute.TiledCopy,
432
+ tiled_mma_qk: cute.TiledMma,
433
+ tiled_mma_pv: cute.TiledMma,
434
+ tile_sched_params: ParamsBase,
435
+ TileScheduler: cutlass.Constexpr[Callable],
436
+ SharedStorage: cutlass.Constexpr[Callable],
437
+ aux_tensors=Optional[list[cute.Tensor]],
438
+ fastdiv_mods=None,
439
+ ):
440
+ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
441
+ # Prefetch tma descriptor
442
+ if warp_idx == 0:
443
+ for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):
444
+ if const_expr(tma_atom is not None):
445
+ cpasync.prefetch_descriptor(tma_atom)
446
+
447
+ smem = cutlass.utils.SmemAllocator()
448
+ storage = smem.allocate(SharedStorage)
449
+
450
+ # Mbarrier / pipeline init
451
+ mbar_ptr_Q = storage.mbar_ptr_Q.data_ptr()
452
+
453
+ ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)
454
+ tma_warp = ThreadCooperativeGroup(1)
455
+ load_threads = ThreadCooperativeGroup(self.num_threads_per_warp_group)
456
+ mma_warps = ThreadCooperativeGroup(self.num_mma_threads // cute.arch.WARP_SIZE)
457
+ if const_expr(self.use_tma_Q):
458
+ pipeline_q = pipeline_custom.PipelineTmaAsync.create(
459
+ barrier_storage=mbar_ptr_Q,
460
+ num_stages=1,
461
+ producer_group=tma_warp,
462
+ consumer_group=mma_warps,
463
+ tx_count=self.tma_copy_bytes["Q"],
464
+ defer_sync=True,
465
+ )
466
+ else:
467
+ pipeline_q = pipeline_custom.PipelineCpAsync.create(
468
+ barrier_storage=mbar_ptr_Q,
469
+ num_stages=1,
470
+ producer_group=load_threads,
471
+ consumer_group=mma_warps,
472
+ defer_sync=True,
473
+ elect_one_release=True,
474
+ syncwarp_before_release=False,
475
+ )
476
+
477
+ if const_expr(self.use_tma_KV):
478
+ pipeline_k = pipeline_custom.PipelineTmaAsync.create(
479
+ barrier_storage=storage.mbar_ptr_K.data_ptr(),
480
+ num_stages=self.num_stages,
481
+ producer_group=tma_warp,
482
+ consumer_group=mma_warps,
483
+ tx_count=self.tma_copy_bytes["K"],
484
+ defer_sync=True,
485
+ )
486
+ pipeline_v = pipeline_custom.PipelineTmaAsync.create(
487
+ barrier_storage=storage.mbar_ptr_V.data_ptr(),
488
+ num_stages=self.num_stages,
489
+ producer_group=tma_warp,
490
+ consumer_group=mma_warps,
491
+ tx_count=self.tma_copy_bytes["V"],
492
+ defer_sync=True,
493
+ )
494
+ else:
495
+ pipeline_k = pipeline_custom.PipelineCpAsync.create(
496
+ barrier_storage=storage.mbar_ptr_K.data_ptr(),
497
+ num_stages=self.num_stages,
498
+ producer_group=load_threads,
499
+ consumer_group=mma_warps,
500
+ defer_sync=True,
501
+ elect_one_release=True,
502
+ syncwarp_before_release=False,
503
+ )
504
+ pipeline_v = pipeline_custom.PipelineCpAsync.create(
505
+ barrier_storage=storage.mbar_ptr_V.data_ptr(),
506
+ num_stages=self.num_stages,
507
+ producer_group=load_threads,
508
+ consumer_group=mma_warps,
509
+ defer_sync=True,
510
+ elect_one_release=True,
511
+ syncwarp_before_release=False,
512
+ )
513
+
514
+ # Cluster arrive after barrier init
515
+ pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
516
+
517
+ # ///////////////////////////////////////////////////////////////////////////////
518
+ # Get shared memory buffer
519
+ # ///////////////////////////////////////////////////////////////////////////////
520
+ sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)
521
+ sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
522
+ if const_expr(not self.Q_in_regs):
523
+ sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
524
+ else:
525
+ sV = storage.sQ.get_tensor(
526
+ sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type
527
+ )
528
+ # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma
529
+ sVt = layout_utils.transpose_view(sV)
530
+ sP = None
531
+ if const_expr(sP_layout is not None):
532
+ sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)
533
+ # reuse sQ's data iterator
534
+ sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)
535
+
536
+ block_info = BlockInfo(
537
+ self.tile_m,
538
+ self.tile_n,
539
+ self.is_causal,
540
+ self.is_local,
541
+ False, # is_split_kv
542
+ window_size_left,
543
+ window_size_right,
544
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
545
+ )
546
+ SeqlenInfoCls = partial(
547
+ SeqlenInfoQK.create,
548
+ seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],
549
+ seqlen_k_static=mK.shape[0]
550
+ if const_expr(mPageTable is None)
551
+ else mK.shape[0] * mPageTable.shape[1],
552
+ mCuSeqlensQ=mCuSeqlensQ,
553
+ mCuSeqlensK=mCuSeqlensK,
554
+ mSeqUsedQ=mSeqUsedQ,
555
+ mSeqUsedK=mSeqUsedK,
556
+ # Don't need to pass in tile_mn because we won't access offset_padded
557
+ )
558
+ AttentionMaskCls = partial(
559
+ AttentionMask,
560
+ self.tile_m,
561
+ self.tile_n,
562
+ window_size_left=window_size_left,
563
+ window_size_right=window_size_right,
564
+ qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
565
+ )
566
+ TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
567
+
568
+ # Cluster wait before starting
569
+ pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
570
+
571
+ if warp_idx < 4: # Producer
572
+ cute.arch.setmaxregister_decrease(self.num_producer_regs)
573
+ self.load(
574
+ mQ,
575
+ mK,
576
+ mV,
577
+ sQ,
578
+ sK,
579
+ sV,
580
+ tma_atom_Q,
581
+ tma_atom_K,
582
+ tma_atom_V,
583
+ pipeline_k,
584
+ pipeline_v,
585
+ pipeline_q,
586
+ gmem_tiled_copy_Q,
587
+ mPageTable,
588
+ blocksparse_tensors,
589
+ block_info,
590
+ SeqlenInfoCls,
591
+ TileSchedulerCls,
592
+ )
593
+
594
+ else: # Consumer
595
+ cute.arch.setmaxregister_increase(self.num_mma_regs)
596
+ # ///////////////////////////////////////////////////////////////////////////////
597
+ # Tile MMA compute thread partitions and allocate accumulators
598
+ # ///////////////////////////////////////////////////////////////////////////////
599
+ tidx, _, _ = cute.arch.thread_idx()
600
+ tidx = tidx - 128
601
+ self.mma(
602
+ tiled_mma_qk,
603
+ tiled_mma_pv,
604
+ mO,
605
+ mLSE,
606
+ sQ,
607
+ sK,
608
+ sVt,
609
+ sP,
610
+ sO,
611
+ learnable_sink,
612
+ pipeline_k,
613
+ pipeline_v,
614
+ pipeline_q,
615
+ gmem_tiled_copy_O,
616
+ tma_atom_O,
617
+ tidx,
618
+ softmax_scale_log2,
619
+ softmax_scale,
620
+ block_info,
621
+ SeqlenInfoCls,
622
+ AttentionMaskCls,
623
+ TileSchedulerCls,
624
+ blocksparse_tensors,
625
+ aux_tensors,
626
+ fastdiv_mods,
627
+ )
628
+
629
+ @cute.jit
630
+ def load(
631
+ self,
632
+ mQ: cute.Tensor,
633
+ mK: cute.Tensor,
634
+ mV: cute.Tensor,
635
+ sQ: cute.Tensor,
636
+ sK: cute.Tensor,
637
+ sV: cute.Tensor,
638
+ tma_atom_Q: Optional[cute.CopyAtom],
639
+ tma_atom_K: Optional[cute.CopyAtom],
640
+ tma_atom_V: Optional[cute.CopyAtom],
641
+ pipeline_k: pipeline.PipelineAsync,
642
+ pipeline_v: pipeline.PipelineAsync,
643
+ pipeline_q: pipeline.PipelineAsync,
644
+ gmem_tiled_copy_Q: cute.TiledCopy,
645
+ mPageTable: Optional[cute.Tensor],
646
+ blocksparse_tensors: Optional[BlockSparseTensors],
647
+ block_info: BlockInfo,
648
+ SeqlenInfoCls: Callable,
649
+ TileSchedulerCls: Callable,
650
+ ):
651
+ warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4
652
+ tidx, _, _ = cute.arch.thread_idx()
653
+
654
+ # TMA: only warp 0 loads. cp_async: all warps load.
655
+ # When not use_tma_Q, all 128 producer threads participate in Q loading.
656
+ is_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV or not self.use_tma_Q)
657
+ # KV loading restricted to warp 0 for TMA, all warps for non-TMA KV
658
+ is_kv_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV)
659
+
660
+ if is_load_warp:
661
+ q_producer_phase = Int32(1)
662
+ kv_producer_state = pipeline.make_pipeline_state(
663
+ pipeline.PipelineUserType.Producer, self.num_stages
664
+ )
665
+ tile_scheduler = TileSchedulerCls()
666
+ work_tile = tile_scheduler.initial_work_tile_info()
667
+ while work_tile.is_valid_tile:
668
+ # if work_tile.is_valid_tile:
669
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
670
+ seqlen = SeqlenInfoCls(batch_idx)
671
+ mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
672
+ head_idx_kv = (
673
+ head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx
674
+ )
675
+
676
+ load_Q = None
677
+ if const_expr(self.use_tma_Q):
678
+ gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))
679
+ load_Q, _, _ = copy_utils.tma_get_copy_fn(
680
+ tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True
681
+ )
682
+
683
+ paged_kv_manager = None
684
+ tma_load_K_fn = None
685
+ tma_load_V_fn = None
686
+ if const_expr(self.use_tma_KV):
687
+ # === TMA path (non-paged and paged with page_size == n_block_size) ===
688
+ if const_expr(mPageTable is not None):
689
+ # Paged TMA: keep page dimension indexable
690
+ mK_cur = mK[None, None, head_idx_kv, None]
691
+ mV_cur = mV[None, None, head_idx_kv, None]
692
+ gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (0, 0, None))
693
+ gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (0, 0, None))
694
+ else:
695
+ # Non-paged TMA
696
+ mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[
697
+ None, None, head_idx_kv
698
+ ]
699
+ mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[
700
+ None, None, head_idx_kv
701
+ ]
702
+ gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))
703
+ gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
704
+ # TODO: mcast
705
+ tma_load_K_fn, _, _ = copy_utils.tma_get_copy_fn(
706
+ tma_atom_K, 0, cute.make_layout(1), gK, sK
707
+ )
708
+ tma_load_K_fn = copy_utils.tma_producer_copy_fn(tma_load_K_fn, pipeline_k)
709
+ tma_load_V_fn, _, _ = copy_utils.tma_get_copy_fn(
710
+ tma_atom_V, 0, cute.make_layout(1), gV, sV
711
+ )
712
+ tma_load_V_fn = copy_utils.tma_producer_copy_fn(tma_load_V_fn, pipeline_v)
713
+ else:
714
+ # === cp_async path (paged KV with page_size != n_block_size) ===
715
+ paged_kv_manager = PagedKVManager.create(
716
+ mPageTable,
717
+ mK,
718
+ mV,
719
+ FastDivmodDivisor(mK.shape[0]),
720
+ batch_idx,
721
+ head_idx_kv,
722
+ tidx,
723
+ seqlen.seqlen_k,
724
+ 0, # leftpad_k
725
+ self.tile_n,
726
+ self.tile_hdim,
727
+ self.tile_hdimv,
728
+ self.num_threads_per_warp_group,
729
+ mK.element_type,
730
+ arch=self.arch.major * 10 + self.arch.minor,
731
+ )
732
+
733
+ load_K = partial(
734
+ self.load_KV,
735
+ tma_load_K_fn,
736
+ paged_kv_manager,
737
+ sK,
738
+ pipeline_kv=pipeline_k,
739
+ K_or_V="K",
740
+ )
741
+ load_V = partial(
742
+ self.load_KV,
743
+ tma_load_V_fn,
744
+ paged_kv_manager,
745
+ sV,
746
+ pipeline_kv=pipeline_v,
747
+ K_or_V="V",
748
+ )
749
+
750
+ pack_gqa = None
751
+ if const_expr(not self.use_tma_Q):
752
+ pack_gqa = PackGQA(
753
+ self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead
754
+ )
755
+
756
+ if const_expr(not self.use_block_sparsity):
757
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
758
+ # if cute.arch.thread_idx()[0] == 0:
759
+ # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max)
760
+ # Clamp n_block to 0 when n_block_max == 0 (can happen with causal
761
+ # + pack_gqa when seqlen_k < tile_n). TMA handles n_block=-1
762
+ # gracefully (fills zeros), but cp.async would crash on
763
+ # out-of-bounds page table access.
764
+ n_block = (
765
+ n_block_max - 1
766
+ if const_expr(self.use_tma_KV)
767
+ else cutlass.max(n_block_max - 1, 0)
768
+ )
769
+ page_idx = (
770
+ mPageTable[batch_idx, n_block]
771
+ if const_expr(mPageTable is not None and self.use_tma_KV)
772
+ else None
773
+ )
774
+
775
+ # First iteration: load K on pipeline_k, Q on pipeline_q
776
+ if is_kv_load_warp:
777
+ pipeline_k.producer_acquire(kv_producer_state)
778
+ if const_expr(not self.use_tma_KV):
779
+ paged_kv_manager.load_page_table(n_block)
780
+ load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)
781
+ if const_expr(self.use_tma_Q):
782
+ if warp_idx_in_wg == 0:
783
+ pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
784
+ load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0))
785
+ q_producer_phase ^= 1
786
+ else:
787
+ pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
788
+ pack_gqa.load_Q(
789
+ mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q
790
+ )
791
+ cute.arch.cp_async_commit_group()
792
+ pipeline_q.producer_commit_w_index(0)
793
+ q_producer_phase ^= 1
794
+
795
+ if is_kv_load_warp:
796
+ if const_expr(not self.intra_wg_overlap or not self.use_tma_KV):
797
+ pipeline_v.producer_acquire(kv_producer_state)
798
+ load_V(
799
+ block=n_block, producer_state=kv_producer_state, page_idx=page_idx
800
+ )
801
+ kv_producer_state.advance()
802
+ for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
803
+ n_block = n_block_max - 1 - i - 1
804
+ page_idx = (
805
+ mPageTable[batch_idx, n_block]
806
+ if const_expr(mPageTable is not None and self.use_tma_KV)
807
+ else None
808
+ )
809
+ if const_expr(not self.use_tma_KV):
810
+ paged_kv_manager.load_page_table(n_block)
811
+ pipeline_k.producer_acquire(kv_producer_state)
812
+ load_K(
813
+ block=n_block,
814
+ producer_state=kv_producer_state,
815
+ page_idx=page_idx,
816
+ )
817
+ pipeline_v.producer_acquire(kv_producer_state)
818
+ load_V(
819
+ block=n_block,
820
+ producer_state=kv_producer_state,
821
+ page_idx=page_idx,
822
+ )
823
+ kv_producer_state.advance()
824
+ else:
825
+ for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):
826
+ n_block_prev = n_block_max - i - 1
827
+ n_block = n_block_prev - 1
828
+ page_idx = (
829
+ mPageTable[batch_idx, n_block]
830
+ if const_expr(mPageTable is not None)
831
+ else None
832
+ )
833
+ page_idx_prev = (
834
+ mPageTable[batch_idx, n_block_prev]
835
+ if const_expr(mPageTable is not None)
836
+ else None
837
+ )
838
+ kv_producer_state_prev = kv_producer_state.clone()
839
+ kv_producer_state.advance()
840
+ pipeline_k.producer_acquire(kv_producer_state)
841
+ load_K(
842
+ block=n_block,
843
+ producer_state=kv_producer_state,
844
+ page_idx=page_idx,
845
+ )
846
+ pipeline_v.producer_acquire(kv_producer_state_prev)
847
+ load_V(
848
+ block=n_block_prev,
849
+ producer_state=kv_producer_state_prev,
850
+ page_idx=page_idx_prev,
851
+ )
852
+ n_block = n_block_min
853
+ page_idx = (
854
+ mPageTable[batch_idx, n_block]
855
+ if const_expr(mPageTable is not None)
856
+ else None
857
+ )
858
+ pipeline_v.producer_acquire(kv_producer_state)
859
+ load_V(
860
+ block=n_block, producer_state=kv_producer_state, page_idx=page_idx
861
+ )
862
+ kv_producer_state.advance()
863
+ else:
864
+ # Block sparsity: use TMA closures directly (not paged)
865
+ # Load Q on pipeline_q, separate from K/V pipeline
866
+ if const_expr(self.use_tma_Q):
867
+ if warp_idx_in_wg == 0:
868
+ pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
869
+ load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0))
870
+ q_producer_phase ^= 1
871
+ else:
872
+ pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)
873
+ pack_gqa.load_Q(
874
+ mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q
875
+ )
876
+ cute.arch.cp_async_commit_group()
877
+ pipeline_q.producer_commit_w_index(0)
878
+ q_producer_phase ^= 1
879
+ if is_kv_load_warp:
880
+ kv_producer_state = produce_block_sparse_loads(
881
+ blocksparse_tensors,
882
+ batch_idx,
883
+ head_idx,
884
+ m_block,
885
+ kv_producer_state,
886
+ tma_load_K_fn,
887
+ tma_load_V_fn,
888
+ pipeline_k,
889
+ pipeline_v,
890
+ self.intra_wg_overlap,
891
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
892
+ self.q_subtile_factor if self.q_subtile_factor is not None else 1,
893
+ )
894
+
895
+ tile_scheduler.prefetch_next_work()
896
+ tile_scheduler.advance_to_next_work()
897
+ work_tile = tile_scheduler.get_current_work()
898
+ # End of persistent scheduler loop
899
+
900
+ # Producer tail is only useful for cluster to avoid early exit of blocks.
901
+ # We only need producer_tail on V since that's the last that's loaded, we don't
902
+ # need it for Q (no cluster) and K.
903
+ if is_kv_load_warp:
904
+ pipeline_v.producer_tail(kv_producer_state)
905
+
906
+ @cute.jit
907
+ def load_KV(
908
+ self,
909
+ tma_load_fn: Optional[Callable],
910
+ paged_kv_manager: Optional[PagedKVManager],
911
+ sX: cute.Tensor,
912
+ block: Int32,
913
+ pipeline_kv: pipeline.PipelineAsync,
914
+ producer_state: pipeline.PipelineState,
915
+ K_or_V: Literal["K", "V"],
916
+ page_idx: Optional[Int32] = None,
917
+ ):
918
+ if const_expr(self.use_tma_KV):
919
+ src_idx = block if const_expr(page_idx is None) else page_idx
920
+ tma_load_fn(src_idx=src_idx, producer_state=producer_state)
921
+ else:
922
+ paged_kv_manager.load_KV(block, sX[None, None, producer_state.index], K_or_V)
923
+ cute.arch.cp_async_commit_group()
924
+ pipeline_kv.producer_commit(producer_state)
925
+
926
+ @cute.jit
927
+ def mma(
928
+ self,
929
+ tiled_mma_qk: cute.TiledMma,
930
+ tiled_mma_pv: cute.TiledMma,
931
+ mO: cute.Tensor,
932
+ mLSE: Optional[cute.Tensor],
933
+ sQ: cute.Tensor,
934
+ sK: cute.Tensor,
935
+ sVt: cute.Tensor,
936
+ sP: Optional[cute.Tensor],
937
+ sO: cute.Tensor,
938
+ learnable_sink: Optional[cute.Tensor],
939
+ pipeline_k: pipeline.PipelineAsync,
940
+ pipeline_v: pipeline.PipelineAsync,
941
+ pipeline_q: pipeline.PipelineAsync,
942
+ gmem_tiled_copy_O: cute.TiledCopy,
943
+ tma_atom_O: Optional[cute.CopyAtom],
944
+ tidx: Int32,
945
+ softmax_scale_log2: Float32,
946
+ softmax_scale: Optional[Float32],
947
+ block_info: BlockInfo,
948
+ SeqlenInfoCls: Callable,
949
+ AttentionMaskCls: Callable,
950
+ TileSchedulerCls: Callable,
951
+ blocksparse_tensors: Optional[BlockSparseTensors],
952
+ aux_tensors: Optional[list],
953
+ fastdiv_mods=None,
954
+ ):
955
+ warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
956
+ warp_group_thread_layout = cute.make_layout(
957
+ self.num_wg_mma, stride=self.num_threads_per_warp_group
958
+ )
959
+ thr_mma_qk = tiled_mma_qk.get_slice(tidx)
960
+ wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))
961
+ wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))
962
+ _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(
963
+ wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK
964
+ )
965
+ mma_qk_fn = partial(
966
+ sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK
967
+ )
968
+ acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC(
969
+ wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt
970
+ )
971
+ mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)
972
+
973
+ # ///////////////////////////////////////////////////////////////////////////////
974
+ # Smem copy atom tiling
975
+ # ///////////////////////////////////////////////////////////////////////////////
976
+ smem_copy_atom_P = utils.get_smem_store_atom(
977
+ self.arch.major * 10 + self.arch.minor, self.dtype
978
+ )
979
+ smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)
980
+ tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None
981
+ smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)
982
+
983
+ self.mma_init()
984
+
985
+ q_consumer_phase = Int32(0)
986
+ kv_consumer_state = pipeline.make_pipeline_state(
987
+ pipeline.PipelineUserType.Consumer, self.num_stages
988
+ )
989
+
990
+ tile_scheduler = TileSchedulerCls()
991
+ work_tile = tile_scheduler.initial_work_tile_info()
992
+ softmax = Softmax.create(
993
+ softmax_scale_log2,
994
+ num_rows=acc_O.shape[0][0] * acc_O.shape[1],
995
+ softmax_scale=softmax_scale,
996
+ )
997
+
998
+ # For RescaleOBeforeGemm: persistent scores_scale across iterations
999
+ scores_scale = None
1000
+ if const_expr(self.rescale_O_before_gemm):
1001
+ scores_scale = cute.make_rmem_tensor_like(softmax.row_max, Float32)
1002
+
1003
+ mma_one_n_block_all = partial(
1004
+ self.mma_one_n_block_intrawg_overlap
1005
+ if const_expr(self.intra_wg_overlap)
1006
+ else self.mma_one_n_block,
1007
+ mma_qk_fn=mma_qk_fn,
1008
+ pipeline_k=pipeline_k,
1009
+ pipeline_v=pipeline_v,
1010
+ acc_O=acc_O,
1011
+ tOrP=tOrP,
1012
+ smem_copy_params=smem_copy_params,
1013
+ check_inf=True,
1014
+ scores_scale=scores_scale,
1015
+ )
1016
+
1017
+ process_first_half_block = partial(
1018
+ self.first_half_block_overlap,
1019
+ mma_qk_fn=mma_qk_fn,
1020
+ pipeline_k=pipeline_k,
1021
+ tOrP=tOrP,
1022
+ smem_copy_params=smem_copy_params,
1023
+ scores_scale=scores_scale,
1024
+ softmax=softmax,
1025
+ acc_O=acc_O,
1026
+ )
1027
+ process_last_half_block = partial(
1028
+ self.last_half_block_overlap,
1029
+ pipeline_v=pipeline_v,
1030
+ mma_pv_fn=mma_pv_fn,
1031
+ scores_scale=scores_scale,
1032
+ softmax=softmax,
1033
+ acc_O=acc_O,
1034
+ )
1035
+ while work_tile.is_valid_tile:
1036
+ # if work_tile.is_valid_tile:
1037
+
1038
+ # shape: (atom_v_m * rest_m)
1039
+ m_block, head_idx, batch_idx, _ = work_tile.tile_idx
1040
+ seqlen = SeqlenInfoCls(batch_idx)
1041
+
1042
+ # Recompute fastdiv_mods if necessary for varlen with aux_tensors
1043
+ recompute_fastdiv_mods_q = cutlass.const_expr(
1044
+ aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
1045
+ )
1046
+ recompute_fastdiv_mods_k = cutlass.const_expr(
1047
+ aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
1048
+ )
1049
+ if cutlass.const_expr(fastdiv_mods is not None):
1050
+ seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
1051
+ fastdiv_mods = (
1052
+ seqlen_q_divmod
1053
+ if not recompute_fastdiv_mods_q
1054
+ else FastDivmodDivisor(seqlen.seqlen_q),
1055
+ seqlen_k_divmod
1056
+ if not recompute_fastdiv_mods_k
1057
+ else FastDivmodDivisor(seqlen.seqlen_k),
1058
+ )
1059
+
1060
+ mask = AttentionMaskCls(seqlen)
1061
+ mask_fn = partial(
1062
+ mask.apply_mask,
1063
+ batch_idx=batch_idx,
1064
+ head_idx=head_idx,
1065
+ m_block=m_block,
1066
+ thr_mma=thr_mma_qk,
1067
+ mask_causal=self.is_causal,
1068
+ mask_local=self.is_local,
1069
+ aux_tensors=aux_tensors,
1070
+ fastdiv_mods=fastdiv_mods,
1071
+ )
1072
+ score_mod_fn = None
1073
+ if const_expr(self.score_mod is not None):
1074
+ score_mod_fn = partial(
1075
+ self.apply_score_mod,
1076
+ thr_mma_qk,
1077
+ batch_idx,
1078
+ head_idx,
1079
+ m_block,
1080
+ softmax_scale=softmax_scale,
1081
+ aux_tensors=aux_tensors,
1082
+ fastdiv_mods=fastdiv_mods,
1083
+ )
1084
+ mma_one_n_block = partial(
1085
+ mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn
1086
+ )
1087
+ n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
1088
+ pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase)
1089
+ # For performance reason, we separate out two kinds of iterations:
1090
+ # those that need masking on S, and those that don't.
1091
+ # We need masking on S for the very last block when K and V has length not multiple of tile_n.
1092
+ # We also need masking on S if it's causal, for the last several blocks.
1093
+ # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True
1094
+ O_should_accumulate = False
1095
+
1096
+ # ==========================================
1097
+ # MAINLOOP
1098
+ # ==========================================
1099
+ if const_expr(not self.use_block_sparsity):
1100
+ # ==========================================
1101
+ # No block-sparsity (original path)
1102
+ # ==========================================
1103
+ # First iteration with seqlen masking
1104
+ if const_expr(self.intra_wg_overlap):
1105
+ kv_consumer_state = process_first_half_block(
1106
+ n_block=n_block_max - 1,
1107
+ seqlen=seqlen,
1108
+ kv_consumer_state=kv_consumer_state,
1109
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
1110
+ score_mod_fn=score_mod_fn,
1111
+ is_first_block=True,
1112
+ )
1113
+ else:
1114
+ self.warp_scheduler_barrier_sync()
1115
+ kv_consumer_state = mma_one_n_block(
1116
+ kv_consumer_state,
1117
+ n_block=n_block_max - 1,
1118
+ seqlen=seqlen,
1119
+ mma_pv_fn=partial(mma_pv_fn, zero_init=True),
1120
+ is_first_n_block=True,
1121
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
1122
+ )
1123
+ O_should_accumulate = True
1124
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min)
1125
+ n_block_max -= 1
1126
+ # Next couple of iterations with causal masking
1127
+ if const_expr(self.is_causal or self.is_local):
1128
+ n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(
1129
+ seqlen, m_block, n_block_min
1130
+ )
1131
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask)
1132
+ for n_tile in cutlass.range(
1133
+ n_block_max - n_block_min_causal_local_mask, unroll=1
1134
+ ):
1135
+ kv_consumer_state = mma_one_n_block(
1136
+ kv_consumer_state,
1137
+ n_block=n_block_max - 1 - n_tile,
1138
+ seqlen=seqlen,
1139
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
1140
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
1141
+ )
1142
+ O_should_accumulate = True
1143
+ n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)
1144
+ # The remaining iterations have no masking
1145
+ n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(
1146
+ seqlen, m_block, n_block_min
1147
+ )
1148
+ # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min)
1149
+ for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):
1150
+ kv_consumer_state = mma_one_n_block(
1151
+ kv_consumer_state,
1152
+ n_block=n_block_max - 1 - n_tile,
1153
+ seqlen=seqlen,
1154
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
1155
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
1156
+ )
1157
+ O_should_accumulate = True
1158
+ # Separate iterations with local masking on the left
1159
+ if const_expr(self.is_local and block_info.window_size_left is not None):
1160
+ n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)
1161
+ for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):
1162
+ kv_consumer_state = mma_one_n_block(
1163
+ kv_consumer_state,
1164
+ n_block=n_block_max - 1 - n_tile,
1165
+ seqlen=seqlen,
1166
+ mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
1167
+ mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
1168
+ )
1169
+ O_should_accumulate = True
1170
+ # Release Q pipeline so the producer can load the next tile's Q
1171
+ pipeline_q.consumer_release_w_index(0)
1172
+ # Last "half" iteration
1173
+ if const_expr(self.intra_wg_overlap):
1174
+ kv_consumer_state = process_last_half_block(
1175
+ kv_consumer_state=kv_consumer_state,
1176
+ zero_init=not O_should_accumulate,
1177
+ )
1178
+ O_should_accumulate = True
1179
+ else:
1180
+ self.warp_scheduler_barrier_arrive()
1181
+
1182
+ else:
1183
+ # ==========================================
1184
+ # Block sparsity
1185
+ # ==========================================
1186
+ kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(
1187
+ blocksparse_tensors,
1188
+ batch_idx,
1189
+ head_idx,
1190
+ m_block,
1191
+ seqlen,
1192
+ kv_consumer_state,
1193
+ mma_pv_fn,
1194
+ mma_one_n_block,
1195
+ process_first_half_block,
1196
+ process_last_half_block,
1197
+ mask_fn,
1198
+ score_mod_fn,
1199
+ O_should_accumulate,
1200
+ self.mask_mod,
1201
+ fastdiv_mods,
1202
+ self.intra_wg_overlap,
1203
+ self.warp_scheduler_barrier_sync,
1204
+ self.warp_scheduler_barrier_arrive,
1205
+ self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1206
+ self.q_subtile_factor if self.q_subtile_factor is not None else 1,
1207
+ )
1208
+
1209
+ # Release Q pipeline so the producer can load the next tile's Q
1210
+ pipeline_q.consumer_release_w_index(0)
1211
+
1212
+ # Handle empty case (when no blocks to process)
1213
+ if not processed_any:
1214
+ softmax.reset()
1215
+ acc_O.fill(0.0)
1216
+
1217
+ q_consumer_phase ^= 1
1218
+
1219
+ sink_val = None
1220
+ if const_expr(learnable_sink is not None):
1221
+ if const_expr(not self.pack_gqa):
1222
+ sink_val = Float32(learnable_sink[head_idx])
1223
+ else: # Each thread might have a different sink value due to different q_head
1224
+ sink_val = cute.make_rmem_tensor_like(softmax.row_max, Float32)
1225
+ cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
1226
+ tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))
1227
+ for r in cutlass.range(cute.size(sink_val), unroll_full=True):
1228
+ row = m_block * self.tile_m + tScS_mn[r][0]
1229
+ q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
1230
+ sink_val[r] = Float32(learnable_sink[q_head_idx])
1231
+
1232
+ # normalize acc_O by row_sum and calculate the lse
1233
+ row_scale = softmax.finalize(sink_val=sink_val)
1234
+ softmax.rescale_O(acc_O, row_scale)
1235
+
1236
+ # ///////////////////////////////////////////////////////////////////////////////
1237
+ # Epilogue
1238
+ # ///////////////////////////////////////////////////////////////////////////////
1239
+ self.epilogue(
1240
+ acc_O,
1241
+ softmax.row_sum,
1242
+ mO,
1243
+ mLSE,
1244
+ sO,
1245
+ seqlen,
1246
+ gmem_tiled_copy_O,
1247
+ tma_atom_O,
1248
+ tiled_mma_pv,
1249
+ tidx,
1250
+ m_block,
1251
+ head_idx,
1252
+ batch_idx,
1253
+ )
1254
+
1255
+ tile_scheduler.advance_to_next_work()
1256
+ work_tile = tile_scheduler.get_current_work()
1257
+
1258
+ @cute.jit
1259
+ def first_half_block_overlap(
1260
+ self,
1261
+ n_block: Int32,
1262
+ mma_qk_fn: Callable,
1263
+ kv_consumer_state,
1264
+ pipeline_k,
1265
+ tOrP: cute.Tensor,
1266
+ smem_copy_params: SimpleNamespace,
1267
+ softmax: Softmax,
1268
+ seqlen: SeqlenInfoQK,
1269
+ scores_scale: Optional[cute.Tensor] = None,
1270
+ acc_O: Optional[cute.Tensor] = None,
1271
+ mask_fn: Callable = None,
1272
+ score_mod_fn: Optional[Callable] = None,
1273
+ is_first_block: bool = False,
1274
+ ):
1275
+ """Processes the first half block when using intra-warpgroup-overlap"""
1276
+
1277
+ pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))
1278
+ acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)
1279
+ pipeline_k.consumer_release(kv_consumer_state)
1280
+
1281
+ # Apply score modification if present
1282
+ if const_expr(score_mod_fn is not None):
1283
+ score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
1284
+
1285
+ # Apply mask; mask_seqlen always True for first block
1286
+ # Caveat: if full block further right than mask block, seqlen masking is redundant;
1287
+ # however, masking is being applied anyway, so essentially no perf hit
1288
+ mask_fn(acc_S, n_block=n_block, mask_seqlen=True)
1289
+
1290
+ row_scale = softmax.online_softmax(acc_S, is_first=is_first_block)
1291
+
1292
+ tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
1293
+ tOrP_cur = (
1294
+ tOrP
1295
+ if const_expr(self.mma_pv_is_rs)
1296
+ else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)
1297
+ )
1298
+ tOrP_cur.store(tOrP_acc.load().to(self.dtype))
1299
+
1300
+ if const_expr(not self.mma_pv_is_rs):
1301
+ tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
1302
+ cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
1303
+ # Fence and barrier to make smem store visible to WGMMA
1304
+ cute.arch.fence_view_async_shared()
1305
+ cute.arch.sync_warp()
1306
+
1307
+ # For RescaleOBeforeGemm: initialize acc_O
1308
+ if const_expr(self.rescale_O_before_gemm):
1309
+ acc_O.fill(0.0)
1310
+ scores_scale.store(row_scale.load())
1311
+
1312
+ return kv_consumer_state
1313
+
1314
+ @cute.jit
1315
+ def last_half_block_overlap(
1316
+ self,
1317
+ kv_consumer_state,
1318
+ pipeline_v,
1319
+ mma_pv_fn: Callable,
1320
+ zero_init: bool,
1321
+ scores_scale: Optional[cute.Tensor] = None,
1322
+ softmax: Optional[Softmax] = None,
1323
+ acc_O: Optional[cute.Tensor] = None,
1324
+ ):
1325
+ """Processes the final PV GEMM when using intra-warpgroup-overlap"""
1326
+
1327
+ # For RescaleOBeforeGemm: rescale O before the final PV GEMM
1328
+ if const_expr(self.rescale_O_before_gemm):
1329
+ softmax.rescale_O(acc_O, scores_scale)
1330
+
1331
+ pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))
1332
+ mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)
1333
+ pipeline_v.consumer_release(kv_consumer_state)
1334
+ kv_consumer_state.advance()
1335
+ return kv_consumer_state
1336
+
1337
+ @cute.jit
1338
+ def mma_one_n_block(
1339
+ self,
1340
+ smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple,
1341
+ n_block: Int32,
1342
+ mma_qk_fn: Callable,
1343
+ mma_pv_fn: Callable,
1344
+ pipeline_k: pipeline.PipelineAsync,
1345
+ pipeline_v: pipeline.PipelineAsync,
1346
+ acc_O: cute.Tensor,
1347
+ tOrP: cute.Tensor,
1348
+ smem_copy_params: SimpleNamespace,
1349
+ softmax: Softmax,
1350
+ seqlen: SeqlenInfoQK,
1351
+ scores_scale: Optional[cute.Tensor] = None, # not used
1352
+ score_mod_fn: Optional[Callable] = None,
1353
+ mask_fn: Optional[Callable] = None,
1354
+ is_first_n_block: cutlass.Constexpr = False,
1355
+ check_inf: cutlass.Constexpr = True,
1356
+ ):
1357
+ pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
1358
+ # S = Q @ K.T
1359
+ acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
1360
+ self.warp_scheduler_barrier_arrive()
1361
+ warpgroup.wait_group(0)
1362
+ pipeline_k.consumer_release(smem_pipe_read)
1363
+
1364
+ # handle score mods and masking
1365
+ if const_expr(score_mod_fn is not None):
1366
+ score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
1367
+ if const_expr(mask_fn is not None):
1368
+ mask_fn(acc_S=acc_S, n_block=n_block)
1369
+
1370
+ row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
1371
+ # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
1372
+ tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
1373
+ tOrP_cur = (
1374
+ tOrP
1375
+ if const_expr(self.mma_pv_is_rs)
1376
+ else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)
1377
+ )
1378
+ # tOrP.store(tOrP_acc.load().to(self.dtype))
1379
+ # the "to(self.dtype)" conversion fails to vectorize for block sizes other
1380
+ # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
1381
+ # 2 elements. So we just call ptx directly.
1382
+ utils.cvt_f16(tOrP_acc, tOrP_cur)
1383
+ if const_expr(not self.mma_pv_is_rs):
1384
+ tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
1385
+ cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
1386
+ softmax.rescale_O(acc_O, row_scale)
1387
+ if const_expr(not self.mma_pv_is_rs):
1388
+ # Fence and barrier to make sure smem store is visible to WGMMA
1389
+ cute.arch.fence_view_async_shared()
1390
+ cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
1391
+ pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
1392
+ self.warp_scheduler_barrier_sync()
1393
+ # O += P @ V
1394
+ mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)
1395
+ pipeline_v.consumer_release(smem_pipe_read)
1396
+ smem_pipe_read.advance()
1397
+ return smem_pipe_read
1398
+
1399
+ @cute.jit
1400
+ def mma_one_n_block_intrawg_overlap(
1401
+ self,
1402
+ smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple,
1403
+ n_block: Int32,
1404
+ mma_qk_fn: Callable,
1405
+ mma_pv_fn: Callable,
1406
+ pipeline_k: pipeline.PipelineAsync,
1407
+ pipeline_v: pipeline.PipelineAsync,
1408
+ acc_O: cute.Tensor,
1409
+ tOrP: cute.Tensor,
1410
+ smem_copy_params: SimpleNamespace,
1411
+ softmax: Softmax,
1412
+ seqlen: SeqlenInfoQK,
1413
+ scores_scale: Optional[cute.Tensor] = None,
1414
+ score_mod_fn: Optional[Callable] = None,
1415
+ mask_fn: Optional[Callable] = None,
1416
+ check_inf: cutlass.Constexpr = True,
1417
+ ):
1418
+ smem_pipe_read_v = smem_pipe_read.clone()
1419
+ smem_pipe_read.advance()
1420
+ pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))
1421
+ self.warp_scheduler_barrier_sync()
1422
+ # S = Q @ K.T
1423
+ acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)
1424
+ # RescaleOBeforeGemm: rescale O while QK GEMM is in flight, before PV GEMM
1425
+ if const_expr(self.rescale_O_before_gemm):
1426
+ softmax.rescale_O(acc_O, scores_scale)
1427
+ pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))
1428
+ # O += P @ V
1429
+ mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)
1430
+ self.warp_scheduler_barrier_arrive()
1431
+ warpgroup.wait_group(1)
1432
+ pipeline_k.consumer_release(smem_pipe_read)
1433
+
1434
+ # handle score mods and masking
1435
+ if const_expr(score_mod_fn is not None):
1436
+ score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
1437
+ if const_expr(mask_fn is not None):
1438
+ mask_fn(acc_S=acc_S, n_block=n_block)
1439
+ # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
1440
+
1441
+ row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
1442
+ warpgroup.wait_group(0)
1443
+ pipeline_v.consumer_release(smem_pipe_read_v)
1444
+ tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
1445
+ tOrP_cur = (
1446
+ tOrP
1447
+ if const_expr(self.mma_pv_is_rs)
1448
+ else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)
1449
+ )
1450
+ # tOrP_cur.store(tOrP_acc.load().to(self.dtype))
1451
+ # the "to(self.dtype)" conversion fails to vectorize for block sizes other
1452
+ # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of
1453
+ # 2 elements. So we just call ptx directly.
1454
+ utils.cvt_f16(tOrP_acc, tOrP_cur)
1455
+ if const_expr(not self.mma_pv_is_rs):
1456
+ tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
1457
+ cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
1458
+ if const_expr(not self.rescale_O_before_gemm):
1459
+ softmax.rescale_O(acc_O, row_scale)
1460
+ if const_expr(self.rescale_O_before_gemm):
1461
+ scores_scale.store(row_scale.load())
1462
+ if const_expr(not self.mma_pv_is_rs):
1463
+ # Fence and barrier to make sure smem store is visible to WGMMA
1464
+ cute.arch.fence_view_async_shared()
1465
+ cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
1466
+ return smem_pipe_read
1467
+
1468
+ @cute.jit
1469
+ def mma_init(self):
1470
+ warp_group_idx = utils.canonical_warp_group_idx(sync=False)
1471
+ if const_expr(self.use_scheduler_barrier):
1472
+ if warp_group_idx == 1:
1473
+ cute.arch.barrier_arrive(
1474
+ barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),
1475
+ number_of_threads=2 * self.num_threads_per_warp_group,
1476
+ )
1477
+
1478
+ @cute.jit
1479
+ def apply_score_mod(
1480
+ self,
1481
+ thr_mma_qk,
1482
+ batch_idx,
1483
+ head_idx,
1484
+ m_block,
1485
+ acc_S,
1486
+ n_block,
1487
+ softmax_scale,
1488
+ seqlen,
1489
+ aux_tensors: Optional[list] = None,
1490
+ fastdiv_mods=None,
1491
+ ):
1492
+ # Prepare index tensor
1493
+ cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
1494
+ cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
1495
+ tScS = thr_mma_qk.partition_C(cS)
1496
+
1497
+ apply_score_mod_inner(
1498
+ acc_S,
1499
+ tScS,
1500
+ self.score_mod,
1501
+ batch_idx,
1502
+ head_idx,
1503
+ softmax_scale,
1504
+ self.vec_size,
1505
+ self.qk_acc_dtype,
1506
+ aux_tensors,
1507
+ fastdiv_mods,
1508
+ seqlen_info=seqlen,
1509
+ constant_q_idx=None,
1510
+ qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
1511
+ )
1512
+
1513
+ def warp_scheduler_barrier_sync(self):
1514
+ if const_expr(self.use_scheduler_barrier):
1515
+ cute.arch.barrier(
1516
+ barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)
1517
+ - 1
1518
+ + utils.canonical_warp_group_idx(sync=False),
1519
+ number_of_threads=2 * self.num_threads_per_warp_group,
1520
+ )
1521
+
1522
+ def warp_scheduler_barrier_arrive(self):
1523
+ if const_expr(self.use_scheduler_barrier):
1524
+ assert self.num_wg_mma in [2, 3]
1525
+ cur_wg = utils.canonical_warp_group_idx(sync=False) - 1
1526
+ if const_expr(self.num_wg_mma == 2):
1527
+ next_wg = 1 - cur_wg
1528
+ else:
1529
+ t = cur_wg + 1
1530
+ next_wg = t % self.num_wg_mma
1531
+ cute.arch.barrier_arrive(
1532
+ barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
1533
+ number_of_threads=2 * self.num_threads_per_warp_group,
1534
+ )
build/torch-cuda/interface.py CHANGED
@@ -21,6 +21,7 @@
21
 
22
  import os
23
  import math
 
24
  from functools import lru_cache
25
  from typing import Optional, Tuple, Callable
26
 
@@ -31,6 +32,8 @@ import cuda.bindings.driver as cuda
31
 
32
  import cutlass
33
  import cutlass.cute as cute
 
 
34
  from .cache_utils import get_jit_cache
35
  from .testing import is_fake_mode
36
 
@@ -43,30 +46,201 @@ if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
43
 
44
 
45
  from . import utils
 
46
  from .cute_dsl_utils import (
47
  to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,
48
  )
49
- from .flash_fwd import FlashAttentionForwardSm90
 
50
  from .flash_fwd_sm100 import FlashAttentionForwardSm100
 
51
  from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess
52
  from .flash_bwd import FlashAttentionBackwardSm80
53
  from .flash_bwd_sm90 import FlashAttentionBackwardSm90
54
  from .flash_bwd_sm100 import FlashAttentionBackwardSm100
 
55
  from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess
56
  from .flash_fwd_combine import FlashAttentionForwardCombine
57
 
58
  from .block_sparsity import (
59
  BlockSparseTensorsTorch,
 
60
  to_cute_block_sparse_tensors,
61
  normalize_block_sparse_config,
62
  normalize_block_sparse_config_bwd,
63
  )
64
 
 
 
 
 
 
 
 
 
 
 
65
  @lru_cache(maxsize=None)
66
  def _get_device_arch():
67
- """Cached device arch check."""
 
 
 
 
 
 
 
 
 
 
 
 
68
  major, minor = torch.cuda.get_device_capability()
69
- return major * 10 + minor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def maybe_contiguous(x):
72
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
@@ -76,7 +250,8 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
76
  assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
77
  assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
78
  assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
79
- assert t.is_cuda, f"{name} must be on CUDA"
 
80
 
81
 
82
  torch2cute_dtype_map = {
@@ -96,6 +271,29 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):
96
  return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def _flash_attn_fwd(
100
  q: torch.Tensor,
101
  k: torch.Tensor,
@@ -113,11 +311,9 @@ def _flash_attn_fwd(
113
  window_size_left: Optional[int] = None,
114
  window_size_right: Optional[int] = None,
115
  learnable_sink: Optional[torch.Tensor] = None,
116
- # m_block_size: int = 128,
117
- # n_block_size: int = 64,
118
- # num_threads: int = 128,
119
- m_block_size: int = 128,
120
- n_block_size: int = 128,
121
  num_threads: int = 384,
122
  num_splits: int = 1,
123
  pack_gqa: Optional[bool] = None,
@@ -138,7 +334,7 @@ def _flash_attn_fwd(
138
  mask_mod: A callable that takes token position information and selectively masks
139
  block_sparse_tensors: A tuple of tensors used for block sparsity.
140
  return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
141
- Note: the returned LSE currently does not support taking gradient.
142
  out: Optional pre-allocated output tensor. If None, will be allocated internally.
143
  lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
144
  aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
@@ -203,25 +399,27 @@ def _flash_attn_fwd(
203
  assert learnable_sink.shape == (num_head,)
204
  assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
205
 
206
- assert all(
207
- t is None or t.is_cuda
208
- for t in (
209
- q,
210
- k,
211
- v,
212
- cu_seqlens_q,
213
- cu_seqlens_k,
214
- seqused_q,
215
- seqused_k,
216
- page_table,
217
- learnable_sink,
218
- )
219
- ), "inputs must be on CUDA device"
 
 
 
220
  assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
221
- assert head_dim <= 256, "head_dim must be less than or equal to 256"
222
  alignment = 16 // q.element_size()
223
- assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
224
- assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
225
  if softmax_scale is None:
226
  softmax_scale = 1.0 / math.sqrt(head_dim)
227
  if softcap == 0.0:
@@ -253,43 +451,47 @@ def _flash_attn_fwd(
253
  _validate_tensor(lse, "lse", lse_shape, torch.float32, device)
254
 
255
  dtype = torch2cute_dtype_map[q.dtype]
256
- arch = _get_device_arch() if _arch is None else _arch
 
 
 
 
257
 
258
- assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
 
259
 
260
- use_block_sparsity = block_sparse_tensors is not None
261
 
262
- if mask_mod is None:
263
- if causal:
264
- window_size_right = 0
265
- if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
266
- window_size_left = None
267
- window_size_right = None
268
- local = window_size_left is not None or window_size_right is not None
269
- if window_size_left is not None or window_size_right is not None:
270
- if window_size_left is None and window_size_right == 0:
271
- causal, local = True, False
272
- window_size_right = None
 
273
  else:
274
- causal, local = False, True
 
 
 
 
 
275
  else:
276
- causal, local = False, False
277
-
278
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
 
 
 
279
 
280
- if arch // 10 == 9: # TODO: tune block size according to hdim.
281
- if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
282
- n_block_size = 192
283
-
284
- if arch // 10 in [10, 11]:
285
- if (
286
- pack_gqa
287
- and (128 % qhead_per_kvhead != 0)
288
- ):
289
- pack_gqa = False
290
- # TODO: fix GQA + SplitKV + non-varlen
291
- if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
292
- pack_gqa = False
293
 
294
  if max_seqlen_q is None:
295
  max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
@@ -297,28 +499,50 @@ def _flash_attn_fwd(
297
  max_seqlen_k = seqlen_k
298
  seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
299
  if arch // 10 == 10:
300
- q_stage = 2 if seqlen_q_packgqa > m_block_size else 1
301
  else:
302
  q_stage = 1
303
 
 
 
 
 
 
 
304
  if num_splits < 1:
305
- m_block_size_effective = q_stage * m_block_size
306
- seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size))
307
- num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size
308
- num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
309
- total_mblocks = batch_size * num_head_kv * num_m_blocks
310
- num_splits = num_splits_heuristic(
311
- total_mblocks,
312
- torch.cuda.get_device_properties(device).multi_processor_count,
313
- num_n_blocks,
314
- 128,
315
- )
316
 
317
  is_split_kv = num_splits > 1
318
  if is_split_kv:
319
  out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
320
  lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  # hash score and mask mods for compile cache
323
  score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
324
  mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
@@ -370,14 +594,14 @@ def _flash_attn_fwd(
370
  num_head=num_head,
371
  seqlen_q=seqlen_q,
372
  seqlen_k=seqlen_k,
373
- block_size=(m_block_size, n_block_size),
374
  q_stage=q_stage,
375
  )
376
- if aux_tensors is not None:
377
  aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
378
  else:
379
  aux_tensor_metadata = None
380
-
381
  compile_key = (
382
  dtype,
383
  head_dim,
@@ -398,15 +622,20 @@ def _flash_attn_fwd(
398
  window_size_left is not None,
399
  window_size_right is not None,
400
  learnable_sink is not None,
401
- m_block_size,
402
- n_block_size,
403
  q_stage,
404
  num_threads,
405
  is_split_kv,
406
  pack_gqa,
407
  arch,
408
- page_size not in [None, 128], # paged KV non-TMA
 
409
  q_subtile_factor,
 
 
 
 
410
  )
411
  if compile_key not in _flash_attn_fwd.compile_cache:
412
  (
@@ -445,10 +674,28 @@ def _flash_attn_fwd(
445
  if aux_tensors is not None:
446
  cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]
447
 
448
- if arch // 10 == 9:
449
- assert page_table is None, "paged KV not supported on SM 9.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  assert not is_split_kv, "SplitKV not supported on SM 9.0"
451
- # fa_fwd = FlashAttentionForwardSm80(
452
  fa_fwd = FlashAttentionForwardSm90(
453
  dtype,
454
  head_dim,
@@ -457,33 +704,21 @@ def _flash_attn_fwd(
457
  is_causal=causal,
458
  is_local=local,
459
  pack_gqa=pack_gqa,
460
- tile_m=m_block_size,
461
- tile_n=n_block_size,
462
  # num_stages=1,
463
  num_stages=2,
464
  num_threads=num_threads,
465
  Q_in_regs=False,
466
- intra_wg_overlap=True,
467
- mma_pv_is_rs=True,
468
  mask_mod=mask_mod,
469
  score_mod=score_mod,
470
  has_aux_tensors=aux_tensors is not None,
471
  q_subtile_factor=q_subtile_factor,
 
472
  )
473
  elif arch // 10 in [10, 11]:
474
- head_dim_padded = int(math.ceil(head_dim / 16) * 16)
475
- head_dim_v_padded = int(math.ceil(head_dim / 16) * 16)
476
- use_2cta_instrs = (
477
- not causal
478
- and not local
479
- and not is_split_kv
480
- and cu_seqlens_q is None
481
- and seqused_q is None
482
- and not use_block_sparsity
483
- and page_size in [None, 128]
484
- and head_dim_padded == 128
485
- and head_dim_v_padded == 128
486
- )
487
  fa_fwd = FlashAttentionForwardSm100(
488
  head_dim,
489
  head_dim_v,
@@ -492,8 +727,8 @@ def _flash_attn_fwd(
492
  is_local=local,
493
  is_split_kv=is_split_kv,
494
  pack_gqa=pack_gqa,
495
- m_block_size=m_block_size,
496
- n_block_size=n_block_size,
497
  q_stage=q_stage,
498
  is_persistent=not causal
499
  and not local
@@ -503,14 +738,37 @@ def _flash_attn_fwd(
503
  score_mod=score_mod,
504
  mask_mod=mask_mod,
505
  has_aux_tensors=aux_tensors is not None,
506
- paged_kv_non_tma=page_size not in [None, 128],
507
  is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
508
  q_subtile_factor=q_subtile_factor,
509
  use_2cta_instrs=use_2cta_instrs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  )
511
  else:
512
  raise ValueError(
513
- f"Unsupported compute capability: {arch}. Supported: 9.x, 10.x, 11.x"
514
  )
515
  # TODO: check @can_implement
516
  _flash_attn_fwd.compile_cache[compile_key] = cute.compile(
@@ -521,7 +779,6 @@ def _flash_attn_fwd(
521
  o_tensor,
522
  lse_tensor,
523
  softmax_scale,
524
- current_stream,
525
  cu_seqlens_q_tensor,
526
  cu_seqlens_k_tensor,
527
  seqused_q_tensor,
@@ -532,6 +789,7 @@ def _flash_attn_fwd(
532
  learnable_sink_tensor,
533
  sparse_tensors,
534
  cute_aux_tensors,
 
535
  options="--enable-tvm-ffi",
536
  )
537
 
@@ -547,7 +805,6 @@ def _flash_attn_fwd(
547
  out.detach() if not is_split_kv else out_partial,
548
  lse_partial if is_split_kv else lse,
549
  softmax_scale,
550
- current_stream,
551
  cu_seqlens_q,
552
  cu_seqlens_k,
553
  seqused_q,
@@ -574,6 +831,140 @@ def _flash_attn_fwd(
574
  _flash_attn_fwd.compile_cache = get_jit_cache("fwd")
575
 
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def _flash_attn_bwd(
578
  q: torch.Tensor,
579
  k: torch.Tensor,
@@ -614,47 +1005,74 @@ def _flash_attn_bwd(
614
  mask_mod: Optional[Callable] = None,
615
  aux_tensors: Optional[list[torch.Tensor]] = None,
616
  block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
 
617
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
618
  arch = _get_device_arch()
619
- assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
 
 
 
620
 
621
  num_head, head_dim = q.shape[-2:]
 
622
 
623
- if causal:
624
- window_size_right = 0
625
- if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
626
- window_size_left = None
627
- window_size_right = None
628
- local = window_size_left is not None or window_size_right is not None
629
- if local:
630
- if window_size_left is None and window_size_right == 0:
631
- causal, local = True, False
632
- window_size_right = None
633
- else:
634
- causal, local = False, True
635
 
636
- if arch // 10 == 9:
637
- m_block_size = 80 if not causal else 64
638
- n_block_size = 128
639
- num_stages_Q = 2
640
- num_stages_dO = 2
641
- num_stages_PdS = 2
642
- SdP_swapAB = True
 
 
 
 
643
  dKV_swapAB = False
644
- dQ_swapAB = not causal
645
- AtomLayoutMSdP = 1
646
- AtomLayoutNdKV = 2
647
- AtomLayoutMdQ = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  cluster_size = 1
649
  use_2cta_instrs = False
650
- assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
651
  is_varlen = (
652
  cu_seqlens_q is not None
653
  or cu_seqlens_k is not None
654
  or seqused_q is not None
655
  or seqused_k is not None
656
  )
657
- assert not is_varlen, "varlen backward is not yet supported on sm90"
658
  else:
659
  m_block_size = 128
660
  n_block_size = 128
@@ -662,15 +1080,17 @@ def _flash_attn_bwd(
662
  dKV_swapAB = False
663
  AtomLayoutMdQ = 1
664
  AtomLayoutNdKV = 1
 
665
  disable_2cta = (
666
- local
667
  or score_mod is not None
668
  or score_mod_bwd is not None
669
  or mask_mod is not None
 
670
  )
671
  cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
672
  use_2cta_instrs = cluster_size==2
673
-
674
  q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
675
  maybe_contiguous(t)
676
  for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
@@ -692,19 +1112,9 @@ def _flash_attn_bwd(
692
  seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
693
 
694
  num_head_kv = k.shape[-2]
695
- head_dim_v = v.shape[-1]
696
 
697
  use_block_sparsity = block_sparse_tensors is not None
698
-
699
- # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,
700
- # the base block_m of 128 from forward, and block-sparse size for subtiling.
701
- if arch // 10 == 9 and use_block_sparsity:
702
- m_block_size = 64
703
- # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)
704
- dQ_swapAB = False
705
-
706
- # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
707
- subtile_factor = 2
708
  seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
709
  seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
710
  num_n_blocks = seqlen_k_rounded // n_block_size
@@ -744,14 +1154,16 @@ def _flash_attn_bwd(
744
  if t is not None:
745
  assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
746
  assert lse.dtype == torch.float32, "lse must be float32"
747
- assert all(
748
- t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
749
- ), "inputs must be on CUDA device"
 
 
 
750
  assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
751
- assert head_dim <= 256, "head_dim must be less than or equal to 256"
752
  alignment = 16 // q.element_size()
753
- assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
754
- assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
755
  if softmax_scale is None:
756
  softmax_scale = 1.0 / math.sqrt(head_dim)
757
  qhead_per_kvhead = num_head // num_head_kv
@@ -759,9 +1171,6 @@ def _flash_attn_bwd(
759
  pack_gqa = qhead_per_kvhead > 1
760
  # pack_gqa backward not yet supported in bwd
761
  pack_gqa = False
762
- if arch // 10 not in [10, 11]:
763
- assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now"
764
-
765
  if score_mod is not None:
766
  assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
767
  assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
@@ -813,6 +1222,9 @@ def _flash_attn_bwd(
813
  dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
814
  lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
815
 
 
 
 
816
  dKV_postprocess = qhead_per_kvhead > 1
817
  if dKV_postprocess:
818
  head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
@@ -850,83 +1262,30 @@ def _flash_attn_bwd(
850
  )
851
 
852
  dtype = torch2cute_dtype_map[q.dtype]
853
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
854
 
855
  if deterministic:
856
- dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device="cuda")
857
  else:
858
  dQ_semaphore = None
859
 
860
  if deterministic and qhead_per_kvhead > 1:
861
- dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
862
- dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda")
863
  else:
864
  dK_semaphore = None
865
  dV_semaphore = None
866
 
867
- # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
868
- compile_key_pre = (
869
- arch,
870
- dtype,
871
- head_dim,
872
- head_dim_v,
873
- m_block_size,
874
- num_threads,
875
- cu_seqlens_q is None,
876
- seqused_q is None,
877
- get_broadcast_dims(out),
878
- get_broadcast_dims(dout),
879
  )
880
- if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
881
- o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
882
- dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
883
- to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
884
- ]
885
- lse_tensor = to_cute_tensor(lse, assumed_align=4)
886
- cu_seqlens_q_tensor, seqused_q_tensor = [
887
- to_cute_tensor(t, assumed_align=4) if t is not None else None
888
- for t in (cu_seqlens_q, seqused_q)
889
- ]
890
- fa_bwd_pre = FlashAttentionBackwardPreprocess(
891
- dtype,
892
- head_dim,
893
- head_dim_v,
894
- arch,
895
- m_block_size,
896
- num_threads=num_threads,
897
- )
898
- # TODO: check @can_implement
899
- _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile(
900
- fa_bwd_pre,
901
- o_tensor,
902
- do_tensor,
903
- dpsum_tensor,
904
- lse_tensor,
905
- lse_log2_tensor,
906
- dq_accum_tensor,
907
- cu_seqlens_q_tensor,
908
- seqused_q_tensor,
909
- current_stream,
910
- options="--enable-tvm-ffi",
911
- )
912
- if not is_fake_mode():
913
- _flash_attn_bwd.compile_cache_pre[compile_key_pre](
914
- out,
915
- dout,
916
- dpsum,
917
- lse,
918
- lse_log2,
919
- dq_accum,
920
- cu_seqlens_q,
921
- seqused_q,
922
- current_stream,
923
- )
924
-
925
- # NB num_threads application for 3 kernels
926
- # There are pre, main, post processing kernels, currenlty num_threads is only actually
927
- # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do
928
- # before cache key gen
929
- num_threads = 384
930
 
931
  # Backward kernel: compute dk, dv, dq_accum.
932
  score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
@@ -953,7 +1312,7 @@ def _flash_attn_bwd(
953
  subtile_factor=subtile_factor,
954
  )
955
 
956
- if arch // 10 == 9:
957
  compile_key = (
958
  arch,
959
  dtype,
@@ -961,6 +1320,8 @@ def _flash_attn_bwd(
961
  head_dim_v,
962
  qhead_per_kvhead,
963
  causal,
 
 
964
  softcap != 0.0,
965
  m_block_size,
966
  n_block_size,
@@ -975,6 +1336,8 @@ def _flash_attn_bwd(
975
  AtomLayoutNdKV,
976
  AtomLayoutMdQ,
977
  V_in_regs,
 
 
978
  cu_seqlens_q is None,
979
  cu_seqlens_k is None,
980
  seqused_q is None,
@@ -1043,51 +1406,56 @@ def _flash_attn_bwd(
1043
  if t is not None else None
1044
  for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
1045
  ]
1046
- fa_bwd_sm80 = FlashAttentionBackwardSm80(
1047
- dtype,
1048
- head_dim,
1049
- head_dim_v,
1050
- qhead_per_kvhead,
1051
- m_block_size,
1052
- n_block_size,
1053
- num_stages_Q,
1054
- num_stages_dO,
1055
- num_threads,
1056
- pack_gqa,
1057
- causal,
1058
- SdP_swapAB,
1059
- dKV_swapAB,
1060
- dQ_swapAB,
1061
- AtomLayoutMSdP,
1062
- AtomLayoutNdKV,
1063
- AtomLayoutMdQ,
1064
- V_in_regs=V_in_regs,
1065
- )
1066
- if arch // 10 == 9:
1067
- fa_bwd_obj = FlashAttentionBackwardSm90(
1068
  dtype,
1069
  head_dim,
1070
  head_dim_v,
1071
  qhead_per_kvhead,
1072
- causal,
1073
  m_block_size,
1074
  n_block_size,
1075
  num_stages_Q,
1076
  num_stages_dO,
1077
- num_stages_PdS,
 
 
1078
  SdP_swapAB,
1079
  dKV_swapAB,
1080
  dQ_swapAB,
1081
  AtomLayoutMSdP,
1082
  AtomLayoutNdKV,
1083
  AtomLayoutMdQ,
1084
- num_threads,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1085
  V_in_regs=V_in_regs,
1086
  score_mod=score_mod,
1087
  score_mod_bwd=score_mod_bwd,
1088
  mask_mod=mask_mod,
1089
  has_aux_tensors=aux_tensors is not None,
1090
  subtile_factor=subtile_factor,
 
1091
  )
1092
  else:
1093
  fa_bwd_obj = FlashAttentionBackwardSm100(
@@ -1126,7 +1494,6 @@ def _flash_attn_bwd(
1126
  dk_tensor if not dKV_postprocess else dk_accum_tensor,
1127
  dv_tensor if not dKV_postprocess else dv_accum_tensor,
1128
  softmax_scale,
1129
- current_stream,
1130
  cu_seqlens_q_tensor,
1131
  cu_seqlens_k_tensor,
1132
  seqused_q_tensor,
@@ -1139,6 +1506,7 @@ def _flash_attn_bwd(
1139
  dV_semaphore_tensor,
1140
  cute_aux_tensors,
1141
  sparse_tensors_compile,
 
1142
  options="--enable-tvm-ffi",
1143
  )
1144
  if not is_fake_mode():
@@ -1153,7 +1521,6 @@ def _flash_attn_bwd(
1153
  dk if not dKV_postprocess else dk_accum,
1154
  dv if not dKV_postprocess else dv_accum,
1155
  softmax_scale,
1156
- current_stream,
1157
  cu_seqlens_q,
1158
  cu_seqlens_k,
1159
  seqused_q,
@@ -1168,157 +1535,45 @@ def _flash_attn_bwd(
1168
  normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
1169
  )
1170
 
1171
- num_threads = 256 if arch // 10 == 9 else 128
1172
- # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
1173
- compile_key_post = (
1174
- arch,
1175
- dtype,
1176
- head_dim,
1177
- m_block_size,
1178
- num_threads,
1179
- AtomLayoutMdQ,
1180
- dQ_swapAB,
1181
- cu_seqlens_q is None,
1182
- seqused_q is None,
1183
- use_2cta_instrs,
1184
- 1, # no cluster for tile_m
1185
- get_broadcast_dims(dq_accum),
1186
- get_broadcast_dims(dq),
1187
  )
1188
- if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1189
- dq_accum_tensor = to_cute_tensor(dq_accum)
1190
- dq_tensor = to_cute_tensor(dq)
1191
- cu_seqlens_q_tensor, seqused_q_tensor = [
1192
- to_cute_tensor(t, assumed_align=4) if t is not None else None
1193
- for t in (cu_seqlens_q, seqused_q)
1194
- ]
1195
- fa_bwd_post = FlashAttentionBackwardPostprocess(
1196
- dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB,
1197
- use_2cta_instrs=use_2cta_instrs,
1198
- )
1199
- # TODO: check @can_implement
1200
- _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1201
- fa_bwd_post,
1202
- dq_accum_tensor,
1203
- dq_tensor,
1204
- softmax_scale,
1205
- cu_seqlens_q_tensor,
1206
- seqused_q_tensor,
1207
- current_stream,
1208
- options="--enable-tvm-ffi",
1209
- )
1210
-
1211
- if not is_fake_mode():
1212
- _flash_attn_bwd.compile_cache_post[compile_key_post](
1213
- dq_accum,
1214
- dq,
1215
- softmax_scale,
1216
- cu_seqlens_q,
1217
- seqused_q,
1218
- current_stream,
1219
- )
1220
 
1221
  if dKV_postprocess:
1222
- # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
1223
- compile_key_post = (
1224
- arch,
1225
- dtype,
1226
- head_dim,
1227
- n_block_size,
1228
- num_threads,
1229
- AtomLayoutNdKV,
1230
- dKV_swapAB,
1231
- cu_seqlens_k is None,
1232
- seqused_k is None,
1233
- False, # even for 2cta, is split along hdim, so always False
1234
- cluster_size, # cluster is for tile_n
1235
- get_broadcast_dims(dk_accum),
1236
- get_broadcast_dims(dk),
1237
  )
1238
- if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1239
- dk_accum_tensor = to_cute_tensor(dk_accum)
1240
- dk_tensor = to_cute_tensor(dk)
1241
- cu_seqlens_k_tensor, seqused_k_tensor = [
1242
- to_cute_tensor(t, assumed_align=4) if t is not None else None
1243
- for t in (cu_seqlens_k, seqused_k)
1244
- ]
1245
- fa_bwd_post = FlashAttentionBackwardPostprocess(
1246
- dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
1247
- cluster_size=cluster_size,
1248
- )
1249
- # TODO: check @can_implement
1250
- _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1251
- fa_bwd_post,
1252
- dk_accum_tensor,
1253
- dk_tensor,
1254
- softmax_scale,
1255
- cu_seqlens_k_tensor,
1256
- seqused_k_tensor,
1257
- current_stream,
1258
- options="--enable-tvm-ffi",
1259
- )
1260
- if not is_fake_mode():
1261
- _flash_attn_bwd.compile_cache_post[compile_key_post](
1262
- dk_accum,
1263
- dk,
1264
- softmax_scale,
1265
- cu_seqlens_k,
1266
- seqused_k,
1267
- current_stream,
1268
- )
1269
- compile_key_post = (
1270
- arch,
1271
- dtype,
1272
- head_dim_v,
1273
- n_block_size,
1274
- num_threads,
1275
- AtomLayoutNdKV,
1276
- dKV_swapAB,
1277
- cu_seqlens_k is None,
1278
- seqused_k is None,
1279
- False,
1280
- cluster_size,
1281
- get_broadcast_dims(dv_accum),
1282
- get_broadcast_dims(dv),
1283
  )
1284
- if compile_key_post not in _flash_attn_bwd.compile_cache_post:
1285
- dv_accum_tensor = to_cute_tensor(dv_accum)
1286
- dv_tensor = to_cute_tensor(dv)
1287
- cu_seqlens_k_tensor, seqused_k_tensor = [
1288
- to_cute_tensor(t, assumed_align=4) if t is not None else None
1289
- for t in (cu_seqlens_k, seqused_k)
1290
- ]
1291
- fa_bwd_post = FlashAttentionBackwardPostprocess(
1292
- dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB,
1293
- cluster_size=cluster_size,
1294
- )
1295
- # TODO: check @can_implement
1296
- _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
1297
- fa_bwd_post,
1298
- dv_accum_tensor,
1299
- dv_tensor,
1300
- cutlass.Float32(1.0),
1301
- cu_seqlens_k_tensor,
1302
- seqused_k_tensor,
1303
- current_stream,
1304
- options="--enable-tvm-ffi",
1305
- )
1306
- if not is_fake_mode():
1307
- _flash_attn_bwd.compile_cache_post[compile_key_post](
1308
- dv_accum,
1309
- dv,
1310
- 1.0,
1311
- cu_seqlens_k,
1312
- seqused_k,
1313
- current_stream,
1314
- )
1315
 
1316
  return dq, dk, dv
1317
 
1318
 
1319
- _flash_attn_bwd.compile_cache_pre = get_jit_cache("bwd_pre")
1320
  _flash_attn_bwd.compile_cache = get_jit_cache("bwd")
1321
- _flash_attn_bwd.compile_cache_post = get_jit_cache("bwd_post")
1322
 
1323
 
1324
  class FlashAttnFunc(torch.autograd.Function):
@@ -1376,14 +1631,17 @@ class FlashAttnFunc(torch.autograd.Function):
1376
  ctx.window_size = window_size
1377
  ctx.softcap = softcap
1378
  ctx.deterministic = deterministic
1379
- # LSE gradient is not supported yet
1380
- if lse is not None:
1381
- ctx.mark_non_differentiable(lse)
1382
  return out, lse
1383
 
1384
  @staticmethod
1385
- def backward(ctx, dout, *args):
1386
  q, k, v, out, lse = ctx.saved_tensors
 
 
 
 
1387
  dq, dk, dv = _flash_attn_bwd(
1388
  q,
1389
  k,
@@ -1397,6 +1655,7 @@ class FlashAttnFunc(torch.autograd.Function):
1397
  window_size_left=ctx.window_size[0],
1398
  window_size_right=ctx.window_size[1],
1399
  deterministic=ctx.deterministic,
 
1400
  )
1401
  return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
1402
 
@@ -1458,15 +1717,18 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
1458
  ctx.deterministic = deterministic
1459
  ctx.max_seqlen_q = max_seqlen_q
1460
  ctx.max_seqlen_k = max_seqlen_k
1461
- # LSE gradient is not supported yet
1462
- if lse is not None:
1463
- ctx.mark_non_differentiable(lse)
1464
  return out, lse
1465
 
1466
  @staticmethod
1467
- def backward(ctx, dout, *args):
1468
  q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
1469
  assert ctx.softcap == 0.0
 
 
 
 
1470
  dq, dk, dv = _flash_attn_bwd(
1471
  q,
1472
  k,
@@ -1486,6 +1748,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
1486
  max_seqlen_q=ctx.max_seqlen_q,
1487
  max_seqlen_k=ctx.max_seqlen_k,
1488
  deterministic=ctx.deterministic,
 
1489
  )
1490
 
1491
  return dq, dk, dv, *((None,) * 20)
@@ -1581,6 +1844,63 @@ def flash_attn_varlen_func(
1581
  )
1582
 
1583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1584
  def _flash_attn_fwd_combine(
1585
  out_partial: torch.Tensor,
1586
  lse_partial: torch.Tensor,
@@ -1589,6 +1909,7 @@ def _flash_attn_fwd_combine(
1589
  cu_seqlens: Optional[torch.Tensor] = None,
1590
  seqused: Optional[torch.Tensor] = None,
1591
  num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
 
1592
  semaphore_to_reset: Optional[torch.Tensor] = None,
1593
  ) -> None:
1594
  """Forward combine kernel for split attention computation.
@@ -1612,27 +1933,13 @@ def _flash_attn_fwd_combine(
1612
  Returns:
1613
  None
1614
  """
1615
- # Input validation
1616
- assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
1617
- assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
1618
  assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
1619
  "out_partial must be fp16, bf16, or fp32"
1620
  )
1621
- assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
1622
- assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
1623
- assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension"
1624
- assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension"
1625
- assert lse_partial.shape == out_partial.shape[:-1]
1626
-
1627
  # Determine if this is variable length based on dimensions
1628
  is_varlen = out_partial.dim() == 4
1629
-
1630
- # Validate output tensor shapes and types
1631
- assert out.shape == out_partial.shape[1:], "out shape mismatch"
1632
- if lse is not None:
1633
- assert lse.shape == lse_partial.shape[1:], "lse shape mismatch"
1634
- assert lse.dtype == torch.float32, "lse must be fp32"
1635
-
1636
  # Validate optional tensors
1637
  for t, name in [
1638
  (cu_seqlens, "cu_seqlens"),
@@ -1640,10 +1947,9 @@ def _flash_attn_fwd_combine(
1640
  (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
1641
  ]:
1642
  if t is not None:
1643
- assert t.dtype == torch.int32, f"{name} must be int32"
1644
- assert t.is_cuda, f"{name} must be on CUDA device"
1645
  assert t.is_contiguous(), f"{name} must be contiguous"
1646
-
1647
  head_dim = out_partial.shape[-1]
1648
  num_splits = out_partial.shape[0]
1649
  assert num_splits <= 256
@@ -1652,101 +1958,37 @@ def _flash_attn_fwd_combine(
1652
  k_block_size = 64 if head_dim <= 64 else 128
1653
  # We want kBlockM to be as small as possible to maximize parallelism.
1654
  # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
1655
- m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
1656
  log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
1657
- if m_block_size == 8:
1658
  # If kBlockM == 8 then the minimum number of splits is 32.
1659
  # TODO: we can deal w this by using 128 threads instead
1660
  log_max_splits = max(log_max_splits, 5)
1661
 
1662
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1663
-
1664
  # Create combine kernel configuration
1665
  dtype = torch2cute_dtype_map[out.dtype]
1666
  dtype_partial = torch2cute_dtype_map[out_partial.dtype]
1667
-
1668
  compile_key = (
1669
  dtype,
1670
  dtype_partial,
1671
  head_dim,
1672
- m_block_size,
1673
  k_block_size,
1674
  log_max_splits,
1675
  cu_seqlens is not None,
1676
  seqused is not None,
1677
  lse is not None,
 
1678
  )
1679
-
1680
  if compile_key not in _flash_attn_fwd_combine.compile_cache:
1681
- out_partial_tensor = to_cute_tensor(
1682
- out_partial, leading_dim=4 if not is_varlen else 3
1683
- )
1684
- lse_partial_tensor = to_cute_tensor(
1685
- lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
1686
- )
1687
- out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
1688
- lse_tensor = (
1689
- to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
1690
- if lse is not None
1691
- else None
1692
- )
1693
-
1694
- optional_tensors = [
1695
- to_cute_tensor(t, assumed_align=4, leading_dim=0)
1696
- if t is not None
1697
- else None
1698
- for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
1699
- ]
1700
- cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
1701
- optional_tensors
1702
- )
1703
- fa_combine = FlashAttentionForwardCombine(
1704
- dtype=dtype,
1705
- dtype_partial=dtype_partial,
1706
- head_dim=head_dim,
1707
- m_block_size=m_block_size,
1708
- k_block_size=k_block_size,
1709
- log_max_splits=log_max_splits,
1710
- )
1711
-
1712
- # Check if implementation is supported
1713
- if not fa_combine.can_implement(
1714
- dtype,
1715
- dtype_partial,
1716
- head_dim,
1717
- m_block_size,
1718
- k_block_size,
1719
- log_max_splits,
1720
- num_threads=256,
1721
- ):
1722
- raise RuntimeError(
1723
- "FlashAttention combine kernel cannot be implemented with given parameters"
1724
- )
1725
-
1726
- _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile(
1727
- fa_combine,
1728
- out_partial_tensor,
1729
- lse_partial_tensor,
1730
- out_tensor,
1731
- lse_tensor,
1732
- cu_seqlens_tensor,
1733
- seqused_tensor,
1734
- num_splits_dynamic_tensor,
1735
- semaphore_tensor,
1736
- current_stream,
1737
- options="--enable-tvm-ffi",
1738
  )
1739
  if not is_fake_mode():
1740
  _flash_attn_fwd_combine.compile_cache[compile_key](
1741
- out_partial,
1742
- lse_partial,
1743
- out,
1744
- lse,
1745
- cu_seqlens,
1746
- seqused,
1747
- num_splits_dynamic_ptr,
1748
  semaphore_to_reset,
1749
- current_stream,
1750
  )
1751
 
1752
 
@@ -1760,6 +2002,7 @@ def flash_attn_combine(
1760
  out_dtype: Optional[torch.dtype] = None,
1761
  cu_seqlens: Optional[torch.Tensor] = None,
1762
  seqused: Optional[torch.Tensor] = None,
 
1763
  return_lse: bool = True,
1764
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1765
  """Flash Attention combine function for split attention computation.
@@ -1779,6 +2022,9 @@ def flash_attn_combine(
1779
  out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
1780
  cu_seqlens: Cumulative sequence lengths for variable length sequences
1781
  seqused: Used sequence lengths for each batch
 
 
 
1782
  return_lse: Whether to return the combined LSE tensor. Default is True.
1783
 
1784
  Returns:
@@ -1795,32 +2041,19 @@ def flash_attn_combine(
1795
  """
1796
  # Input validation
1797
  assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
1798
- assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions"
1799
- assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)"
1800
- assert lse_partial.dtype == torch.float32, "lse_partial must be fp32"
1801
-
1802
  # Determine if this is variable length based on dimensions
1803
  is_varlen = out_partial.dim() == 4
1804
-
1805
  if is_varlen:
1806
  # Variable length: (num_splits, total_q, num_heads, head_size)
1807
  num_splits, total_q, num_heads, head_size = out_partial.shape
1808
- assert lse_partial.shape == (num_splits, total_q, num_heads), (
1809
- "lse_partial shape mismatch for varlen"
1810
- )
1811
  batch_size = 1 # Treat as single batch for varlen
1812
  seqlen = total_q
1813
  else:
1814
  # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
1815
  num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
1816
- assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), (
1817
- "lse_partial shape mismatch"
1818
- )
1819
-
1820
  # Determine output dtype
1821
  if out_dtype is None:
1822
  out_dtype = out_partial.dtype
1823
-
1824
  # Create output if not provided
1825
  device = out_partial.device
1826
  if out is None:
@@ -1830,20 +2063,15 @@ def flash_attn_combine(
1830
  out = torch.empty(
1831
  batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
1832
  )
1833
-
1834
  # Create lse output only if requested
1835
  if return_lse:
1836
  if is_varlen:
1837
- lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(
1838
- 0, 1
1839
- )
1840
  else:
1841
- lse = torch.empty(
1842
- batch_size, num_heads, seqlen, dtype=torch.float32, device=device
1843
- ).transpose(1, 2)
1844
  else:
1845
  lse = None
1846
-
1847
  _flash_attn_fwd_combine(
1848
  out_partial,
1849
  lse_partial,
@@ -1851,5 +2079,6 @@ def flash_attn_combine(
1851
  lse,
1852
  cu_seqlens,
1853
  seqused,
 
1854
  )
1855
  return out, lse
 
21
 
22
  import os
23
  import math
24
+ from dataclasses import dataclass
25
  from functools import lru_cache
26
  from typing import Optional, Tuple, Callable
27
 
 
32
 
33
  import cutlass
34
  import cutlass.cute as cute
35
+ from cutlass import Int32, Float32
36
+ from .quack.compile_utils import make_fake_tensor as fake_tensor
37
  from .cache_utils import get_jit_cache
38
  from .testing import is_fake_mode
39
 
 
46
 
47
 
48
  from . import utils
49
+ from . import fa_logging
50
  from .cute_dsl_utils import (
51
  to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,
52
  )
53
+ from .flash_fwd import FlashAttentionForwardSm80
54
+ from .flash_fwd_sm90 import FlashAttentionForwardSm90
55
  from .flash_fwd_sm100 import FlashAttentionForwardSm100
56
+ from .flash_fwd_sm120 import FlashAttentionForwardSm120
57
  from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess
58
  from .flash_bwd import FlashAttentionBackwardSm80
59
  from .flash_bwd_sm90 import FlashAttentionBackwardSm90
60
  from .flash_bwd_sm100 import FlashAttentionBackwardSm100
61
+ from .flash_bwd_sm120 import FlashAttentionBackwardSm120
62
  from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess
63
  from .flash_fwd_combine import FlashAttentionForwardCombine
64
 
65
  from .block_sparsity import (
66
  BlockSparseTensorsTorch,
67
+ get_sparse_q_block_size,
68
  to_cute_block_sparse_tensors,
69
  normalize_block_sparse_config,
70
  normalize_block_sparse_config_bwd,
71
  )
72
 
73
+ def _parse_arch_str(arch_str):
74
+ """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100)."""
75
+ import re
76
+ match = re.match(r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$", arch_str)
77
+ if not match:
78
+ raise ValueError(f"Invalid arch format: {arch_str}")
79
+ major, minor, _ = match.groups()
80
+ return int(major) * 10 + int(minor)
81
+
82
+
83
  @lru_cache(maxsize=None)
84
  def _get_device_arch():
85
+ """Cached device arch check.
86
+
87
+ Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which
88
+ kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation
89
+ target (CUTE_DSL_ARCH).
90
+
91
+ For CPU-only compilation (no GPU), set both:
92
+ FLASH_ATTENTION_ARCH=sm_80 (kernel selection)
93
+ CUTE_DSL_ARCH=sm_80 (compilation target)
94
+ """
95
+ arch_override = os.environ.get("FLASH_ATTENTION_ARCH", None)
96
+ if arch_override is not None:
97
+ return _parse_arch_str(arch_override)
98
  major, minor = torch.cuda.get_device_capability()
99
+ return major * 10 + int(minor)
100
+
101
+
102
+ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None:
103
+ """Validate head dimension constraints based on compute capability."""
104
+ is_deepseek_shape = head_dim == 192 and head_dim_v == 128
105
+ is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
106
+
107
+ is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256
108
+ if compute_capability == 9:
109
+ assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
110
+ f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. "
111
+ f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}."
112
+ )
113
+ elif compute_capability in [10, 11]:
114
+ assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
115
+ f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. "
116
+ f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek."
117
+ )
118
+
119
+
120
+ @dataclass(frozen=True)
121
+ class FwdConfig:
122
+ m_block_size: int
123
+ n_block_size: int
124
+ mma_pv_is_rs: bool
125
+ intra_wg_overlap: bool
126
+
127
+
128
+ def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None):
129
+ """Return FwdConfig for SM90 forward.
130
+
131
+ Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted
132
+ for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM).
133
+
134
+ When sparse_block_size_q is set, tile_m must divide it. For head_dim <= 96 the
135
+ optimal tile_m=192 is used when compatible, otherwise we fall back to 128.
136
+ """
137
+ if head_dim <= 64:
138
+ # C++: 192×192 non-causal, 192×128 causal/local.
139
+ # Python: 192×128 RS+OL is consistently best across seqlens.
140
+ if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0:
141
+ return FwdConfig(128, 128, True, True)
142
+ return FwdConfig(192, 128, True, True)
143
+ elif head_dim <= 96:
144
+ # C++: 192×144 noRS+OL for all cases.
145
+ # Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS).
146
+ # noRS+OL is always required. Causal: 192×128 slightly better short seqlen.
147
+ if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0:
148
+ return FwdConfig(128, 128, False, True)
149
+ if is_causal or is_local:
150
+ return FwdConfig(192, 128, False, True)
151
+ else:
152
+ return FwdConfig(192, 144, False, True)
153
+ elif head_dim <= 128:
154
+ return FwdConfig(128, 128, True, True)
155
+ elif head_dim <= 192:
156
+ tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112)
157
+ return FwdConfig(128, tile_n, True, True)
158
+ else: # hdim 256
159
+ tile_n = 64 if is_local else 80
160
+ return FwdConfig(128, tile_n, True, True)
161
+
162
+ @dataclass(frozen=True)
163
+ class BwdConfig:
164
+ m_block_size: int
165
+ n_block_size: int
166
+ num_stages_Q: int
167
+ num_stages_dO: int
168
+ num_stages_PdS: int
169
+ SdP_swapAB: bool
170
+ dKV_swapAB: bool
171
+ dQ_swapAB: bool
172
+ AtomLayoutMSdP: int
173
+ AtomLayoutNdKV: int
174
+ AtomLayoutMdQ: int
175
+ num_wg: int = 2 # MMA warp groups (total threads = (num_wg + 1) * 128)
176
+ dQ_single_wg: bool = False
177
+
178
+
179
+ def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=None):
180
+ """Return BwdConfig for SM90.
181
+
182
+ Configs based on C++ FA3 hopper/flash_bwd_launch_template.h,
183
+ benchmarked on H100 SXM.
184
+ """
185
+ if head_dim <= 64:
186
+ # C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2
187
+ return BwdConfig(
188
+ m_block_size=128, n_block_size=128,
189
+ num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,
190
+ SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False,
191
+ AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2,
192
+ )
193
+ elif head_dim <= 96:
194
+ # C++ FA3: 64, 128, 96, dQ_swapAB=False
195
+ return BwdConfig(
196
+ m_block_size=64, n_block_size=128,
197
+ num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,
198
+ SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False,
199
+ AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
200
+ dQ_single_wg=True,
201
+ )
202
+ elif head_dim <= 128:
203
+ # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB
204
+ is_causal_or_local = causal or local
205
+ m_block_size = 64 if is_causal_or_local else 80
206
+ if sparse_block_size_q is not None and sparse_block_size_q % m_block_size != 0:
207
+ m_block_size = 64
208
+ return BwdConfig(
209
+ m_block_size=m_block_size,
210
+ n_block_size=128,
211
+ num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,
212
+ SdP_swapAB=True, dKV_swapAB=False,
213
+ dQ_swapAB=m_block_size % 64 != 0,
214
+ AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
215
+ )
216
+ elif head_dim <= 192:
217
+ hdimv128 = head_dim_v <= 128
218
+ if hdimv128:
219
+ return BwdConfig(
220
+ m_block_size=64, n_block_size=96,
221
+ num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1,
222
+ SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False,
223
+ AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
224
+ num_wg=2,
225
+ )
226
+ else:
227
+ return BwdConfig(
228
+ m_block_size=64, n_block_size=96,
229
+ num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1,
230
+ SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False,
231
+ AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,
232
+ num_wg=2,
233
+ )
234
+ else:
235
+ # hdim 256
236
+ return BwdConfig(
237
+ m_block_size=64, n_block_size=64,
238
+ num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1,
239
+ SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False,
240
+ AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1,
241
+ )
242
+
243
+
244
 
245
  def maybe_contiguous(x):
246
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
 
250
  assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}"
251
  assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}"
252
  assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
253
+ if not is_fake_mode():
254
+ assert t.is_cuda, f"{name} must be on CUDA"
255
 
256
 
257
  torch2cute_dtype_map = {
 
271
  return min(num_SMs // total_mblocks, max_splits, num_n_blocks)
272
 
273
 
274
+ def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None):
275
+ """Resolve causal/local/window settings into canonical form.
276
+
277
+ Returns (causal, local, window_size_left, window_size_right).
278
+ """
279
+ if mask_mod is not None:
280
+ return False, False, window_size_left, window_size_right
281
+ if causal:
282
+ window_size_right = 0
283
+ if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:
284
+ window_size_left = None
285
+ window_size_right = None
286
+ if window_size_left is not None or window_size_right is not None:
287
+ if window_size_left is None and window_size_right == 0:
288
+ causal, local = True, False
289
+ window_size_right = None
290
+ else:
291
+ causal, local = False, True
292
+ else:
293
+ local = False
294
+ return causal, local, window_size_left, window_size_right
295
+
296
+
297
  def _flash_attn_fwd(
298
  q: torch.Tensor,
299
  k: torch.Tensor,
 
311
  window_size_left: Optional[int] = None,
312
  window_size_right: Optional[int] = None,
313
  learnable_sink: Optional[torch.Tensor] = None,
314
+ tile_mn: Optional[Tuple[int, int]] = None,
315
+ mma_pv_is_rs: Optional[bool] = None,
316
+ intra_wg_overlap: Optional[bool] = None,
 
 
317
  num_threads: int = 384,
318
  num_splits: int = 1,
319
  pack_gqa: Optional[bool] = None,
 
334
  mask_mod: A callable that takes token position information and selectively masks
335
  block_sparse_tensors: A tuple of tensors used for block sparsity.
336
  return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
337
+ The returned LSE supports taking gradient.
338
  out: Optional pre-allocated output tensor. If None, will be allocated internally.
339
  lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
340
  aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
 
399
  assert learnable_sink.shape == (num_head,)
400
  assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
401
 
402
+ if not is_fake_mode():
403
+ assert all(
404
+ t is None or t.is_cuda
405
+ for t in (
406
+ q,
407
+ k,
408
+ v,
409
+ cu_seqlens_q,
410
+ cu_seqlens_k,
411
+ seqused_q,
412
+ seqused_k,
413
+ page_table,
414
+ learnable_sink,
415
+ )
416
+ ), "inputs must be on CUDA device"
417
+ arch = _get_device_arch() if _arch is None else _arch
418
+ assert arch // 10 in [8, 9, 10, 11, 12], "Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x"
419
  assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
 
420
  alignment = 16 // q.element_size()
421
+ if arch // 10 not in [8, 12]:
422
+ _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)
423
  if softmax_scale is None:
424
  softmax_scale = 1.0 / math.sqrt(head_dim)
425
  if softcap == 0.0:
 
451
  _validate_tensor(lse, "lse", lse_shape, torch.float32, device)
452
 
453
  dtype = torch2cute_dtype_map[q.dtype]
454
+ use_block_sparsity = block_sparse_tensors is not None
455
+
456
+ causal, local, window_size_left, window_size_right = _resolve_causal_local_window(
457
+ causal, window_size_left, window_size_right, mask_mod
458
+ )
459
 
460
+ requested_use_clc_scheduler = utils._get_use_clc_scheduler_default()
461
+ requested_disable_2cta = utils._get_disable_2cta_default()
462
 
463
+ current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
464
 
465
+ # SM80/SM120: uses SM80 MMA, 128 threads (4 warps)
466
+ if arch // 10 in [8, 12]:
467
+ num_threads = 128
468
+
469
+ fwd_cfg = FwdConfig(128, 128, True, True) # default
470
+ if tile_mn is None:
471
+ if arch // 10 == 12:
472
+ # SM120 tile sizes tuned for 99 KB SMEM capacity:
473
+ # D<=64: 128x128 48 KB (good occupancy)
474
+ # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy)
475
+ if head_dim <= 64:
476
+ fwd_cfg = FwdConfig(128, 128, True, True)
477
  else:
478
+ fwd_cfg = FwdConfig(128, 64, True, True)
479
+ elif arch // 10 == 8:
480
+ fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune
481
+ elif arch // 10 == 9:
482
+ sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q)
483
+ fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q)
484
  else:
485
+ fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap)
486
+ tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size
487
+ if mma_pv_is_rs is None:
488
+ mma_pv_is_rs = fwd_cfg.mma_pv_is_rs
489
+ if intra_wg_overlap is None:
490
+ intra_wg_overlap = fwd_cfg.intra_wg_overlap
491
 
492
+ # TODO: fix GQA + SplitKV + non-varlen
493
+ if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
494
+ pack_gqa = False
 
 
 
 
 
 
 
 
 
 
495
 
496
  if max_seqlen_q is None:
497
  max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q
 
499
  max_seqlen_k = seqlen_k
500
  seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
501
  if arch // 10 == 10:
502
+ q_stage = 2 if seqlen_q_packgqa > tile_m else 1
503
  else:
504
  q_stage = 1
505
 
506
+ m_block_size_effective = q_stage * tile_m
507
+ seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m))
508
+ num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective
509
+ total_mblocks = batch_size * num_head_kv * num_m_blocks
510
+ num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n
511
+ num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count
512
  if num_splits < 1:
513
+ num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)
514
+
515
+ # SplitKV uses float32 partial output, which doubles the O buffer size
516
+ # in shared memory, causing OOM for diff-headdim (192, 128)
517
+ if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1:
518
+ if num_n_blocks >= 64:
519
+ tile_n = 64
520
+ num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n
521
+ num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)
522
+ else:
523
+ num_splits = 1
524
 
525
  is_split_kv = num_splits > 1
526
  if is_split_kv:
527
  out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
528
  lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)
529
 
530
+ use_2cta_instrs = (
531
+ arch // 10 in [10, 11]
532
+ and not requested_disable_2cta
533
+ and not causal
534
+ and not local
535
+ and not is_split_kv
536
+ and cu_seqlens_q is None
537
+ and seqused_q is None
538
+ and not use_block_sparsity
539
+ and page_size in [None, 128]
540
+ and int(math.ceil(head_dim / 16) * 16) in [128, 192]
541
+ and int(math.ceil(head_dim_v / 16) * 16) == 128
542
+ and seqlen_q_packgqa > 2 * tile_m
543
+ and (tile_m % qhead_per_kvhead == 0 or not pack_gqa)
544
+ )
545
+
546
  # hash score and mask mods for compile cache
547
  score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
548
  mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
 
594
  num_head=num_head,
595
  seqlen_q=seqlen_q,
596
  seqlen_k=seqlen_k,
597
+ block_size=(tile_m, tile_n),
598
  q_stage=q_stage,
599
  )
600
+ if aux_tensors is not None:
601
  aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
602
  else:
603
  aux_tensor_metadata = None
604
+
605
  compile_key = (
606
  dtype,
607
  head_dim,
 
622
  window_size_left is not None,
623
  window_size_right is not None,
624
  learnable_sink is not None,
625
+ tile_m,
626
+ tile_n,
627
  q_stage,
628
  num_threads,
629
  is_split_kv,
630
  pack_gqa,
631
  arch,
632
+ page_size not in [None, tile_n], # paged KV non-TMA
633
+ use_2cta_instrs,
634
  q_subtile_factor,
635
+ mma_pv_is_rs,
636
+ intra_wg_overlap,
637
+ requested_use_clc_scheduler,
638
+ fa_logging.get_fa_log_level(),
639
  )
640
  if compile_key not in _flash_attn_fwd.compile_cache:
641
  (
 
674
  if aux_tensors is not None:
675
  cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]
676
 
677
+ if arch // 10 == 8:
678
+ assert page_table is None, "paged KV not supported on SM 8.0"
679
+ assert not is_split_kv, "SplitKV not supported on SM 8.0"
680
+ fa_fwd = FlashAttentionForwardSm80(
681
+ dtype,
682
+ head_dim,
683
+ head_dim_v,
684
+ qhead_per_kvhead,
685
+ is_causal=causal,
686
+ is_local=local,
687
+ pack_gqa=pack_gqa,
688
+ tile_m=tile_m,
689
+ tile_n=tile_n,
690
+ num_stages=1,
691
+ num_threads=num_threads,
692
+ Q_in_regs=False,
693
+ score_mod=score_mod,
694
+ mask_mod=mask_mod,
695
+ has_aux_tensors=aux_tensors is not None,
696
+ )
697
+ elif arch // 10 == 9:
698
  assert not is_split_kv, "SplitKV not supported on SM 9.0"
 
699
  fa_fwd = FlashAttentionForwardSm90(
700
  dtype,
701
  head_dim,
 
704
  is_causal=causal,
705
  is_local=local,
706
  pack_gqa=pack_gqa,
707
+ tile_m=tile_m,
708
+ tile_n=tile_n,
709
  # num_stages=1,
710
  num_stages=2,
711
  num_threads=num_threads,
712
  Q_in_regs=False,
713
+ intra_wg_overlap=intra_wg_overlap,
714
+ mma_pv_is_rs=mma_pv_is_rs,
715
  mask_mod=mask_mod,
716
  score_mod=score_mod,
717
  has_aux_tensors=aux_tensors is not None,
718
  q_subtile_factor=q_subtile_factor,
719
+ paged_kv_non_tma=page_size not in [None, tile_n],
720
  )
721
  elif arch // 10 in [10, 11]:
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  fa_fwd = FlashAttentionForwardSm100(
723
  head_dim,
724
  head_dim_v,
 
727
  is_local=local,
728
  is_split_kv=is_split_kv,
729
  pack_gqa=pack_gqa,
730
+ m_block_size=tile_m,
731
+ n_block_size=tile_n,
732
  q_stage=q_stage,
733
  is_persistent=not causal
734
  and not local
 
738
  score_mod=score_mod,
739
  mask_mod=mask_mod,
740
  has_aux_tensors=aux_tensors is not None,
741
+ paged_kv_non_tma=page_size not in [None, tile_n],
742
  is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,
743
  q_subtile_factor=q_subtile_factor,
744
  use_2cta_instrs=use_2cta_instrs,
745
+ use_clc_scheduler=requested_use_clc_scheduler,
746
+ )
747
+ elif arch // 10 == 12:
748
+ # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity
749
+ assert not use_block_sparsity, "Block sparsity not supported on SM 12.0"
750
+ assert page_table is None, "Paged KV not supported on SM 12.0 in this PR"
751
+ assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR"
752
+ fa_fwd = FlashAttentionForwardSm120(
753
+ dtype,
754
+ head_dim,
755
+ head_dim_v,
756
+ qhead_per_kvhead,
757
+ is_causal=causal,
758
+ is_local=local,
759
+ pack_gqa=pack_gqa,
760
+ tile_m=tile_m,
761
+ tile_n=tile_n,
762
+ num_stages=1,
763
+ num_threads=num_threads,
764
+ Q_in_regs=False,
765
+ score_mod=score_mod,
766
+ mask_mod=mask_mod,
767
+ has_aux_tensors=aux_tensors is not None,
768
  )
769
  else:
770
  raise ValueError(
771
+ f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x"
772
  )
773
  # TODO: check @can_implement
774
  _flash_attn_fwd.compile_cache[compile_key] = cute.compile(
 
779
  o_tensor,
780
  lse_tensor,
781
  softmax_scale,
 
782
  cu_seqlens_q_tensor,
783
  cu_seqlens_k_tensor,
784
  seqused_q_tensor,
 
789
  learnable_sink_tensor,
790
  sparse_tensors,
791
  cute_aux_tensors,
792
+ current_stream,
793
  options="--enable-tvm-ffi",
794
  )
795
 
 
805
  out.detach() if not is_split_kv else out_partial,
806
  lse_partial if is_split_kv else lse,
807
  softmax_scale,
 
808
  cu_seqlens_q,
809
  cu_seqlens_k,
810
  seqused_q,
 
831
  _flash_attn_fwd.compile_cache = get_jit_cache("fwd")
832
 
833
 
834
+ def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k):
835
+ sym = cute.sym_int
836
+ # divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8
837
+ # For 16-byte align: fp16/bf16 → divisibility=8, float32 → divisibility=4
838
+ div = 128 // dtype.width # 8 for fp16/bf16
839
+ # Shared sym_ints for dimensions that must match across tensors
840
+ b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym()
841
+ h_kv = h_q if not has_gqa else sym()
842
+ seqlen_q_rounded, seqlen_k_rounded = sym(), sym()
843
+ seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym()
844
+ total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym()
845
+ total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym()
846
+ b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,)
847
+ b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,)
848
+ mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div)
849
+ mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div)
850
+ mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div)
851
+ mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div)
852
+ mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div)
853
+ mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div)
854
+ mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div)
855
+ mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div)
856
+ if not varlen_q:
857
+ mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1)
858
+ mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4)
859
+ mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4)
860
+ dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4)
861
+ else:
862
+ mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1)
863
+ mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4)
864
+ mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4)
865
+ dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4)
866
+ if not has_gqa:
867
+ mdKaccum, mdVaccum = None, None
868
+ else:
869
+ if not varlen_k:
870
+ mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4)
871
+ mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4)
872
+ else:
873
+ mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4)
874
+ mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4)
875
+ return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum
876
+
877
+
878
+ def _compile_bwd_preprocess(
879
+ dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse,
880
+ ):
881
+ """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed)."""
882
+ mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors(
883
+ dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False
884
+ )
885
+ batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int()
886
+ batchp1 = cute.sym_int()
887
+ mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None
888
+ mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None
889
+ mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None
890
+ fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size)
891
+ return cute.compile(
892
+ fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE,
893
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
894
+ options="--enable-tvm-ffi",
895
+ )
896
+
897
+
898
+ def _bwd_preprocess(
899
+ out, dout, dpsum, lse, lse_log2, dq_accum,
900
+ cu_seqlens_q, seqused_q, dlse,
901
+ dtype, head_dim, head_dim_v, m_block_size,
902
+ ):
903
+ """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum."""
904
+ is_varlen = cu_seqlens_q is not None
905
+ compile_key = (
906
+ dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None,
907
+ )
908
+ if compile_key not in _bwd_preprocess.compile_cache:
909
+ _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key)
910
+ if not is_fake_mode():
911
+ _bwd_preprocess.compile_cache[compile_key](
912
+ out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse
913
+ )
914
+
915
+
916
+ _bwd_preprocess.compile_cache = get_jit_cache("bwd_pre")
917
+
918
+
919
+ def _compile_bwd_postprocess(
920
+ dtype, hdim, block_size, num_threads, atom_layout, swap_ab,
921
+ has_cuseqlens_q, has_seqused_q,
922
+ use_2cta_instrs, cluster_size, arch,
923
+ ):
924
+ """Compile bwd postprocess kernel using cute fake tensors."""
925
+ mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors(
926
+ dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False
927
+ )
928
+ batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int()
929
+ batchp1 = cute.sym_int()
930
+ mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None
931
+ mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None
932
+ fa_bwd_post = FlashAttentionBackwardPostprocess(
933
+ dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab,
934
+ use_2cta_instrs=use_2cta_instrs,
935
+ cluster_size=cluster_size,
936
+ )
937
+ return cute.compile(
938
+ fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ,
939
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
940
+ options="--enable-tvm-ffi",
941
+ )
942
+
943
+
944
+ def _bwd_postprocess_convert(
945
+ accum, output, scale,
946
+ cu_seqlens, seqused,
947
+ arch, dtype, hdim, block_size, num_threads,
948
+ atom_layout, swap_ab,
949
+ use_2cta_instrs=False, cluster_size=1,
950
+ ):
951
+ """Backward postprocess: convert float32 accumulator to bf16/fp16 output."""
952
+ compile_key = (
953
+ dtype, hdim, block_size, num_threads, atom_layout, swap_ab,
954
+ cu_seqlens is not None, seqused is not None,
955
+ use_2cta_instrs, cluster_size, arch,
956
+ )
957
+ if compile_key not in _bwd_postprocess_convert.compile_cache:
958
+ _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key)
959
+ if not is_fake_mode():
960
+ _bwd_postprocess_convert.compile_cache[compile_key](
961
+ accum, output, scale, cu_seqlens, seqused,
962
+ )
963
+
964
+
965
+ _bwd_postprocess_convert.compile_cache = get_jit_cache("bwd_post")
966
+
967
+
968
  def _flash_attn_bwd(
969
  q: torch.Tensor,
970
  k: torch.Tensor,
 
1005
  mask_mod: Optional[Callable] = None,
1006
  aux_tensors: Optional[list[torch.Tensor]] = None,
1007
  block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,
1008
+ dlse: Optional[torch.Tensor] = None,
1009
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1010
  arch = _get_device_arch()
1011
+ assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x"
1012
+ sparse_q = None
1013
+ if block_sparse_tensors is not None and arch // 10 == 9:
1014
+ sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128
1015
 
1016
  num_head, head_dim = q.shape[-2:]
1017
+ head_dim_v = v.shape[-1]
1018
 
1019
+ causal, local, window_size_left, window_size_right = _resolve_causal_local_window(
1020
+ causal, window_size_left, window_size_right
1021
+ )
 
 
 
 
 
 
 
 
 
1022
 
1023
+ if arch // 10 == 12:
1024
+ # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps).
1025
+ m_block_size = 64
1026
+ n_block_size = 64
1027
+ if head_dim <= 64:
1028
+ num_stages_Q = 2
1029
+ num_stages_dO = 2
1030
+ else:
1031
+ num_stages_Q = 1
1032
+ num_stages_dO = 1
1033
+ SdP_swapAB = False
1034
  dKV_swapAB = False
1035
+ dQ_swapAB = False
1036
+ AtomLayoutMSdP = 4
1037
+ AtomLayoutNdKV = 4
1038
+ AtomLayoutMdQ = 4
1039
+ V_in_regs = False
1040
+ cluster_size = 1
1041
+ use_2cta_instrs = False
1042
+ num_threads = 128
1043
+ assert not (block_sparse_tensors is not None), "Block sparsity backward not supported on SM 12.0"
1044
+ assert score_mod is None and score_mod_bwd is None, "score_mod backward not supported on SM 12.0"
1045
+ assert mask_mod is None, "mask_mod backward not supported on SM 12.0"
1046
+ assert deterministic is False, "deterministic backward not supported on SM 12.0"
1047
+ elif arch // 10 == 9:
1048
+ cfg = _tile_size_bwd_sm90(
1049
+ head_dim,
1050
+ head_dim_v,
1051
+ causal,
1052
+ local,
1053
+ sparse_block_size_q=sparse_q,
1054
+ )
1055
+ m_block_size = cfg.m_block_size
1056
+ n_block_size = cfg.n_block_size
1057
+ num_stages_Q = cfg.num_stages_Q
1058
+ num_stages_dO = cfg.num_stages_dO
1059
+ num_stages_PdS = cfg.num_stages_PdS
1060
+ SdP_swapAB = cfg.SdP_swapAB
1061
+ dKV_swapAB = cfg.dKV_swapAB
1062
+ dQ_swapAB = cfg.dQ_swapAB
1063
+ AtomLayoutMSdP = cfg.AtomLayoutMSdP
1064
+ AtomLayoutNdKV = cfg.AtomLayoutNdKV
1065
+ AtomLayoutMdQ = cfg.AtomLayoutMdQ
1066
+ num_threads = (cfg.num_wg + 1) * 128
1067
+ dQ_single_wg = cfg.dQ_single_wg
1068
  cluster_size = 1
1069
  use_2cta_instrs = False
 
1070
  is_varlen = (
1071
  cu_seqlens_q is not None
1072
  or cu_seqlens_k is not None
1073
  or seqused_q is not None
1074
  or seqused_k is not None
1075
  )
 
1076
  else:
1077
  m_block_size = 128
1078
  n_block_size = 128
 
1080
  dKV_swapAB = False
1081
  AtomLayoutMdQ = 1
1082
  AtomLayoutNdKV = 1
1083
+ requested_disable_2cta = utils._get_disable_2cta_default()
1084
  disable_2cta = (
1085
+ requested_disable_2cta
1086
  or score_mod is not None
1087
  or score_mod_bwd is not None
1088
  or mask_mod is not None
1089
+ or block_sparse_tensors is not None
1090
  )
1091
  cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
1092
  use_2cta_instrs = cluster_size==2
1093
+
1094
  q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
1095
  maybe_contiguous(t)
1096
  for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
 
1112
  seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
1113
 
1114
  num_head_kv = k.shape[-2]
 
1115
 
1116
  use_block_sparsity = block_sparse_tensors is not None
1117
+ subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2
 
 
 
 
 
 
 
 
 
1118
  seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
1119
  seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
1120
  num_n_blocks = seqlen_k_rounded // n_block_size
 
1154
  if t is not None:
1155
  assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32"
1156
  assert lse.dtype == torch.float32, "lse must be float32"
1157
+ if dlse is not None:
1158
+ dlse = maybe_contiguous(dlse)
1159
+ if not is_fake_mode():
1160
+ assert all(
1161
+ t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
1162
+ ), "inputs must be on CUDA device"
1163
  assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
 
1164
  alignment = 16 // q.element_size()
1165
+ if arch // 10 != 12:
1166
+ _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)
1167
  if softmax_scale is None:
1168
  softmax_scale = 1.0 / math.sqrt(head_dim)
1169
  qhead_per_kvhead = num_head // num_head_kv
 
1171
  pack_gqa = qhead_per_kvhead > 1
1172
  # pack_gqa backward not yet supported in bwd
1173
  pack_gqa = False
 
 
 
1174
  if score_mod is not None:
1175
  assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
1176
  assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
 
1222
  dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
1223
  lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
1224
 
1225
+ # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads
1226
+ # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses
1227
+ # ragged TMA tensors for direct store, so no longer needs accum+postprocess.
1228
  dKV_postprocess = qhead_per_kvhead > 1
1229
  if dKV_postprocess:
1230
  head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
 
1262
  )
1263
 
1264
  dtype = torch2cute_dtype_map[q.dtype]
1265
+ current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
1266
 
1267
  if deterministic:
1268
+ dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device)
1269
  else:
1270
  dQ_semaphore = None
1271
 
1272
  if deterministic and qhead_per_kvhead > 1:
1273
+ dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device)
1274
+ dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device)
1275
  else:
1276
  dK_semaphore = None
1277
  dV_semaphore = None
1278
 
1279
+ # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.
1280
+ _bwd_preprocess(
1281
+ out, dout, dpsum, lse, lse_log2, dq_accum,
1282
+ cu_seqlens_q, seqused_q, dlse,
1283
+ dtype, head_dim, head_dim_v, m_block_size,
 
 
 
 
 
 
 
1284
  )
1285
+ # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above,
1286
+ # SM100/SM110 uses default from function signature (384).
1287
+ if arch // 10 not in [9, 12]:
1288
+ num_threads = 384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1289
 
1290
  # Backward kernel: compute dk, dv, dq_accum.
1291
  score_mod_hash = utils.hash_callable(score_mod) if score_mod else False
 
1312
  subtile_factor=subtile_factor,
1313
  )
1314
 
1315
+ if arch // 10 in [8, 9, 12]:
1316
  compile_key = (
1317
  arch,
1318
  dtype,
 
1320
  head_dim_v,
1321
  qhead_per_kvhead,
1322
  causal,
1323
+ window_size_left is not None,
1324
+ window_size_right is not None,
1325
  softcap != 0.0,
1326
  m_block_size,
1327
  n_block_size,
 
1336
  AtomLayoutNdKV,
1337
  AtomLayoutMdQ,
1338
  V_in_regs,
1339
+ dQ_single_wg,
1340
+ deterministic,
1341
  cu_seqlens_q is None,
1342
  cu_seqlens_k is None,
1343
  seqused_q is None,
 
1406
  if t is not None else None
1407
  for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
1408
  ]
1409
+ if arch // 10 in [8, 12]:
1410
+ flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80
1411
+ fa_bwd_obj = flash_bwd_obj_cls(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1412
  dtype,
1413
  head_dim,
1414
  head_dim_v,
1415
  qhead_per_kvhead,
 
1416
  m_block_size,
1417
  n_block_size,
1418
  num_stages_Q,
1419
  num_stages_dO,
1420
+ num_threads,
1421
+ pack_gqa,
1422
+ causal,
1423
  SdP_swapAB,
1424
  dKV_swapAB,
1425
  dQ_swapAB,
1426
  AtomLayoutMSdP,
1427
  AtomLayoutNdKV,
1428
  AtomLayoutMdQ,
1429
+ V_in_regs=V_in_regs,
1430
+ )
1431
+ elif arch // 10 == 9:
1432
+ fa_bwd_obj = FlashAttentionBackwardSm90(
1433
+ dtype,
1434
+ head_dim,
1435
+ head_dim_v,
1436
+ qhead_per_kvhead,
1437
+ causal,
1438
+ is_local=local,
1439
+ deterministic=deterministic,
1440
+ tile_m=m_block_size,
1441
+ tile_n=n_block_size,
1442
+ Q_stage=num_stages_Q,
1443
+ dO_stage=num_stages_dO,
1444
+ PdS_stage=num_stages_PdS,
1445
+ SdP_swapAB=SdP_swapAB,
1446
+ dKV_swapAB=dKV_swapAB,
1447
+ dQ_swapAB=dQ_swapAB,
1448
+ AtomLayoutMSdP=AtomLayoutMSdP,
1449
+ AtomLayoutNdKV=AtomLayoutNdKV,
1450
+ AtomLayoutMdQ=AtomLayoutMdQ,
1451
+ num_threads=num_threads,
1452
  V_in_regs=V_in_regs,
1453
  score_mod=score_mod,
1454
  score_mod_bwd=score_mod_bwd,
1455
  mask_mod=mask_mod,
1456
  has_aux_tensors=aux_tensors is not None,
1457
  subtile_factor=subtile_factor,
1458
+ dQ_single_wg=dQ_single_wg,
1459
  )
1460
  else:
1461
  fa_bwd_obj = FlashAttentionBackwardSm100(
 
1494
  dk_tensor if not dKV_postprocess else dk_accum_tensor,
1495
  dv_tensor if not dKV_postprocess else dv_accum_tensor,
1496
  softmax_scale,
 
1497
  cu_seqlens_q_tensor,
1498
  cu_seqlens_k_tensor,
1499
  seqused_q_tensor,
 
1506
  dV_semaphore_tensor,
1507
  cute_aux_tensors,
1508
  sparse_tensors_compile,
1509
+ current_stream,
1510
  options="--enable-tvm-ffi",
1511
  )
1512
  if not is_fake_mode():
 
1521
  dk if not dKV_postprocess else dk_accum,
1522
  dv if not dKV_postprocess else dv_accum,
1523
  softmax_scale,
 
1524
  cu_seqlens_q,
1525
  cu_seqlens_k,
1526
  seqused_q,
 
1535
  normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,
1536
  )
1537
 
1538
+ if arch // 10 == 9:
1539
+ # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg
1540
+ num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128
1541
+ num_threads_post_dKV = cfg.num_wg * 128
1542
+ else:
1543
+ num_threads_post_dQ = 128
1544
+ num_threads_post_dKV = 128
1545
+
1546
+ # Postprocess: convert dq_accum from float32 to dq in bf16/fp16
1547
+ _bwd_postprocess_convert(
1548
+ dq_accum, dq, softmax_scale,
1549
+ cu_seqlens_q, seqused_q,
1550
+ arch, dtype, head_dim, m_block_size, num_threads_post_dQ,
1551
+ AtomLayoutMdQ, dQ_swapAB,
1552
+ use_2cta_instrs=use_2cta_instrs, cluster_size=1,
 
1553
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1554
 
1555
  if dKV_postprocess:
1556
+ # Postprocess: convert dk_accum from float32 to dk in bf16/fp16
1557
+ _bwd_postprocess_convert(
1558
+ dk_accum, dk, softmax_scale,
1559
+ cu_seqlens_k, seqused_k,
1560
+ arch, dtype, head_dim, n_block_size, num_threads_post_dKV,
1561
+ AtomLayoutNdKV, dKV_swapAB,
1562
+ cluster_size=cluster_size,
 
 
 
 
 
 
 
 
1563
  )
1564
+ # Postprocess: convert dv_accum from float32 to dv in bf16/fp16
1565
+ _bwd_postprocess_convert(
1566
+ dv_accum, dv, 1.0,
1567
+ cu_seqlens_k, seqused_k,
1568
+ arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV,
1569
+ AtomLayoutNdKV, dKV_swapAB,
1570
+ cluster_size=cluster_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1571
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1572
 
1573
  return dq, dk, dv
1574
 
1575
 
 
1576
  _flash_attn_bwd.compile_cache = get_jit_cache("bwd")
 
1577
 
1578
 
1579
  class FlashAttnFunc(torch.autograd.Function):
 
1631
  ctx.window_size = window_size
1632
  ctx.softcap = softcap
1633
  ctx.deterministic = deterministic
1634
+ ctx.return_lse = return_lse
1635
+ ctx.set_materialize_grads(False)
 
1636
  return out, lse
1637
 
1638
  @staticmethod
1639
+ def backward(ctx, dout, dlse):
1640
  q, k, v, out, lse = ctx.saved_tensors
1641
+ if not ctx.return_lse:
1642
+ dlse = None
1643
+ if dout is None:
1644
+ dout = torch.zeros_like(out)
1645
  dq, dk, dv = _flash_attn_bwd(
1646
  q,
1647
  k,
 
1655
  window_size_left=ctx.window_size[0],
1656
  window_size_right=ctx.window_size[1],
1657
  deterministic=ctx.deterministic,
1658
+ dlse=dlse,
1659
  )
1660
  return dq, dk, dv, *((None,) * 20) # Extra Nones is fine
1661
 
 
1717
  ctx.deterministic = deterministic
1718
  ctx.max_seqlen_q = max_seqlen_q
1719
  ctx.max_seqlen_k = max_seqlen_k
1720
+ ctx.return_lse = return_lse
1721
+ ctx.set_materialize_grads(False)
 
1722
  return out, lse
1723
 
1724
  @staticmethod
1725
+ def backward(ctx, dout, dlse):
1726
  q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
1727
  assert ctx.softcap == 0.0
1728
+ if not ctx.return_lse:
1729
+ dlse = None
1730
+ if dout is None:
1731
+ dout = torch.zeros_like(out)
1732
  dq, dk, dv = _flash_attn_bwd(
1733
  q,
1734
  k,
 
1748
  max_seqlen_q=ctx.max_seqlen_q,
1749
  max_seqlen_k=ctx.max_seqlen_k,
1750
  deterministic=ctx.deterministic,
1751
+ dlse=dlse,
1752
  )
1753
 
1754
  return dq, dk, dv, *((None,) * 20)
 
1844
  )
1845
 
1846
 
1847
+ def _compile_fwd_combine(
1848
+ dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits,
1849
+ has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx,
1850
+ ):
1851
+ """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed)."""
1852
+ sym = cute.sym_int
1853
+ div = 128 // dtype_partial.width # 16-byte alignment in elements
1854
+
1855
+ fa_combine = FlashAttentionForwardCombine(
1856
+ dtype=dtype,
1857
+ dtype_partial=dtype_partial,
1858
+ head_dim=head_dim,
1859
+ tile_m=tile_m,
1860
+ k_block_size=k_block_size,
1861
+ log_max_splits=log_max_splits,
1862
+ )
1863
+ if not fa_combine.can_implement(
1864
+ dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits,
1865
+ num_threads=256,
1866
+ ):
1867
+ raise RuntimeError(
1868
+ "FlashAttention combine kernel cannot be implemented with given parameters"
1869
+ )
1870
+
1871
+ if has_cu_seqlens:
1872
+ # Varlen: (num_splits, total_q, nheads, headdim)
1873
+ num_splits, total_q, nheads = sym(), sym(), sym()
1874
+ mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div)
1875
+ mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1)
1876
+ mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div)
1877
+ mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None
1878
+ else:
1879
+ # Batched: (num_splits, batch, seqlen, nheads, headdim)
1880
+ num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym()
1881
+ mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div)
1882
+ mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2)
1883
+ mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div)
1884
+ mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None
1885
+ batch = mO_partial.shape[1]
1886
+
1887
+ batch_for_1d = batch if not has_cu_seqlens else sym()
1888
+ batchp1 = sym()
1889
+ mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None
1890
+ mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None
1891
+ mNumSplitsDynamic = None # Not parametrized in compile_key
1892
+ mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None
1893
+ mSemaphore = None # Not parametrized in compile_key
1894
+
1895
+ return cute.compile(
1896
+ fa_combine,
1897
+ mO_partial, mLSE_partial, mO, mLSE,
1898
+ mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore,
1899
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
1900
+ options="--enable-tvm-ffi",
1901
+ )
1902
+
1903
+
1904
  def _flash_attn_fwd_combine(
1905
  out_partial: torch.Tensor,
1906
  lse_partial: torch.Tensor,
 
1909
  cu_seqlens: Optional[torch.Tensor] = None,
1910
  seqused: Optional[torch.Tensor] = None,
1911
  num_splits_dynamic_ptr: Optional[torch.Tensor] = None,
1912
+ varlen_batch_idx: Optional[torch.Tensor] = None,
1913
  semaphore_to_reset: Optional[torch.Tensor] = None,
1914
  ) -> None:
1915
  """Forward combine kernel for split attention computation.
 
1933
  Returns:
1934
  None
1935
  """
 
 
 
1936
  assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (
1937
  "out_partial must be fp16, bf16, or fp32"
1938
  )
1939
+ if not is_fake_mode():
1940
+ assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device"
 
 
 
 
1941
  # Determine if this is variable length based on dimensions
1942
  is_varlen = out_partial.dim() == 4
 
 
 
 
 
 
 
1943
  # Validate optional tensors
1944
  for t, name in [
1945
  (cu_seqlens, "cu_seqlens"),
 
1947
  (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"),
1948
  ]:
1949
  if t is not None:
1950
+ if not is_fake_mode():
1951
+ assert t.is_cuda, f"{name} must be on CUDA device"
1952
  assert t.is_contiguous(), f"{name} must be contiguous"
 
1953
  head_dim = out_partial.shape[-1]
1954
  num_splits = out_partial.shape[0]
1955
  assert num_splits <= 256
 
1958
  k_block_size = 64 if head_dim <= 64 else 128
1959
  # We want kBlockM to be as small as possible to maximize parallelism.
1960
  # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
1961
+ tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)
1962
  log_max_splits = max(math.ceil(math.log2(num_splits)), 4)
1963
+ if tile_m == 8:
1964
  # If kBlockM == 8 then the minimum number of splits is 32.
1965
  # TODO: we can deal w this by using 128 threads instead
1966
  log_max_splits = max(log_max_splits, 5)
1967
 
 
 
1968
  # Create combine kernel configuration
1969
  dtype = torch2cute_dtype_map[out.dtype]
1970
  dtype_partial = torch2cute_dtype_map[out_partial.dtype]
 
1971
  compile_key = (
1972
  dtype,
1973
  dtype_partial,
1974
  head_dim,
1975
+ tile_m,
1976
  k_block_size,
1977
  log_max_splits,
1978
  cu_seqlens is not None,
1979
  seqused is not None,
1980
  lse is not None,
1981
+ varlen_batch_idx is not None,
1982
  )
 
1983
  if compile_key not in _flash_attn_fwd_combine.compile_cache:
1984
+ _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine(
1985
+ *compile_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1986
  )
1987
  if not is_fake_mode():
1988
  _flash_attn_fwd_combine.compile_cache[compile_key](
1989
+ out_partial, lse_partial, out, lse,
1990
+ cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx,
 
 
 
 
 
1991
  semaphore_to_reset,
 
1992
  )
1993
 
1994
 
 
2002
  out_dtype: Optional[torch.dtype] = None,
2003
  cu_seqlens: Optional[torch.Tensor] = None,
2004
  seqused: Optional[torch.Tensor] = None,
2005
+ varlen_batch_idx: Optional[torch.Tensor] = None,
2006
  return_lse: bool = True,
2007
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
2008
  """Flash Attention combine function for split attention computation.
 
2022
  out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.
2023
  cu_seqlens: Cumulative sequence lengths for variable length sequences
2024
  seqused: Used sequence lengths for each batch
2025
+ varlen_batch_idx: Optional mapping from virtual batch index to real batch index
2026
+ (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers
2027
+ that reorder batch processing for load balancing.
2028
  return_lse: Whether to return the combined LSE tensor. Default is True.
2029
 
2030
  Returns:
 
2041
  """
2042
  # Input validation
2043
  assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions"
 
 
 
 
2044
  # Determine if this is variable length based on dimensions
2045
  is_varlen = out_partial.dim() == 4
 
2046
  if is_varlen:
2047
  # Variable length: (num_splits, total_q, num_heads, head_size)
2048
  num_splits, total_q, num_heads, head_size = out_partial.shape
 
 
 
2049
  batch_size = 1 # Treat as single batch for varlen
2050
  seqlen = total_q
2051
  else:
2052
  # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)
2053
  num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape
 
 
 
 
2054
  # Determine output dtype
2055
  if out_dtype is None:
2056
  out_dtype = out_partial.dtype
 
2057
  # Create output if not provided
2058
  device = out_partial.device
2059
  if out is None:
 
2063
  out = torch.empty(
2064
  batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device
2065
  )
 
2066
  # Create lse output only if requested
2067
  if return_lse:
2068
  if is_varlen:
2069
+ lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device)
 
 
2070
  else:
2071
+ lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device)
2072
+ lse = lse.transpose(-1, -2)
 
2073
  else:
2074
  lse = None
 
2075
  _flash_attn_fwd_combine(
2076
  out_partial,
2077
  lse_partial,
 
2079
  lse,
2080
  cu_seqlens,
2081
  seqused,
2082
+ varlen_batch_idx=varlen_batch_idx,
2083
  )
2084
  return out, lse
build/torch-cuda/mask.py CHANGED
@@ -1,109 +1,102 @@
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
- from typing import Optional, Callable
4
  from dataclasses import dataclass
5
 
6
  import cutlass
7
  import cutlass.cute as cute
8
- from cutlass import Float32, Int32, const_expr
9
 
10
  from .quack import layout_utils
11
- from . import utils
12
  from .seqlen_info import SeqlenInfoQK
13
 
 
 
 
14
 
15
  @cute.jit
16
- def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None:
17
- # Bit manipulation, compiles down to the R2P instruction
18
- # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using.
19
- # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ...,
20
- # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
21
- if const_expr(arch == 90):
22
- col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2)
23
- else:
24
- col_limit_transformed = col_limit
25
- ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
26
- # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
27
- for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
28
- # Don't need to clamp to 32 since the shr.u32 instruction does that already
29
- col_limit_right_s = max(col_limit_transformed - s * 24, 0)
30
- # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
31
- mask = (1 << col_limit_right_s) - 1
32
- # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
33
- for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
34
- in_bound = cutlass.Boolean(mask & (1 << i))
35
- c = s * 24 + i
36
- if const_expr(rank1):
37
- X[c] = X[c] if in_bound else -Float32.inf
38
- # This is the equivalent of:
39
- # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf
40
- else:
41
- for r in cutlass.range_constexpr(cute.size(X.shape[0])):
42
- X[r, c] = X[r, c] if in_bound else -Float32.inf
43
 
44
 
45
  @cute.jit
46
- def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None:
47
- # Bit manipulation, compiles down to the R2P instruction
48
- # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127
49
- # or 0, 1, ..., 15, 32, ..., 47, 64, ...
50
- # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ...
51
- # Here we hardcode for the case of 2 warp groups.
52
- num_wg = 2
53
- row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min(
54
- row_limit_top % (num_rep * num_wg), num_rep
55
- )
56
- ncol = cute.size(X.shape)
57
- # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31
58
- for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
59
- row_limit_top_s = max(row_limit_top_transformed - s * 24, 0)
60
- # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11
61
- mask = (1 << row_limit_top_s) - 1
62
- # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
63
- for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
64
- out_bound = cutlass.Boolean(mask & (1 << i))
65
- c = s * 24 + i
66
- X[c] = -Float32.inf if out_bound else X[c]
67
- # tidx = cute.arch.thread_idx()[0] % 256
68
- # if tidx == 128:
69
- # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
70
 
71
 
72
  @cute.jit
73
- def mask_r2p_dual_bound(
74
  X: cute.Tensor,
75
- col_limit_left: Int32, # Inclusive lower bound
76
- col_limit_right: Int32, # Exclusive upper bound
77
  ) -> None:
78
- """
79
- Dual-bound masking using two bitmasks for SM100, following mask_r2p.
80
- Masks elements where: NOT (col_limit_left <= col < col_limit_right)
81
 
82
- Uses bit manipulation to create a range mask:
83
- mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1
84
- mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1
85
- mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
86
  """
87
- ncol = const_expr(cute.size(X.shape))
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
90
- right_s = max(col_limit_right - s * 24, 0)
91
- left_s = max(col_limit_left - s * 24, 0)
92
 
93
- # otherwise cute dsl complains about python int too large to convert into c long
94
- right_s = min(right_s, 24)
95
- left_s = min(left_s, 24)
96
 
97
- # bits (right-1)..left are 1
98
- mask_right = (1 << right_s) - 1
99
- mask_left = (1 << left_s) - 1
100
- mask_range = mask_right & ~mask_left
 
101
 
102
- # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
103
- for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
104
- in_bound = cutlass.Boolean(mask_range & (1 << i))
105
- c = s * 24 + i
106
- X[c] = X[c] if in_bound else -Float32.inf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  @dataclass(frozen=True)
@@ -161,8 +154,7 @@ class AttentionMask:
161
  seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
162
  if const_expr(not mask_causal and not mask_local and mask_mod is None):
163
  if const_expr(mask_seqlen):
164
- # The compiler now choses not to use R2P
165
- r2p = const_expr(False and not self.swap_AB)
166
  if const_expr(not r2p):
167
  # traverse column index.
168
  for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
@@ -170,7 +162,8 @@ class AttentionMask:
170
  for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
171
  acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
172
  else:
173
- mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90)
 
174
 
175
  elif const_expr(
176
  not mask_causal and not mask_local and mask_mod is not None
@@ -272,7 +265,12 @@ class AttentionMask:
272
  else acc_S_mn[r, c]
273
  )
274
  else:
275
- mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True)
 
 
 
 
 
276
  else: # Local
277
  local_row_offset_right = (
278
  causal_row_offset + self.window_size_right
@@ -284,6 +282,7 @@ class AttentionMask:
284
  if const_expr(self.window_size_left is not None)
285
  else None
286
  )
 
287
  for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
288
  if const_expr(self.qhead_per_kvhead_packgqa == 1):
289
  row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
@@ -302,13 +301,22 @@ class AttentionMask:
302
  if const_expr(self.window_size_left is not None)
303
  else 0
304
  )
305
- # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left)
306
- # traverse column index.
307
- for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
308
- col_idx = t0ScS_mn[0, c][1]
309
- # only consider the column index, so the row index sets to 0.
310
- if col_idx >= col_limit_right or col_idx < col_limit_left:
311
- acc_S_mn[r, c] = -Float32.inf
 
 
 
 
 
 
 
 
 
312
  else: # swap_AB
313
  assert self.qhead_per_kvhead_packgqa == 1
314
  thr_row_offset = tScS_mn[0][ROW]
@@ -338,11 +346,18 @@ class AttentionMask:
338
  # column, by setting row limit to be self.tile_m.
339
  row_limit_top = (
340
  self.tile_m
341
- if col0 >= seqlenk_col_limit
342
- else col0 - causal_row_offset - self.window_size_right
 
 
 
 
 
 
 
 
 
343
  )
344
- # TODO: do we need col_limit_sink?
345
- row_limit_bot = col0 - causal_row_offset + self.window_size_left
346
  for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
347
  row_idx = t0ScS_mn[r, 0][ROW]
348
  acc_S_mn[r, c] = (
@@ -392,7 +407,11 @@ class AttentionMask:
392
  # For some reason the 2 lines above generate really bad SASS
393
  acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
394
  else:
395
- mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True)
 
 
 
 
396
 
397
  elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
398
  # Block sparse case w/ mask_mod
@@ -445,12 +464,12 @@ class AttentionMask:
445
  acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
446
 
447
  else: # Causal or local
448
- causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q
449
  row_idx = tScS_t2r[0][0] + m_block * self.tile_m
450
  if const_expr(self.qhead_per_kvhead_packgqa != 1):
451
  row_idx = row_idx // self.qhead_per_kvhead_packgqa
452
  if const_expr(mask_causal):
453
- col_limit_right = row_idx + causal_row_offset
454
  if const_expr(mask_seqlen):
455
  col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
456
  # if cute.arch.thread_idx()[0] % 32 == 0:
@@ -460,15 +479,19 @@ class AttentionMask:
460
  for i in cutlass.range(ncol, unroll_full=True):
461
  acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
462
  else:
463
- mask_r2p(acc_S, col_limit_right, arch=100, rank1=True)
 
 
 
 
464
  else:
465
  local_row_offset_right = (
466
- causal_row_offset + self.window_size_right
467
  if const_expr(self.window_size_right is not None)
468
  else None
469
  )
470
  local_row_offset_left = (
471
- causal_row_offset - 1 - self.window_size_left
472
  if const_expr(self.window_size_left is not None)
473
  else None
474
  )
@@ -493,8 +516,15 @@ class AttentionMask:
493
  else acc_S[i]
494
  )
495
  else:
496
- # XOR-based R2P dual bound masking
497
- mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right)
 
 
 
 
 
 
 
498
 
499
  @cute.jit
500
  def apply_mask_sm100_transposed(
@@ -634,7 +664,13 @@ class AttentionMask:
634
  )
635
  else:
636
  num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
637
- mask_r2p_transposed(acc_S, row_limit_top, num_rep)
 
 
 
 
 
 
638
  else:
639
  if const_expr(self.window_size_right is not None):
640
  row_limit_top = causal_offset - self.window_size_right
@@ -645,9 +681,31 @@ class AttentionMask:
645
  if const_expr(mask_seqlen):
646
  if seqlenk_col_limit <= 0:
647
  row_limit_top = self.tile_m
648
- for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
649
- row_idx = t0ScS_t2r[i][ROW]
650
- local_mask = row_idx < row_limit_top
651
- if const_expr(self.window_size_left is not None):
652
- local_mask |= row_idx > row_limit_bot
653
- acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
+ from typing import Optional, Callable, TypeAlias
4
  from dataclasses import dataclass
5
 
6
  import cutlass
7
  import cutlass.cute as cute
8
+ from cutlass import Float32, Int32, Uint32, const_expr
9
 
10
  from .quack import layout_utils
11
+ from . import utils as utils
12
  from .seqlen_info import SeqlenInfoQK
13
 
14
+ MaskGenFn: TypeAlias = Callable[[int], Uint32]
15
+ MASK_R2P_CHUNK_SIZE: int = 32
16
+
17
 
18
  @cute.jit
19
+ def r2p_bitmask_below(limit: Int32, s: int) -> Uint32:
20
+ """32-bit R2P bitmask keeping positions < limit (exclusive upper bound).
21
+
22
+ Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).
23
+ Uses inline PTX to avoid shift-by-type-width UB.
24
+ """
25
+ m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0)
26
+ return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  @cute.jit
30
+ def r2p_bitmask_above(limit: Int32, s: int) -> Uint32:
31
+ """32-bit R2P bitmask keeping positions >= limit (inclusive lower bound).
32
+
33
+ Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).
34
+ Uses inline PTX to avoid shift-by-type-width UB.
35
+ """
36
+ n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0)
37
+ return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  @cute.jit
41
+ def mask_r2p_lambda(
42
  X: cute.Tensor,
43
+ mask_gen_fn: cutlass.Constexpr[MaskGenFn],
44
+ rank1: bool = False,
45
  ) -> None:
46
+ """Apply R2P masking with a custom bitmask generator.
 
 
47
 
48
+ mask_gen_fn(chunk_idx: constexpr int) -> Uint32:
49
+ Returns a 32-bit bitmask for the chunk. Bit i set means column
50
+ chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf.
 
51
  """
52
+ ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
53
+ # 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep).
54
+ CHUNK_SIZE = MASK_R2P_CHUNK_SIZE
55
+ for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)):
56
+ mask = mask_gen_fn(s)
57
+ # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
58
+ for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)):
59
+ in_bound = cutlass.Boolean(mask & (Uint32(1) << i))
60
+ c = s * CHUNK_SIZE + i
61
+ if const_expr(rank1):
62
+ X[c] = X[c] if in_bound else -Float32.inf
63
+ else:
64
+ for r in cutlass.range_constexpr(cute.size(X.shape[0])):
65
+ X[r, c] = X[r, c] if in_bound else -Float32.inf
66
 
 
 
 
67
 
68
+ @cute.jit
69
+ def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32:
70
+ """Transform SM90 MMA column coordinate to R2P element index.
71
 
72
+ SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ...
73
+ Element indices are contiguous: 0, 1, 2, 3, 4, 5, ...
74
+ This converts a column-space threshold to element-space for r2p_bitmask_below/above.
75
+ """
76
+ return col_limit // 8 * 2 + min(col_limit % 8, 2)
77
 
78
+
79
+ @cute.jit
80
+ def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
81
+ """Convert a row coordinate to an R2P element index in the warp-group interleaved layout.
82
+
83
+ In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom
84
+ distributes rows in an interleaved pattern: elements 0..num_rep-1 map to
85
+ rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to
86
+ rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on.
87
+ Row-coordinate thresholds (causal limits, window bounds, uih_len) must be
88
+ converted to element indices before use with r2p_bitmask_above/below.
89
+
90
+ Rows not owned by this thread (in the gap between warp groups) are clamped
91
+ to the boundary element index, which is safe because R2P thresholds are
92
+ monotonic.
93
+
94
+ Example with num_rep=16, num_wg=2:
95
+ row 0 -> elem 0, row 15 -> elem 15,
96
+ row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped),
97
+ row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31.
98
+ """
99
+ return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)
100
 
101
 
102
  @dataclass(frozen=True)
 
154
  seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
155
  if const_expr(not mask_causal and not mask_local and mask_mod is None):
156
  if const_expr(mask_seqlen):
157
+ r2p = const_expr(not self.swap_AB)
 
158
  if const_expr(not r2p):
159
  # traverse column index.
160
  for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
 
162
  for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
163
  acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
164
  else:
165
+ seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit)
166
+ mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s))
167
 
168
  elif const_expr(
169
  not mask_causal and not mask_local and mask_mod is not None
 
265
  else acc_S_mn[r, c]
266
  )
267
  else:
268
+ col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right)
269
+ mask_r2p_lambda(
270
+ acc_S_mn[r, None],
271
+ lambda s: r2p_bitmask_below(col_limit_r2p, s),
272
+ rank1=True,
273
+ )
274
  else: # Local
275
  local_row_offset_right = (
276
  causal_row_offset + self.window_size_right
 
282
  if const_expr(self.window_size_left is not None)
283
  else None
284
  )
285
+ r2p_local = const_expr(not self.swap_AB)
286
  for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
287
  if const_expr(self.qhead_per_kvhead_packgqa == 1):
288
  row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
 
301
  if const_expr(self.window_size_left is not None)
302
  else 0
303
  )
304
+ if const_expr(not r2p_local):
305
+ # traverse column index.
306
+ for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
307
+ col_idx = t0ScS_mn[0, c][1]
308
+ if col_idx >= col_limit_right or col_idx < col_limit_left:
309
+ acc_S_mn[r, c] = -Float32.inf
310
+ else:
311
+ col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right)
312
+ col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left)
313
+
314
+ def mask_gen_fn(s: int) -> Uint32:
315
+ return r2p_bitmask_below(
316
+ col_limit_right_r2p, s
317
+ ) & r2p_bitmask_above(col_limit_left_r2p, s)
318
+
319
+ mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True)
320
  else: # swap_AB
321
  assert self.qhead_per_kvhead_packgqa == 1
322
  thr_row_offset = tScS_mn[0][ROW]
 
346
  # column, by setting row limit to be self.tile_m.
347
  row_limit_top = (
348
  self.tile_m
349
+ if col0 >= seqlenk_col_limit and mask_seqlen
350
+ else (
351
+ col0 - causal_row_offset - self.window_size_right
352
+ if const_expr(self.window_size_right is not None)
353
+ else 0
354
+ )
355
+ )
356
+ row_limit_bot = (
357
+ col0 - causal_row_offset + self.window_size_left
358
+ if const_expr(self.window_size_left is not None)
359
+ else self.tile_m
360
  )
 
 
361
  for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
362
  row_idx = t0ScS_mn[r, 0][ROW]
363
  acc_S_mn[r, c] = (
 
407
  # For some reason the 2 lines above generate really bad SASS
408
  acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
409
  else:
410
+ mask_r2p_lambda(
411
+ acc_S,
412
+ lambda s: r2p_bitmask_below(seqlenk_col_limit, s),
413
+ rank1=True,
414
+ )
415
 
416
  elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
417
  # Block sparse case w/ mask_mod
 
464
  acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
465
 
466
  else: # Causal or local
467
+ causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q
468
  row_idx = tScS_t2r[0][0] + m_block * self.tile_m
469
  if const_expr(self.qhead_per_kvhead_packgqa != 1):
470
  row_idx = row_idx // self.qhead_per_kvhead_packgqa
471
  if const_expr(mask_causal):
472
+ col_limit_right = row_idx + causal_row_offset + 1
473
  if const_expr(mask_seqlen):
474
  col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
475
  # if cute.arch.thread_idx()[0] % 32 == 0:
 
479
  for i in cutlass.range(ncol, unroll_full=True):
480
  acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
481
  else:
482
+ mask_r2p_lambda(
483
+ acc_S,
484
+ lambda s: r2p_bitmask_below(col_limit_right, s),
485
+ rank1=True,
486
+ )
487
  else:
488
  local_row_offset_right = (
489
+ causal_row_offset + 1 + self.window_size_right
490
  if const_expr(self.window_size_right is not None)
491
  else None
492
  )
493
  local_row_offset_left = (
494
+ causal_row_offset - self.window_size_left
495
  if const_expr(self.window_size_left is not None)
496
  else None
497
  )
 
516
  else acc_S[i]
517
  )
518
  else:
519
+ # Dual-bound R2P masking for SM100.
520
+ # Masks elements where: NOT (col_limit_left <= col < col_limit_right)
521
+
522
+ def mask_gen_fn(s: int) -> Uint32:
523
+ return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above(
524
+ col_limit_left, s
525
+ )
526
+
527
+ mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True)
528
 
529
  @cute.jit
530
  def apply_mask_sm100_transposed(
 
664
  )
665
  else:
666
  num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
667
+ num_wg = 2
668
+ row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)
669
+ mask_r2p_lambda(
670
+ acc_S,
671
+ lambda s: r2p_bitmask_above(row_limit, s),
672
+ rank1=True,
673
+ )
674
  else:
675
  if const_expr(self.window_size_right is not None):
676
  row_limit_top = causal_offset - self.window_size_right
 
681
  if const_expr(mask_seqlen):
682
  if seqlenk_col_limit <= 0:
683
  row_limit_top = self.tile_m
684
+ r2p = True
685
+ if const_expr(not r2p):
686
+ for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
687
+ row_idx = t0ScS_t2r[i][ROW]
688
+ local_mask = row_idx < row_limit_top
689
+ if const_expr(self.window_size_left is not None):
690
+ local_mask |= row_idx > row_limit_bot
691
+ acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
692
+ else:
693
+
694
+ def mask_gen_fn(s: int) -> Uint32:
695
+ num_rep = cute.size(tScS_t2r, mode=[0])
696
+ num_wg = 2
697
+
698
+ row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)
699
+ mask = r2p_bitmask_above(row_limit, s)
700
+
701
+ if const_expr(self.window_size_left is not None):
702
+ row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg)
703
+ mask = mask & r2p_bitmask_below(row_limit_bottom, s)
704
+
705
+ return mask
706
+
707
+ mask_r2p_lambda(
708
+ acc_S,
709
+ mask_gen_fn,
710
+ rank1=True,
711
+ )
build/torch-cuda/named_barrier.py CHANGED
@@ -12,6 +12,19 @@ class NamedBarrierFwd(enum.IntEnum):
12
  PEmpty = enum.auto()
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class NamedBarrierBwd(enum.IntEnum):
16
  Epilogue = enum.auto()
17
  WarpSchedulerWG1 = enum.auto()
@@ -20,8 +33,10 @@ class NamedBarrierBwd(enum.IntEnum):
20
  PdS = enum.auto()
21
  dQFullWG0 = enum.auto()
22
  dQFullWG1 = enum.auto()
 
23
  dQEmptyWG0 = enum.auto()
24
  dQEmptyWG1 = enum.auto()
 
25
 
26
 
27
  class NamedBarrierBwdSm100(enum.IntEnum):
 
12
  PEmpty = enum.auto()
13
 
14
 
15
+ class NamedBarrierFwdSm100(enum.IntEnum):
16
+ Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
17
+ TmemPtr = enum.auto()
18
+ SoftmaxStatsW0 = enum.auto()
19
+ SoftmaxStatsW1 = enum.auto()
20
+ SoftmaxStatsW2 = enum.auto()
21
+ SoftmaxStatsW3 = enum.auto()
22
+ SoftmaxStatsW4 = enum.auto()
23
+ SoftmaxStatsW5 = enum.auto()
24
+ SoftmaxStatsW6 = enum.auto()
25
+ SoftmaxStatsW7 = enum.auto()
26
+
27
+
28
  class NamedBarrierBwd(enum.IntEnum):
29
  Epilogue = enum.auto()
30
  WarpSchedulerWG1 = enum.auto()
 
33
  PdS = enum.auto()
34
  dQFullWG0 = enum.auto()
35
  dQFullWG1 = enum.auto()
36
+ dQFullWG2 = enum.auto()
37
  dQEmptyWG0 = enum.auto()
38
  dQEmptyWG1 = enum.auto()
39
+ dQEmptyWG2 = enum.auto()
40
 
41
 
42
  class NamedBarrierBwdSm100(enum.IntEnum):
build/torch-cuda/pack_gqa.py CHANGED
@@ -1,25 +1,123 @@
1
  # Copyright (c) 2025, Tri Dao.
2
 
 
 
3
 
4
  import cutlass
5
  import cutlass.cute as cute
 
 
6
 
7
  from .quack import layout_utils
8
- from . import utils
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class PackGQA:
12
- def __init__(
13
- self,
14
- m_block_size: cutlass.Constexpr[int],
15
- head_dim_padded: cutlass.Constexpr[int],
16
- check_hdim_oob: cutlass.Constexpr[bool],
17
- qhead_per_kvhead: cutlass.Constexpr[bool],
18
- ):
19
- self.m_block_size = m_block_size
20
- self.head_dim_padded = head_dim_padded
21
- self.check_hdim_oob = check_hdim_oob
22
- self.qhead_per_kvhead = qhead_per_kvhead
23
 
24
  @cute.jit
25
  def compute_ptr(
 
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
+ from dataclasses import dataclass
4
+ from typing import Union, Tuple
5
 
6
  import cutlass
7
  import cutlass.cute as cute
8
+ from cutlass.cute.nvgpu import cpasync
9
+
10
 
11
  from .quack import layout_utils
12
+ from . import utils as utils
13
+
14
+
15
+ def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx):
16
+ """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0).
17
+
18
+ The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
19
+ are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
20
+ as-is (e.g. batch).
21
+
22
+ For Q/O tensors (head_idx=2):
23
+ (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...)
24
+ For LSE tensors (head_idx=1):
25
+ (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...)
26
+ """
27
+ head_stride = T.stride[head_idx]
28
+ shape_packed = (
29
+ (qhead_per_kvhead, T.shape[0]),
30
+ *[T.shape[i] for i in range(1, head_idx)],
31
+ nheads_kv,
32
+ *[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
33
+ )
34
+ stride_packed = (
35
+ (head_stride, T.stride[0]),
36
+ *[T.stride[i] for i in range(1, head_idx)],
37
+ head_stride * qhead_per_kvhead,
38
+ *[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
39
+ )
40
+ return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))
41
+
42
+
43
+ def make_packgqa_tiled_tma_atom(
44
+ op: cute.atom.CopyOp,
45
+ gmem_tensor: cute.Tensor,
46
+ smem_layout: Union[cute.Layout, cute.ComposedLayout],
47
+ cta_tiler: Tuple[int, int],
48
+ qhead_per_kvhead: int,
49
+ head_idx: int,
50
+ ):
51
+ # This packing and unpacking of the layout is so that we keep the same TMA dimension as usual.
52
+ # e.g. for (seqlen, d, nheads, b) layout, we still have 4D TMA after packing to
53
+ # ((nheads, seqlen), d, b).
54
+ # If we instead pack directly to ((qhead_per_kvhead, seqlen), d, nheads_kv, b) we'd have 5D TMA.
55
+ # Pack headdim and seqlen dim into 1: (seqlen, d, nheads, b) -> ((nheads, seqlen), d, b)
56
+ gmem_tensor = layout_utils.select(
57
+ gmem_tensor, [head_idx, *range(head_idx), *range(head_idx + 1, cute.rank(gmem_tensor))]
58
+ )
59
+ gmem_tensor = cute.group_modes(gmem_tensor, 0, 2)
60
+ assert cta_tiler[0] % qhead_per_kvhead == 0, (
61
+ "CTA tile size in the seqlen dimension must be divisible by qhead_per_kvhead"
62
+ )
63
+ tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
64
+ op,
65
+ gmem_tensor,
66
+ smem_layout,
67
+ ((qhead_per_kvhead, cta_tiler[0] // qhead_per_kvhead), cta_tiler[1]), # No mcast
68
+ )
69
+ # Unpack from ((nheads, seqlen), d, b) -> ((qhead_per_kvhead, seqlen), d, nheads_kv, b)
70
+ T = tma_tensor
71
+ shape_packed = (
72
+ (qhead_per_kvhead, T.shape[0][1]),
73
+ *[T.shape[i] for i in range(1, head_idx)],
74
+ T.shape[0][0] // qhead_per_kvhead,
75
+ *[T.shape[i] for i in range(head_idx, len(T.shape))],
76
+ )
77
+ stride_packed = (
78
+ *[T.stride[i] for i in range(head_idx)],
79
+ T.stride[0][0] * qhead_per_kvhead,
80
+ *[T.stride[i] for i in range(head_idx, len(T.shape))],
81
+ )
82
+ tma_tensor = cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))
83
+ return tma_atom, tma_tensor
84
 
85
 
86
+ def unpack_gqa_layout(T, qhead_per_kvhead, head_idx):
87
+ """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0).
88
+
89
+ The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1)
90
+ are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept
91
+ as-is (e.g. batch).
92
+
93
+ For Q/O tensors (head_idx=2):
94
+ ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...)
95
+ For LSE tensors (head_idx=1):
96
+ ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...)
97
+ """
98
+ seqlen_stride = T.stride[0][1]
99
+ head_stride = T.stride[0][0]
100
+ shape_unpacked = (
101
+ T.shape[0][1],
102
+ *[T.shape[i] for i in range(1, head_idx)],
103
+ T.shape[head_idx] * qhead_per_kvhead,
104
+ *[T.shape[i] for i in range(head_idx + 1, len(T.shape))],
105
+ )
106
+ stride_unpacked = (
107
+ seqlen_stride,
108
+ *[T.stride[i] for i in range(1, head_idx)],
109
+ head_stride,
110
+ *[T.stride[i] for i in range(head_idx + 1, len(T.shape))],
111
+ )
112
+ return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked))
113
+
114
+
115
+ @dataclass
116
  class PackGQA:
117
+ m_block_size: cutlass.Constexpr[int]
118
+ head_dim_padded: cutlass.Constexpr[int]
119
+ check_hdim_oob: cutlass.Constexpr[bool]
120
+ qhead_per_kvhead: cutlass.Constexpr[bool]
 
 
 
 
 
 
 
121
 
122
  @cute.jit
123
  def compute_ptr(
build/torch-cuda/paged_kv.py CHANGED
@@ -28,6 +28,9 @@ class PagedKVManager(ParamsBase):
28
  head_dim_padded: cutlass.Constexpr[Int32]
29
  head_dim_v_padded: cutlass.Constexpr[Int32]
30
 
 
 
 
31
  gmem_threads_per_row: cutlass.Constexpr[Int32]
32
  page_entry_per_thread: Int32
33
  async_copy_elems: Int32
@@ -55,7 +58,11 @@ class PagedKVManager(ParamsBase):
55
  head_dim_v_padded: cutlass.Constexpr[Int32],
56
  num_threads: cutlass.Constexpr[Int32],
57
  dtype: Type[cutlass.Numeric],
 
58
  ):
 
 
 
59
  universal_copy_bits = 128
60
  async_copy_elems = universal_copy_bits // dtype.width
61
  dtype_bytes = dtype.width // 8
@@ -97,7 +104,8 @@ class PagedKVManager(ParamsBase):
97
  else:
98
  cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
99
  tVcV = gmem_thr_copy_KV.partition_S(cV)
100
- tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0])
 
101
 
102
  return PagedKVManager(
103
  mPageTable,
@@ -111,6 +119,8 @@ class PagedKVManager(ParamsBase):
111
  num_threads,
112
  head_dim_padded,
113
  head_dim_v_padded,
 
 
114
  gmem_threads_per_row,
115
  page_entry_per_thread,
116
  async_copy_elems,
@@ -146,13 +156,17 @@ class PagedKVManager(ParamsBase):
146
  @cute.jit
147
  def compute_X_ptr(self, K_or_V: str):
148
  tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
 
 
 
 
149
  for i in cutlass.range(self.page_entry_per_thread, unroll=1):
150
  page = self.tPrPage[i]
151
  page_offset = self.tPrPageOffset[i]
152
- if const_expr(K_or_V == "K"):
153
- tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint()
154
  else:
155
- tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint()
156
  return tPrXPtr
157
 
158
  @cute.jit
@@ -161,18 +175,24 @@ class PagedKVManager(ParamsBase):
161
 
162
  tPrXPtr = self.compute_X_ptr(K_or_V)
163
 
164
- # Finesse sX layout to be (M, N).
165
- sX_pi = cute.make_tensor(
166
- sX.iterator,
167
- cute.make_layout(
168
- (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
169
- stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
170
- ),
171
- )
 
 
 
 
 
 
172
 
173
- if const_expr(K_or_V == "V"):
174
- # Need to transpose V
175
- sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
176
 
177
  head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
178
  cX = cute.make_identity_tensor((self.n_block_size, head_dim))
 
28
  head_dim_padded: cutlass.Constexpr[Int32]
29
  head_dim_v_padded: cutlass.Constexpr[Int32]
30
 
31
+ arch: cutlass.Constexpr[Int32]
32
+ v_gmem_transposed: cutlass.Constexpr[bool]
33
+
34
  gmem_threads_per_row: cutlass.Constexpr[Int32]
35
  page_entry_per_thread: Int32
36
  async_copy_elems: Int32
 
58
  head_dim_v_padded: cutlass.Constexpr[Int32],
59
  num_threads: cutlass.Constexpr[Int32],
60
  dtype: Type[cutlass.Numeric],
61
+ arch: cutlass.Constexpr[int] = 100,
62
  ):
63
+ # SM100 transposes V in gmem to (dv, page_size, num_pages);
64
+ # SM90 keeps V as (page_size, dv, num_pages), same layout as K.
65
+ v_gmem_transposed = arch != 90
66
  universal_copy_bits = 128
67
  async_copy_elems = universal_copy_bits // dtype.width
68
  dtype_bytes = dtype.width // 8
 
104
  else:
105
  cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))
106
  tVcV = gmem_thr_copy_KV.partition_S(cV)
107
+ # When V is transposed in gmem, dv is shape[0]; otherwise dv is shape[1] (same as K)
108
+ tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0 if v_gmem_transposed else 1])
109
 
110
  return PagedKVManager(
111
  mPageTable,
 
119
  num_threads,
120
  head_dim_padded,
121
  head_dim_v_padded,
122
+ arch,
123
+ v_gmem_transposed,
124
  gmem_threads_per_row,
125
  page_entry_per_thread,
126
  async_copy_elems,
 
156
  @cute.jit
157
  def compute_X_ptr(self, K_or_V: str):
158
  tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
159
+ mX = self.mK_paged if const_expr(K_or_V == "K") else self.mV_paged
160
+ # K is always (page_size, d, num_pages). V matches K when not transposed,
161
+ # but is (dv, page_size, num_pages) when transposed (SM100).
162
+ transposed = const_expr(K_or_V == "V" and self.v_gmem_transposed)
163
  for i in cutlass.range(self.page_entry_per_thread, unroll=1):
164
  page = self.tPrPage[i]
165
  page_offset = self.tPrPageOffset[i]
166
+ if const_expr(transposed):
167
+ tPrXPtr[i] = utils.elem_pointer(mX, (0, page_offset, page)).toint()
168
  else:
169
+ tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint()
170
  return tPrXPtr
171
 
172
  @cute.jit
 
175
 
176
  tPrXPtr = self.compute_X_ptr(K_or_V)
177
 
178
+ if const_expr(self.arch == 90):
179
+ # SM90: sX is already stage-sliced by caller (sK[None, None, stage]).
180
+ # Flatten hierarchical modes to get (n_block_size, head_dim).
181
+ sX_pi = cute.group_modes(sX, 0, 1)
182
+ # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA)
183
+ else:
184
+ # SM100: Finesse sX layout to be (M, N).
185
+ sX_pi = cute.make_tensor(
186
+ sX.iterator,
187
+ cute.make_layout(
188
+ (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),
189
+ stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),
190
+ ),
191
+ )
192
 
193
+ if const_expr(K_or_V == "V"):
194
+ # Transpose smem V to match transposed gmem layout
195
+ sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))
196
 
197
  head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded
198
  cX = cute.make_identity_tensor((self.n_block_size, head_dim))
build/torch-cuda/pipeline.py CHANGED
@@ -1,6 +1,5 @@
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
- # import math
4
  from typing import Optional
5
  from dataclasses import dataclass
6
 
@@ -11,12 +10,31 @@ from cutlass.pipeline import PipelineState
11
  from cutlass.pipeline import PipelineUserType
12
  from cutlass.pipeline import NamedBarrier as NamedBarrierOg
13
  from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
 
14
  from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
15
  from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
16
  from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
17
  from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class PipelineStateSimple:
21
  """
22
  Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
@@ -25,9 +43,6 @@ class PipelineStateSimple:
25
  """
26
 
27
  def __init__(self, stages: int, phase_index: Int32):
28
- # assert stages < 2**16
29
- # self._log_stages = int(math.log2(stages))
30
- # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2."
31
  self._stages = stages
32
  self._phase_index = phase_index
33
 
@@ -36,13 +51,10 @@ class PipelineStateSimple:
36
 
37
  @property
38
  def stages(self) -> int:
39
- # return 1 << self._log_stages
40
  return self._stages
41
 
42
  @property
43
  def index(self) -> Int32:
44
- # return self._phase_index & 0xFFFF
45
- # return self._phase_index & ((1 << self._log_stages) - 1)
46
  if const_expr(self._stages == 1):
47
  return Int32(0)
48
  else:
@@ -50,11 +62,8 @@ class PipelineStateSimple:
50
 
51
  @property
52
  def phase(self) -> Int32:
53
- # return self._phase_index >> 16
54
  # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
55
  # take modulo 2. But in practice just passing the phase in without modulo works fine.
56
- # return (self._phase_index >> self._log_stages) % 2
57
- # return self._phase_index >> self._log_stages
58
  if const_expr(self._stages == 1):
59
  return self._phase_index
60
  else:
@@ -66,21 +75,6 @@ class PipelineStateSimple:
66
  else:
67
  self._phase_index += 1
68
 
69
- # def then_body(phase_index):
70
- # # XOR the phase bit and set the index to 0
71
- # return (phase_index & 0xFFFF0000) ^ (1 << 16)
72
-
73
- # def else_body(phase_index):
74
- # return phase_index
75
-
76
- # self._phase_index = if_generate(
77
- # (self._phase_index & 0xFFFF) == self.stages,
78
- # then_body,
79
- # else_body,
80
- # [self._phase_index],
81
- # [Int32],
82
- # )
83
-
84
  def __extract_mlir_values__(self):
85
  phase_index = self._phase_index
86
  return [phase_index.ir_value()]
@@ -94,7 +88,6 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
94
  Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
95
  """
96
  if type is PipelineUserType.Producer:
97
- # return PipelineStateSimple(stages, Int32(1 << 16))
98
  return PipelineStateSimple(stages, Int32(stages))
99
  elif type is PipelineUserType.Consumer:
100
  return PipelineStateSimple(stages, Int32(0))
@@ -102,14 +95,73 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
102
  assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  @dataclass(frozen=True)
106
  class NamedBarrier(NamedBarrierOg):
107
- @staticmethod
108
- def create(*args, **kwargs):
109
- obj = NamedBarrierOg.create(*args, **kwargs)
110
- # Can't assign to __class__ directly since the dataclass is frozen
111
- object.__setattr__(obj, "__class__", NamedBarrier)
112
- return obj
113
 
114
  @dsl_user_op
115
  def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
@@ -134,72 +186,121 @@ class NamedBarrier(NamedBarrierOg):
134
  )
135
 
136
 
 
 
 
 
 
 
137
  @dataclass(frozen=True)
138
- class PipelineAsync(PipelineAsyncOg):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @staticmethod
140
- def create(*args, **kwargs):
 
 
 
 
 
 
 
141
  obj = PipelineAsyncOg.create(*args, **kwargs)
142
- # Can't assign to __class__ directly since the dataclass is frozen
143
- # obj.__class__ = PipelineAsync
144
  object.__setattr__(obj, "__class__", PipelineAsync)
 
 
 
 
145
  return obj
146
 
147
  @dsl_user_op
148
- def producer_acquire_w_index_phase(
149
- self,
150
- index: Int32,
151
- phase: Int32,
152
- try_acquire_token: Optional[Boolean] = None,
153
- *,
154
- loc=None,
155
- ip=None,
156
- ):
157
- if_generate(
158
- try_acquire_token is None or try_acquire_token == 0,
159
- lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
160
- loc=loc,
161
- ip=ip,
162
  )
163
 
164
  @dsl_user_op
165
- def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
166
- self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
167
-
168
- @dsl_user_op
169
- def consumer_wait_w_index_phase(
170
- self,
171
- index: Int32,
172
- phase: Int32,
173
- try_wait_token: Optional[Boolean] = None,
174
- *,
175
- loc=None,
176
- ip=None,
177
- ):
178
- if_generate(
179
- try_wait_token is None or try_wait_token == 0,
180
- lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
181
- loc=loc,
182
- ip=ip,
183
  )
184
 
185
- @dsl_user_op
186
- def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
187
- self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
 
 
188
 
189
 
190
  @dataclass(frozen=True)
191
- class PipelineTmaAsync(PipelineTmaAsyncOg):
192
- """
193
- Override producer_acquire to take in extra_tx_count parameter.
194
- """
195
 
196
  @staticmethod
197
- def create(*args, **kwargs):
198
- obj = PipelineTmaAsyncOg.create(*args, **kwargs)
199
- # Can't assign to __class__ directly since the dataclass is frozen
200
- object.__setattr__(obj, "__class__", PipelineTmaAsync)
 
 
 
 
 
 
201
  return obj
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  @dsl_user_op
204
  def producer_acquire(
205
  self,
@@ -226,19 +327,15 @@ class PipelineTmaAsync(PipelineTmaAsyncOg):
226
  self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
227
 
228
 
229
- @dataclass(frozen=True)
230
- class PipelineTmaUmma(PipelineTmaUmmaOg):
231
- """
232
- Override producer_acquire to take in extra_tx_count parameter.
233
- """
234
 
235
- @staticmethod
236
- def create(*args, **kwargs):
237
- obj = PipelineTmaUmmaOg.create(*args, **kwargs)
238
- # Can't assign to __class__ directly since the dataclass is frozen
239
- # obj.__class__ = PipelineTmaUmma
240
- object.__setattr__(obj, "__class__", PipelineTmaUmma)
241
- return obj
242
 
243
  @dsl_user_op
244
  def producer_acquire(
@@ -279,162 +376,27 @@ class PipelineTmaUmma(PipelineTmaUmmaOg):
279
  ip=ip,
280
  )
281
 
282
- @dsl_user_op
283
- def producer_acquire_w_index_phase(
284
- self,
285
- index: Int32,
286
- phase: Int32,
287
- try_acquire_token: Optional[Boolean] = None,
288
- *,
289
- loc=None,
290
- ip=None,
291
- ):
292
- """
293
- TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.
294
- """
295
- if_generate(
296
- try_acquire_token is None or try_acquire_token == 0,
297
- lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
298
- loc=loc,
299
- ip=ip,
300
- )
301
- if_generate(
302
- self.is_leader_cta,
303
- lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),
304
- loc=loc,
305
- ip=ip,
306
- )
307
 
308
- @dsl_user_op
309
- def consumer_wait_w_index_phase(
310
- self,
311
- index: Int32,
312
- phase: Int32,
313
- try_wait_token: Optional[Boolean] = None,
314
- *,
315
- loc=None,
316
- ip=None,
317
- ):
318
- if_generate(
319
- try_wait_token is None or try_wait_token == 0,
320
- lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
321
- loc=loc,
322
- ip=ip,
323
- )
324
 
325
- @dsl_user_op
326
- def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
327
- """
328
- UMMA consumer release buffer empty, cta_group needs to be provided.
329
- """
330
- self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
331
 
332
 
333
  @dataclass(frozen=True)
334
- class PipelineUmmaAsync(PipelineUmmaAsyncOg):
335
- @staticmethod
336
- def create(*args, **kwargs):
337
- obj = PipelineUmmaAsyncOg.create(*args, **kwargs)
338
- # Can't assign to __class__ directly since the dataclass is frozen
339
- object.__setattr__(obj, "__class__", PipelineUmmaAsync)
340
- return obj
341
 
342
- @dsl_user_op
343
- def producer_acquire_w_index_phase(
344
- self,
345
- index: Int32,
346
- phase: Int32,
347
- try_acquire_token: Optional[Boolean] = None,
348
- *,
349
- loc=None,
350
- ip=None,
351
- ):
352
- if_generate(
353
- try_acquire_token is None or try_acquire_token == 0,
354
- lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
355
- loc=loc,
356
- ip=ip,
357
- )
358
 
359
- @dsl_user_op
360
- def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
361
- """
362
- UMMA producer commit buffer full, cta_group needs to be provided.
363
- """
364
- self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)
365
 
366
- @dsl_user_op
367
- def consumer_wait_w_index_phase(
368
- self,
369
- index: Int32,
370
- phase: Int32,
371
- try_wait_token: Optional[Boolean] = None,
372
- *,
373
- loc=None,
374
- ip=None,
375
- ):
376
- if_generate(
377
- try_wait_token is None or try_wait_token == 0,
378
- lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
379
- loc=loc,
380
- ip=ip,
381
- )
382
 
383
- @dsl_user_op
384
- def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
385
- self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)
386
 
387
 
388
  @dataclass(frozen=True)
389
- class PipelineAsyncUmma(PipelineAsyncUmmaOg):
390
- @staticmethod
391
- def create(*args, **kwargs):
392
- obj = PipelineAsyncUmmaOg.create(*args, **kwargs)
393
- # Can't assign to __class__ directly since the dataclass is frozen
394
- object.__setattr__(obj, "__class__", PipelineAsyncUmma)
395
- return obj
396
 
397
- @dsl_user_op
398
- def producer_acquire_w_index_phase(
399
- self,
400
- index: Int32,
401
- phase: Int32,
402
- try_acquire_token: Optional[Boolean] = None,
403
- *,
404
- loc=None,
405
- ip=None,
406
- ):
407
- if_generate(
408
- try_acquire_token is None or try_acquire_token == 0,
409
- lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),
410
- loc=loc,
411
- ip=ip,
412
- )
413
-
414
- @dsl_user_op
415
- def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
416
- self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)
417
-
418
- @dsl_user_op
419
- def consumer_wait_w_index_phase(
420
- self,
421
- index: Int32,
422
- phase: Int32,
423
- try_wait_token: Optional[Boolean] = None,
424
- *,
425
- loc=None,
426
- ip=None,
427
- ):
428
- if_generate(
429
- try_wait_token is None or try_wait_token == 0,
430
- lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),
431
- loc=loc,
432
- ip=ip,
433
- )
434
 
435
- @dsl_user_op
436
- def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
437
- """
438
- UMMA consumer release buffer empty, cta_group needs to be provided.
439
- """
440
- self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)
 
1
  # Copyright (c) 2025, Tri Dao.
2
 
 
3
  from typing import Optional
4
  from dataclasses import dataclass
5
 
 
10
  from cutlass.pipeline import PipelineUserType
11
  from cutlass.pipeline import NamedBarrier as NamedBarrierOg
12
  from cutlass.pipeline import PipelineAsync as PipelineAsyncOg
13
+ from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg
14
  from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg
15
  from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg
16
  from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg
17
  from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg
18
 
19
 
20
+ def _override_create(parent_cls, child_cls):
21
+ """Create a static factory that constructs parent_cls then re-classes to child_cls."""
22
+
23
+ @staticmethod
24
+ def create(*args, **kwargs):
25
+ obj = parent_cls.create(*args, **kwargs)
26
+ # Can't assign to __class__ directly since the dataclass is frozen
27
+ object.__setattr__(obj, "__class__", child_cls)
28
+ return obj
29
+
30
+ return create
31
+
32
+
33
+ def _make_state(index: Int32, phase: Int32) -> PipelineState:
34
+ """Construct a PipelineState from index and phase (count/stages unused by callers)."""
35
+ return PipelineState(stages=0, count=Int32(0), index=index, phase=phase)
36
+
37
+
38
  class PipelineStateSimple:
39
  """
40
  Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.
 
43
  """
44
 
45
  def __init__(self, stages: int, phase_index: Int32):
 
 
 
46
  self._stages = stages
47
  self._phase_index = phase_index
48
 
 
51
 
52
  @property
53
  def stages(self) -> int:
 
54
  return self._stages
55
 
56
  @property
57
  def index(self) -> Int32:
 
 
58
  if const_expr(self._stages == 1):
59
  return Int32(0)
60
  else:
 
62
 
63
  @property
64
  def phase(self) -> Int32:
 
65
  # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to
66
  # take modulo 2. But in practice just passing the phase in without modulo works fine.
 
 
67
  if const_expr(self._stages == 1):
68
  return self._phase_index
69
  else:
 
75
  else:
76
  self._phase_index += 1
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def __extract_mlir_values__(self):
79
  phase_index = self._phase_index
80
  return [phase_index.ir_value()]
 
88
  Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
89
  """
90
  if type is PipelineUserType.Producer:
 
91
  return PipelineStateSimple(stages, Int32(stages))
92
  elif type is PipelineUserType.Consumer:
93
  return PipelineStateSimple(stages, Int32(0))
 
95
  assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
96
 
97
 
98
+ # ── Shared helpers ───────────────────────────────────────────────────────────
99
+
100
+
101
+ def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip):
102
+ """Optionally wrap a parent pipeline method call in sync_warp + elect_one."""
103
+ if const_expr(elect_one):
104
+ if const_expr(syncwarp):
105
+ cute.arch.sync_warp()
106
+ with cute.arch.elect_one():
107
+ parent_method(self, state, loc=loc, ip=ip)
108
+ else:
109
+ parent_method(self, state, loc=loc, ip=ip)
110
+
111
+
112
+ # ── Mixin: _w_index / _w_index_phase variants that delegate to parent ───────
113
+ # Each parent class has PipelineState-based methods (producer_acquire, producer_commit,
114
+ # consumer_wait, consumer_release). The _w_index_phase variants just construct a
115
+ # PipelineState from (index, phase) and delegate.
116
+
117
+
118
+ class _PipelineIndexPhaseMixin:
119
+ """Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents."""
120
+
121
+ @dsl_user_op
122
+ def producer_acquire_w_index_phase(
123
+ self,
124
+ index: Int32,
125
+ phase: Int32,
126
+ try_acquire_token: Optional[Boolean] = None,
127
+ *,
128
+ loc=None,
129
+ ip=None,
130
+ ):
131
+ state = _make_state(index, phase)
132
+ # Call the parent's producer_acquire (which takes PipelineState)
133
+ self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip)
134
+
135
+ @dsl_user_op
136
+ def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):
137
+ state = _make_state(index, Int32(0))
138
+ self.producer_commit(state, loc=loc, ip=ip)
139
+
140
+ @dsl_user_op
141
+ def consumer_wait_w_index_phase(
142
+ self,
143
+ index: Int32,
144
+ phase: Int32,
145
+ try_wait_token: Optional[Boolean] = None,
146
+ *,
147
+ loc=None,
148
+ ip=None,
149
+ ):
150
+ state = _make_state(index, phase)
151
+ self.consumer_wait(state, try_wait_token, loc=loc, ip=ip)
152
+
153
+ @dsl_user_op
154
+ def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):
155
+ state = _make_state(index, Int32(0))
156
+ self.consumer_release(state, loc=loc, ip=ip)
157
+
158
+
159
+ # ── NamedBarrier ─────────────────────────────────────────────────────────────
160
+
161
+
162
  @dataclass(frozen=True)
163
  class NamedBarrier(NamedBarrierOg):
164
+ create = _override_create(NamedBarrierOg, None) # patched below
 
 
 
 
 
165
 
166
  @dsl_user_op
167
  def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:
 
186
  )
187
 
188
 
189
+ NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier)
190
+
191
+
192
+ # ── PipelineAsync ────────────────────────────────────────────────────────────
193
+
194
+
195
  @dataclass(frozen=True)
196
+ class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg):
197
+ """
198
+ PipelineAsync with optional elect_one for producer_commit and consumer_release.
199
+
200
+ When elect_one_*=True (set at create time), only one elected thread per warp
201
+ signals the barrier arrive. This is useful when the mask count is set to 1 per warp.
202
+
203
+ Args (to create):
204
+ elect_one_commit: If True, only elected thread signals producer_commit.
205
+ syncwarp_before_commit: If True (default), issue syncwarp before elect_one.
206
+ elect_one_release: If True, only elected thread signals consumer_release.
207
+ syncwarp_before_release: If True (default), issue syncwarp before elect_one.
208
+ Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group).
209
+ """
210
+
211
+ _elect_one_commit: bool = False
212
+ _syncwarp_before_commit: bool = True
213
+ _elect_one_release: bool = False
214
+ _syncwarp_before_release: bool = True
215
+
216
  @staticmethod
217
+ def create(
218
+ *args,
219
+ elect_one_commit: bool = False,
220
+ syncwarp_before_commit: bool = True,
221
+ elect_one_release: bool = False,
222
+ syncwarp_before_release: bool = True,
223
+ **kwargs,
224
+ ):
225
  obj = PipelineAsyncOg.create(*args, **kwargs)
 
 
226
  object.__setattr__(obj, "__class__", PipelineAsync)
227
+ object.__setattr__(obj, "_elect_one_commit", elect_one_commit)
228
+ object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit)
229
+ object.__setattr__(obj, "_elect_one_release", elect_one_release)
230
+ object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
231
  return obj
232
 
233
  @dsl_user_op
234
+ def producer_commit(self, state: PipelineState, *, loc=None, ip=None):
235
+ _call_with_elect_one(
236
+ PipelineAsyncOg.producer_commit,
237
+ self,
238
+ state,
239
+ self._elect_one_commit,
240
+ self._syncwarp_before_commit,
241
+ loc,
242
+ ip,
 
 
 
 
 
243
  )
244
 
245
  @dsl_user_op
246
+ def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
247
+ _call_with_elect_one(
248
+ PipelineAsyncOg.consumer_release,
249
+ self,
250
+ state,
251
+ self._elect_one_release,
252
+ self._syncwarp_before_release,
253
+ loc,
254
+ ip,
 
 
 
 
 
 
 
 
 
255
  )
256
 
257
+ # _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate
258
+ # to producer_commit / consumer_release above.
259
+
260
+
261
+ # ── PipelineCpAsync ──────────────────────────────────────────────────────────
262
 
263
 
264
  @dataclass(frozen=True)
265
+ class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg):
266
+ _elect_one_release: bool = False
267
+ _syncwarp_before_release: bool = True
 
268
 
269
  @staticmethod
270
+ def create(
271
+ *args,
272
+ elect_one_release: bool = False,
273
+ syncwarp_before_release: bool = True,
274
+ **kwargs,
275
+ ):
276
+ obj = PipelineCpAsyncOg.create(*args, **kwargs)
277
+ object.__setattr__(obj, "__class__", PipelineCpAsync)
278
+ object.__setattr__(obj, "_elect_one_release", elect_one_release)
279
+ object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release)
280
  return obj
281
 
282
+ @dsl_user_op
283
+ def consumer_release(self, state: PipelineState, *, loc=None, ip=None):
284
+ _call_with_elect_one(
285
+ PipelineCpAsyncOg.consumer_release,
286
+ self,
287
+ state,
288
+ self._elect_one_release,
289
+ self._syncwarp_before_release,
290
+ loc,
291
+ ip,
292
+ )
293
+
294
+ # _w_index variants inherited from _PipelineIndexPhaseMixin.
295
+
296
+
297
+ # ── PipelineTmaAsync ────────────────────────────────────────────────────────
298
+
299
+
300
+ @dataclass(frozen=True)
301
+ class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg):
302
+ """Override producer_acquire to take in extra_tx_count parameter."""
303
+
304
  @dsl_user_op
305
  def producer_acquire(
306
  self,
 
327
  self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)
328
 
329
 
330
+ PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync)
331
+
332
+
333
+ # ── PipelineTmaUmma ─────────────────────────────────────────────────────────
 
334
 
335
+
336
+ @dataclass(frozen=True)
337
+ class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg):
338
+ """Override producer_acquire to take in extra_tx_count parameter."""
 
 
 
339
 
340
  @dsl_user_op
341
  def producer_acquire(
 
376
  ip=ip,
377
  )
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
+
383
+ # ── PipelineUmmaAsync ───────────────────────────────────────────────────────
 
 
 
 
384
 
385
 
386
  @dataclass(frozen=True)
387
+ class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg):
388
+ pass
 
 
 
 
 
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync)
 
 
 
 
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
+ # ── PipelineAsyncUmma ───────────────────────────────────────────────────────
 
 
395
 
396
 
397
  @dataclass(frozen=True)
398
+ class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg):
399
+ pass
 
 
 
 
 
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
+ PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma)
 
 
 
 
 
build/torch-cuda/quack/copy_utils.py CHANGED
@@ -15,6 +15,9 @@ from cutlass._mlir.dialects import llvm
15
  from cutlass._mlir import ir
16
  from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
17
 
 
 
 
18
 
19
  Sm100MmaPeerBitMask = 0xFEFFFFFF
20
 
@@ -41,6 +44,30 @@ def cvt_copy(
41
  cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @dsl_user_op
45
  def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
46
  dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
@@ -796,17 +823,17 @@ def gather_m_get_copy_fn(
796
  limit_m: Int32,
797
  limit_k: Int32,
798
  ) -> Callable:
799
- tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
800
- tAsA = thr_copy_A.partition_D(sA)
801
  # k-major
802
  assert tAsA.shape[2] == 1
803
  tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
804
 
805
- is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
806
  if const_expr(not is_even_m_smem):
807
- limit_m = min(limit_m, tile_shape_mk[0])
808
  elems_per_load = cute.size(tAsA.shape[0][0])
809
- cA = cute.make_identity_tensor(tile_shape_mk)
810
  tAcA = thr_copy_A.partition_S(cA)
811
  t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
812
  # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
@@ -828,13 +855,13 @@ def gather_m_get_copy_fn(
828
  else:
829
  m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
830
 
831
- mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
832
 
833
  def copy_fn(src_idx, dst_idx, pred: bool = False):
834
  tApA_k = None
835
  if const_expr(pred):
836
  tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
837
- limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
838
  for k in cutlass.range(cols_per_thread, unroll_full=True):
839
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
840
  mA_cur = mA_k[None, (None, src_idx)]
@@ -997,11 +1024,162 @@ def gather_m_get_tma_copy_fn(
997
  tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
998
 
999
  def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
 
1000
  col_idx = tile_K * src_idx
1001
  for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1002
  row_indices = [tSR_rAIdx[v, m] for v in range(4)]
1003
- smem_ptr = tSR_sA[None, m, None, dst_idx].iterator
1004
  with cute.arch.elect_one():
1005
  tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
1006
 
1007
  return copy_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from cutlass._mlir import ir
16
  from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
17
 
18
+ from . import layout_utils
19
+ from .utils import make_vector
20
+
21
 
22
  Sm100MmaPeerBitMask = 0xFEFFFFFF
23
 
 
44
  cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
45
 
46
 
47
+ @dsl_user_op
48
+ def sr_cvt_copy(
49
+ tiled_copy: cute.TiledCopy,
50
+ src: cute.Tensor,
51
+ dst: cute.Tensor,
52
+ seed: Int32,
53
+ tidx: Int32,
54
+ *,
55
+ loc=None,
56
+ ip=None,
57
+ ) -> None:
58
+ """Like cvt_copy but uses stochastic rounding for FP32 -> BF16 conversion."""
59
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
60
+ from .rounding import convert_f32_to_bf16_sr
61
+ from cutlass.cute.tensor import TensorSSA
62
+
63
+ src_cvt = cute.make_rmem_tensor_like(src, dst.element_type)
64
+ src_vec = src.load()
65
+ raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx, loc=loc, ip=ip)
66
+ src_cvt.store(TensorSSA(raw_vec, src_vec.shape, dst.element_type))
67
+ src = src_cvt
68
+ cute.copy(tiled_copy, src, dst, loc=loc, ip=ip)
69
+
70
+
71
  @dsl_user_op
72
  def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
73
  dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip)
 
823
  limit_m: Int32,
824
  limit_k: Int32,
825
  ) -> Callable:
826
+ tile_M, tile_K = cute.size(sA, mode=[0]), cute.size(sA, mode=[1])
827
+ tAsA = partition_D_position_independent(thr_copy_A, sA)
828
  # k-major
829
  assert tAsA.shape[2] == 1
830
  tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
831
 
832
+ is_even_m_smem = tile_M % thr_copy_A.tiler_mn[0].shape == 0
833
  if const_expr(not is_even_m_smem):
834
+ limit_m = min(limit_m, tile_M)
835
  elems_per_load = cute.size(tAsA.shape[0][0])
836
+ cA = cute.make_identity_tensor((tile_M, tile_K))
837
  tAcA = thr_copy_A.partition_S(cA)
838
  t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
839
  # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
 
855
  else:
856
  m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
857
 
858
+ mA_k = cute.logical_divide(mA, (None, tile_K))
859
 
860
  def copy_fn(src_idx, dst_idx, pred: bool = False):
861
  tApA_k = None
862
  if const_expr(pred):
863
  tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
864
+ limit_k_cur = limit_k - src_idx * tile_K
865
  for k in cutlass.range(cols_per_thread, unroll_full=True):
866
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
867
  mA_cur = mA_k[None, (None, src_idx)]
 
1024
  tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
1025
 
1026
  def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
1027
+ tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
1028
  col_idx = tile_K * src_idx
1029
  for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1030
  row_indices = [tSR_rAIdx[v, m] for v in range(4)]
1031
+ smem_ptr = tSR_sA_cur[None, m, None].iterator
1032
  with cute.arch.elect_one():
1033
  tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
1034
 
1035
  return copy_fn
1036
+
1037
+
1038
+ @cute.jit
1039
+ def gather_k_get_tma_copy_fn(
1040
+ tma_atom: cute.CopyAtom,
1041
+ sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout
1042
+ sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem
1043
+ col_idx: Int32, # M offset in global tensor (contiguous dim for M-major)
1044
+ warp_idx: Int32,
1045
+ num_warps: int,
1046
+ num_cta: int = 1,
1047
+ ) -> Tuple[Callable, Callable]:
1048
+ """Build a copy function for TMA gather4 in K dimension (M-major A).
1049
+
1050
+ Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements.
1051
+ col_idx is the absolute M position in the global tensor.
1052
+ K indices come from sAIdx (prefetched to smem by the scheduler warp).
1053
+
1054
+ Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which:
1055
+ Issues gather4 calls with those K indices as row_indices
1056
+ """
1057
+ tile_K = cute.size(sAIdx, mode=[0])
1058
+ assert tile_K % 4 == 0
1059
+ cta_group = num_cta
1060
+
1061
+ # Tiled copy for loading K indices from smem to registers (4 per vector, across warps)
1062
+ copy_AIdx_s2r = cute.make_tiled_copy_tv(
1063
+ cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
1064
+ cute.make_layout(num_warps), # thr_layout
1065
+ cute.make_layout(4), # val_layout — 4 K indices per gather4
1066
+ )
1067
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
1068
+ warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
1069
+ tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4))
1070
+ # ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192))
1071
+ tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA))
1072
+ tma_desc_ptr = get_tma_desc_addr(tma_atom)
1073
+ tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
1074
+
1075
+ def prefetch_from_smem_fn(
1076
+ a_prefetch_pipeline,
1077
+ src_idx,
1078
+ dst_idx,
1079
+ a_prefetch_consumer_state,
1080
+ ) -> cute.Tensor:
1081
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
1082
+ tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx])
1083
+ cute.arch.sync_warp()
1084
+ with cute.arch.elect_one():
1085
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
1086
+ return tSR_rAIdx
1087
+
1088
+ def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer):
1089
+ # Issue gather4: col_idx = M position, row_indices = 4 K positions
1090
+ tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
1091
+ gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64
1092
+ for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1093
+ row_indices = [tSR_rAIdx[v, k] for v in range(4)]
1094
+ for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True):
1095
+ smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator
1096
+ with cute.arch.elect_one():
1097
+ tma_gather4_load_fn(
1098
+ smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices
1099
+ )
1100
+
1101
+ return copy_fn, prefetch_from_smem_fn
1102
+
1103
+
1104
+ # ---------------------------------------------------------------------------
1105
+ # Store helpers
1106
+ # ---------------------------------------------------------------------------
1107
+
1108
+
1109
+ @dsl_user_op
1110
+ @cute.jit
1111
+ def store(
1112
+ ptr: cute.Pointer,
1113
+ val,
1114
+ pred: Optional[Boolean] = None,
1115
+ cop: cutlass.Constexpr = None,
1116
+ *,
1117
+ loc=None,
1118
+ ip=None,
1119
+ ):
1120
+ """Store a scalar value via cute.arch.store.
1121
+
1122
+ ptr: cute.Pointer (any address space).
1123
+ val: DSL Numeric value.
1124
+ pred: None → unconditional. DSL Boolean → skipped when pred == 0.
1125
+ cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1126
+ """
1127
+ if const_expr(pred is None):
1128
+ cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
1129
+ else:
1130
+ if pred:
1131
+ cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip)
1132
+
1133
+
1134
+ @dsl_user_op
1135
+ @cute.jit
1136
+ def store_v2(
1137
+ ptr: cute.Pointer,
1138
+ v0,
1139
+ v1,
1140
+ pred: Optional[Boolean] = None,
1141
+ cop: cutlass.Constexpr = None,
1142
+ *,
1143
+ loc=None,
1144
+ ip=None,
1145
+ ):
1146
+ """Vectorized store of 2 elements via cute.arch.store.
1147
+
1148
+ Packs v0, v1 into an MLIR <2 x T> vector.
1149
+ ptr: cute.Pointer (any address space, must be aligned for vector width).
1150
+ cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1151
+ """
1152
+ vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip)
1153
+ if const_expr(pred is None):
1154
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1155
+ else:
1156
+ if pred:
1157
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1158
+
1159
+
1160
+ @dsl_user_op
1161
+ @cute.jit
1162
+ def store_v4(
1163
+ ptr: cute.Pointer,
1164
+ v0,
1165
+ v1,
1166
+ v2,
1167
+ v3,
1168
+ pred: Optional[Boolean] = None,
1169
+ cop: cutlass.Constexpr = None,
1170
+ *,
1171
+ loc=None,
1172
+ ip=None,
1173
+ ):
1174
+ """Vectorized store of 4 elements via cute.arch.store.
1175
+
1176
+ Packs v0–v3 into an MLIR <4 x T> vector.
1177
+ ptr: cute.Pointer (any address space, must be aligned for vector width).
1178
+ cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt".
1179
+ """
1180
+ vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip)
1181
+ if const_expr(pred is None):
1182
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
1183
+ else:
1184
+ if pred:
1185
+ cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip)
build/torch-cuda/quack/cute_dsl_utils.py CHANGED
@@ -4,6 +4,9 @@ from typing import Tuple, get_origin
4
  from functools import lru_cache
5
  from dataclasses import dataclass, fields
6
 
 
 
 
7
  import torch
8
 
9
  try:
@@ -14,7 +17,6 @@ except ImportError:
14
  import cutlass
15
  import cutlass.cute as cute
16
  from cutlass import Int32, Int64, Float16, BFloat16, Float32
17
- from cutlass.base_dsl.typing import JitArgument
18
  from cutlass.base_dsl.tvm_ffi_builder import spec
19
  from cutlass.cutlass_dsl import NumericMeta
20
 
@@ -65,8 +67,25 @@ def get_max_active_clusters(cluster_size):
65
  return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
66
 
67
 
 
 
 
 
 
 
 
 
 
68
  @lru_cache
69
  def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
 
 
 
 
 
 
 
 
70
  return torch.cuda.get_device_capability(device)
71
 
72
 
@@ -138,28 +157,3 @@ class ParamsBase:
138
  return values
139
 
140
  __new_from_mlir_values__ = _new_from_mlir_values
141
-
142
-
143
- @dataclass
144
- class ArgumentsBase(JitArgument):
145
- def __c_pointers__(self):
146
- _, non_constexpr_fields = _partition_fields(self)
147
- c_ptrs = []
148
- for obj in non_constexpr_fields.values():
149
- if hasattr(obj, "__c_pointers__"):
150
- c_ptrs.extend(obj.__c_pointers__())
151
- return c_ptrs
152
-
153
- def __get_mlir_types__(self):
154
- _, non_constexpr_fields = _partition_fields(self)
155
- types, self._values_pos = [], []
156
- for obj in non_constexpr_fields.values():
157
- if hasattr(obj, "__get_mlir_types__"):
158
- obj_types = obj.__get_mlir_types__()
159
- types.extend(obj_types)
160
- self._values_pos.append(len(obj_types))
161
- else:
162
- self._values_pos.append(0)
163
- return types
164
-
165
- __new_from_mlir_values__ = _new_from_mlir_values
 
4
  from functools import lru_cache
5
  from dataclasses import dataclass, fields
6
 
7
+ import os
8
+ import re
9
+
10
  import torch
11
 
12
  try:
 
17
  import cutlass
18
  import cutlass.cute as cute
19
  from cutlass import Int32, Int64, Float16, BFloat16, Float32
 
20
  from cutlass.base_dsl.tvm_ffi_builder import spec
21
  from cutlass.cutlass_dsl import NumericMeta
22
 
 
67
  return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
68
 
69
 
70
+ def _parse_arch_str(arch_str: str) -> Tuple[int, int]:
71
+ """Parse arch string (e.g. 'sm_90', 'sm90', '90', 'sm_100a') to (major, minor) tuple."""
72
+ match = re.match(r"^(?:sm_?)?(\d+)(\d)([af]?)$", arch_str.strip(), re.IGNORECASE)
73
+ if not match:
74
+ raise ValueError(f"Invalid QUACK_ARCH format: {arch_str!r} (expected e.g. '90', 'sm_90')")
75
+ major, minor, _ = match.groups()
76
+ return int(major), int(minor)
77
+
78
+
79
  @lru_cache
80
  def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
81
+ """Return (major, minor) device capability.
82
+
83
+ Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation
84
+ without a GPU present.
85
+ """
86
+ arch_override = os.environ.get("QUACK_ARCH")
87
+ if arch_override is not None:
88
+ return _parse_arch_str(arch_override)
89
  return torch.cuda.get_device_capability(device)
90
 
91
 
 
157
  return values
158
 
159
  __new_from_mlir_values__ = _new_from_mlir_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/quack/layout_utils.py CHANGED
@@ -295,3 +295,37 @@ def mma_partition_A_vec(
295
  sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
296
  tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
297
  return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
296
  tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
297
  return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
298
+
299
+
300
+ def copy_partition_S_vec(
301
+ sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
302
+ ) -> cute.Tensor:
303
+ assert cute.rank(sVec) == 2
304
+ assert sVec.stride[0] == 1
305
+ stage = sVec.shape[1]
306
+ shape = (
307
+ (sVec.shape[0], expand_shape, stage)
308
+ if const_expr(is_colvec)
309
+ else (expand_shape, sVec.shape[0], stage)
310
+ )
311
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
312
+ sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
313
+ tC_sVec = reshape_acc_to_mn(thr_copy.partition_S(sVec_thr))
314
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
315
+
316
+
317
+ def copy_partition_D_vec(
318
+ sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool
319
+ ) -> cute.Tensor:
320
+ assert cute.rank(sVec) == 2
321
+ assert sVec.stride[0] == 1
322
+ stage = sVec.shape[1]
323
+ shape = (
324
+ (sVec.shape[0], expand_shape, stage)
325
+ if const_expr(is_colvec)
326
+ else (expand_shape, sVec.shape[0], stage)
327
+ )
328
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
329
+ sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
330
+ tC_sVec = reshape_acc_to_mn(thr_copy.partition_D(sVec_thr))
331
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
build/torch-cuda/quack/utils.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ from cutlass import Float32, Int32, const_expr
10
+ from cutlass._mlir.dialects import arith as _arith
11
+ from cutlass._mlir.dialects import llvm, nvvm, vector
12
+ from cutlass.cutlass_dsl import T, dsl_user_op
13
+
14
+
15
+ @dsl_user_op
16
+ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
17
+ return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
18
+
19
+
20
+ @cute.jit
21
+ def load_scalar_or_pointer(x, dtype=Float32):
22
+ if const_expr(isinstance(x, cute.Pointer)):
23
+ return dtype(cute.make_tensor(x, cute.make_layout(1))[0])
24
+ else:
25
+ return x
26
+
27
+
28
+ @dsl_user_op
29
+ def set_block_rank(
30
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
31
+ ) -> Int32:
32
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
33
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
34
+ return Int32(
35
+ llvm.inline_asm(
36
+ T.i32(),
37
+ [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
38
+ "mapa.shared::cluster.u32 $0, $1, $2;",
39
+ "=r,r,r",
40
+ has_side_effects=False,
41
+ is_align_stack=False,
42
+ asm_dialect=llvm.AsmDialect.AD_ATT,
43
+ )
44
+ )
45
+
46
+
47
+ @dsl_user_op
48
+ def store_shared_remote(
49
+ val: float | Float32 | Int32 | cutlass.Int64,
50
+ smem_ptr: cute.Pointer,
51
+ mbar_ptr: cute.Pointer,
52
+ peer_cta_rank_in_cluster: cute.typing.Int,
53
+ *,
54
+ loc=None,
55
+ ip=None,
56
+ ) -> None:
57
+ remote_smem_ptr_i32 = set_block_rank(
58
+ smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
59
+ ).ir_value()
60
+ remote_mbar_ptr_i32 = set_block_rank(
61
+ mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
62
+ ).ir_value()
63
+ if const_expr(isinstance(val, float)):
64
+ val = Float32(val)
65
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
66
+ suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
67
+ constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
68
+ llvm.inline_asm(
69
+ None,
70
+ [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
71
+ f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
72
+ f"r,{constraint},r",
73
+ has_side_effects=True,
74
+ is_align_stack=False,
75
+ asm_dialect=llvm.AsmDialect.AD_ATT,
76
+ )
77
+
78
+
79
+ @dsl_user_op
80
+ def store_shared_remote_x4(
81
+ val0: Float32 | Int32,
82
+ val1: Float32 | Int32,
83
+ val2: Float32 | Int32,
84
+ val3: Float32 | Int32,
85
+ smem_ptr: cute.Pointer,
86
+ mbar_ptr: cute.Pointer,
87
+ peer_cta_rank_in_cluster: cute.typing.Int,
88
+ *,
89
+ loc=None,
90
+ ip=None,
91
+ ) -> None:
92
+ remote_smem_ptr_i32 = set_block_rank(
93
+ smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
94
+ ).ir_value()
95
+ remote_mbar_ptr_i32 = set_block_rank(
96
+ mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
97
+ ).ir_value()
98
+ assert isinstance(val0, (Float32, Int32)), "val must be Float32, or Int32"
99
+ dtype = Float32 if isinstance(val0, Float32) else Int32
100
+ suffix = {Float32: "f32", Int32: "s32"}[dtype]
101
+ constraint = {Float32: "f", Int32: "r"}[dtype]
102
+ llvm.inline_asm(
103
+ None,
104
+ [
105
+ remote_smem_ptr_i32,
106
+ remote_mbar_ptr_i32,
107
+ dtype(val0).ir_value(loc=loc, ip=ip),
108
+ dtype(val1).ir_value(loc=loc, ip=ip),
109
+ dtype(val2).ir_value(loc=loc, ip=ip),
110
+ dtype(val3).ir_value(loc=loc, ip=ip),
111
+ ],
112
+ "{\n\t"
113
+ f".reg .v4 .{suffix} abcd;\n\t"
114
+ f"mov.{suffix} abcd.x, $2;\n\t"
115
+ f"mov.{suffix} abcd.y, $3;\n\t"
116
+ f"mov.{suffix} abcd.z, $4;\n\t"
117
+ f"mov.{suffix} abcd.w, $5;\n\t"
118
+ f"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.{suffix} [$0], abcd, [$1];\n\t"
119
+ "}\n",
120
+ f"r,r,{constraint},{constraint},{constraint},{constraint}",
121
+ has_side_effects=True,
122
+ is_align_stack=False,
123
+ asm_dialect=llvm.AsmDialect.AD_ATT,
124
+ )
125
+
126
+
127
+ @dsl_user_op
128
+ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
129
+ if cutlass.const_expr(cutlass.CUDA_VERSION.major) == 12:
130
+ return Float32(
131
+ nvvm.fmin(
132
+ T.f32(),
133
+ Float32(a).ir_value(loc=loc, ip=ip),
134
+ Float32(b).ir_value(loc=loc, ip=ip),
135
+ loc=loc,
136
+ ip=ip,
137
+ )
138
+ )
139
+ return Float32(
140
+ nvvm.fmin(
141
+ Float32(a).ir_value(loc=loc, ip=ip),
142
+ Float32(b).ir_value(loc=loc, ip=ip),
143
+ loc=loc,
144
+ ip=ip,
145
+ )
146
+ )
147
+
148
+
149
+ @dsl_user_op
150
+ def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
151
+ return Float32(
152
+ llvm.inline_asm(
153
+ T.f32(),
154
+ [Float32(a).ir_value(loc=loc, ip=ip)],
155
+ "sqrt.approx.f32 $0, $1;",
156
+ "=f,f",
157
+ has_side_effects=False,
158
+ is_align_stack=False,
159
+ asm_dialect=llvm.AsmDialect.AD_ATT,
160
+ )
161
+ )
162
+
163
+
164
+ @dsl_user_op
165
+ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
166
+ return Int32(
167
+ llvm.inline_asm(
168
+ T.i32(),
169
+ [Float32(a).ir_value(loc=loc, ip=ip)],
170
+ "cvt.rpi.ftz.s32.f32 $0, $1;",
171
+ "=r,f",
172
+ has_side_effects=False,
173
+ is_align_stack=False,
174
+ asm_dialect=llvm.AsmDialect.AD_ATT,
175
+ )
176
+ )
177
+
178
+
179
+ @cute.jit
180
+ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
181
+ """Fill out-of-bounds values in shared memory tensor.
182
+
183
+ Args:
184
+ tXsX: Shared memory tensor to fill
185
+ tXpX: Predicate tensor indicating valid elements
186
+ fill_value: Value to fill OOB locations with
187
+ """
188
+ tXrX_fill = cute.make_rmem_tensor_like(tXsX[(None, 0), None, 0])
189
+ tXrX_fill.fill(fill_value)
190
+ for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
191
+ for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
192
+ if const_expr(tXpX is not None):
193
+ if not tXpX[rest_v, 0, rest_k]:
194
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
195
+ else:
196
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # General-purpose DSL store / vector helpers
201
+ # ---------------------------------------------------------------------------
202
+
203
+
204
+ @dsl_user_op
205
+ def make_vector(elem_type, *values, loc=None, ip=None):
206
+ """Build an MLIR vector <N x elem_type> from N scalar DSL values.
207
+
208
+ Example: make_vector(cutlass.Uint32, v0, v1) -> <2 x i32> MLIR vector
209
+ """
210
+ from cutlass._mlir import ir
211
+
212
+ n = len(values)
213
+ mlir_ty = elem_type.mlir_type
214
+ vec_ty = ir.VectorType.get([n], mlir_ty)
215
+ vec = llvm.mlir_undef(vec_ty, loc=loc, ip=ip)
216
+ for i, v in enumerate(values):
217
+ vec = vector.insertelement(
218
+ elem_type(v).ir_value(loc=loc, ip=ip),
219
+ vec,
220
+ position=_arith.constant(T.i32(), i, loc=loc, ip=ip),
221
+ loc=loc,
222
+ ip=ip,
223
+ )
224
+ return vec
225
+
226
+
227
+ @dsl_user_op
228
+ def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
229
+ vec_f32x2 = vector.from_elements(
230
+ T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
231
+ )
232
+ vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
233
+ res = cutlass.Int64(
234
+ vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
235
+ )
236
+ return res
237
+
238
+
239
+ @dsl_user_op
240
+ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
241
+ vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
242
+ vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
243
+ res0 = Float32(
244
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
245
+ )
246
+ res1 = Float32(
247
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
248
+ )
249
+ return res0, res1
250
+
251
+
252
+ @cute.jit
253
+ def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
254
+ if const_expr(lane is None):
255
+ lane = cute.arch.lane_idx()
256
+ for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
257
+ offset = 1 << i
258
+ # Very important that we set mask_and_clamp to 0
259
+ partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
260
+ if lane >= offset:
261
+ val += partial_sum
262
+ return val
263
+
264
+
265
+ @dsl_user_op
266
+ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
267
+ from cutlass import CUDA_VERSION
268
+
269
+ # * NVVM call based on nvvm version
270
+ if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
271
+ # Old API: requires explicit result type as first positional argument
272
+ return nvvm.atomicrmw(
273
+ res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
274
+ )
275
+ else:
276
+ # New API: infers result type automatically
277
+ return nvvm.atomicrmw(
278
+ op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
279
+ )
280
+
281
+
282
+ @dsl_user_op
283
+ def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
284
+ from cutlass import CUDA_VERSION
285
+
286
+ # * NVVM call based on nvvm version
287
+ if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
288
+ # Old API: requires explicit result type as first positional argument
289
+ return nvvm.atomicrmw(
290
+ res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
291
+ )
292
+ else:
293
+ # New API: infers result type automatically
294
+ return nvvm.atomicrmw(
295
+ op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
296
+ )
297
+
298
+
299
+ @dsl_user_op
300
+ def issue_clc_query_nomulticast(
301
+ mbar_ptr: cute.Pointer,
302
+ clc_response_ptr: cute.Pointer,
303
+ loc=None,
304
+ ip=None,
305
+ ) -> None:
306
+ """
307
+ The clusterlaunchcontrol.try_cancel instruction requests atomically cancelling the launch
308
+ of a cluster that has not started running yet. It asynchronously writes an opaque response
309
+ to shared memory indicating whether the operation succeeded or failed. On success, the
310
+ opaque response contains the ctaid of the first CTA of the canceled cluster.
311
+
312
+ :param mbar_ptr: A pointer to the mbarrier address in SMEM
313
+ :type mbar_ptr: Pointer
314
+ :param clc_response_ptr: A pointer to the cluster launch control response address in SMEM
315
+ :type clc_response_ptr: Pointer
316
+ """
317
+ mbar_llvm_ptr = mbar_ptr.llvm_ptr
318
+ clc_response_llvm_ptr = clc_response_ptr.llvm_ptr
319
+ nvvm.clusterlaunchcontrol_try_cancel(
320
+ clc_response_llvm_ptr,
321
+ mbar_llvm_ptr,
322
+ loc=loc,
323
+ ip=ip,
324
+ )
build/torch-cuda/seqlen_info.py CHANGED
@@ -5,6 +5,8 @@ import cutlass
5
  import cutlass.cute as cute
6
  from cutlass import Int32, const_expr
7
 
 
 
8
  """
9
  This consolidates all the info related to sequence length. This is so that we can do all
10
  the gmem reads once at the beginning of each tile, rather than having to repeat these reads
@@ -14,34 +16,61 @@ to compute various things like n_block_min, n_block_max, etc.
14
 
15
  @dataclass(frozen=True)
16
  class SeqlenInfo:
17
- offset: cutlass.Int32
18
- seqlen: cutlass.Int32
 
 
19
 
20
  @staticmethod
21
  def create(
22
- batch_idx: cutlass.Int32,
23
- seqlen_static: cutlass.Int32,
24
  cu_seqlens: Optional[cute.Tensor] = None,
25
  seqused: Optional[cute.Tensor] = None,
 
26
  ):
27
  offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
 
 
 
 
 
 
28
  if const_expr(seqused is not None):
29
  seqlen = seqused[batch_idx]
30
  elif const_expr(cu_seqlens is not None):
31
  seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
32
  else:
33
  seqlen = seqlen_static
34
- return SeqlenInfo(offset, seqlen)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  @dataclass(frozen=True)
38
  class SeqlenInfoQK:
39
- offset_q: cutlass.Int32
40
- offset_k: cutlass.Int32
41
- padded_offset_q: cutlass.Int32
42
- padded_offset_k: cutlass.Int32
43
- seqlen_q: cutlass.Int32
44
- seqlen_k: cutlass.Int32
45
  has_cu_seqlens_q: cutlass.Constexpr[bool]
46
  has_cu_seqlens_k: cutlass.Constexpr[bool]
47
  has_seqused_q: cutlass.Constexpr[bool]
@@ -49,27 +78,27 @@ class SeqlenInfoQK:
49
 
50
  @staticmethod
51
  def create(
52
- batch_idx: cutlass.Int32,
53
- seqlen_q_static: cutlass.Int32,
54
- seqlen_k_static: cutlass.Int32,
55
  mCuSeqlensQ: Optional[cute.Tensor] = None,
56
  mCuSeqlensK: Optional[cute.Tensor] = None,
57
  mSeqUsedQ: Optional[cute.Tensor] = None,
58
  mSeqUsedK: Optional[cute.Tensor] = None,
59
- tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
60
- tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
61
  ):
62
  offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
63
  offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
64
  padded_offset_q = (
65
  0
66
  if const_expr(mCuSeqlensQ is None)
67
- else (offset_q + batch_idx * tile_m) // tile_m * tile_m
68
  )
69
  padded_offset_k = (
70
  0
71
  if const_expr(mCuSeqlensK is None)
72
- else (offset_k + batch_idx * tile_n) // tile_n * tile_n
73
  )
74
  if const_expr(mSeqUsedQ is not None):
75
  seqlen_q = mSeqUsedQ[batch_idx]
@@ -87,10 +116,6 @@ class SeqlenInfoQK:
87
  if const_expr(mCuSeqlensK is None)
88
  else mCuSeqlensK[batch_idx + 1] - offset_k
89
  )
90
- has_cu_seqlens_q: int = mCuSeqlensQ is not None
91
- has_cu_seqlens_k: int = mCuSeqlensK is not None
92
- has_seqused_q: int = mSeqUsedQ is not None
93
- has_seqused_k: int = mSeqUsedK is not None
94
  return SeqlenInfoQK(
95
  offset_q,
96
  offset_k,
@@ -98,10 +123,10 @@ class SeqlenInfoQK:
98
  padded_offset_k,
99
  seqlen_q,
100
  seqlen_k,
101
- has_cu_seqlens_q,
102
- has_cu_seqlens_k,
103
- has_seqused_q,
104
- has_seqused_k,
105
  )
106
 
107
  def offset_batch_Q(
@@ -110,16 +135,38 @@ class SeqlenInfoQK:
110
  batch_idx: Int32,
111
  dim: int,
112
  padded: cutlass.Constexpr[bool] = False,
 
113
  ) -> cute.Tensor:
114
  """Seqlen must be the first dimension of mQ"""
115
- if const_expr(not self.has_cu_seqlens_q):
116
- idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
117
- return mQ[idx]
 
 
 
 
 
 
118
  else:
119
- offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
120
- offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
121
- idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
122
- return cute.domain_offset(idx, mQ)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def offset_batch_K(
125
  self,
@@ -127,12 +174,114 @@ class SeqlenInfoQK:
127
  batch_idx: Int32,
128
  dim: int,
129
  padded: cutlass.Constexpr[bool] = False,
 
 
130
  ) -> cute.Tensor:
131
  """Seqlen must be the first dimension of mK"""
132
- if const_expr(not self.has_cu_seqlens_k):
133
- idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
134
- return mK[idx]
 
 
 
 
 
 
135
  else:
136
- offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
137
- idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
138
- return cute.domain_offset(idx, mK)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import cutlass.cute as cute
6
  from cutlass import Int32, const_expr
7
 
8
+ from .quack import copy_utils
9
+
10
  """
11
  This consolidates all the info related to sequence length. This is so that we can do all
12
  the gmem reads once at the beginning of each tile, rather than having to repeat these reads
 
16
 
17
  @dataclass(frozen=True)
18
  class SeqlenInfo:
19
+ offset: Int32
20
+ offset_padded: Int32
21
+ seqlen: Int32
22
+ has_cu_seqlens: cutlass.Constexpr[bool] = False
23
 
24
  @staticmethod
25
  def create(
26
+ batch_idx: Int32,
27
+ seqlen_static: Int32,
28
  cu_seqlens: Optional[cute.Tensor] = None,
29
  seqused: Optional[cute.Tensor] = None,
30
+ tile: cutlass.Constexpr[int] = 128,
31
  ):
32
  offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]
33
+ offset_padded = (
34
+ 0
35
+ if const_expr(cu_seqlens is None)
36
+ # Add divby so that the compiler knows the alignment when moving by offset_padded
37
+ else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile)
38
+ )
39
  if const_expr(seqused is not None):
40
  seqlen = seqused[batch_idx]
41
  elif const_expr(cu_seqlens is not None):
42
  seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
43
  else:
44
  seqlen = seqlen_static
45
+ return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None)
46
+
47
+ def offset_batch(
48
+ self,
49
+ mT: cute.Tensor,
50
+ batch_idx: Int32,
51
+ dim: int,
52
+ padded: cutlass.Constexpr[bool] = False,
53
+ multiple: int = 1,
54
+ ) -> cute.Tensor:
55
+ """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0."""
56
+ if const_expr(not self.has_cu_seqlens):
57
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim)
58
+ return mT[idx]
59
+ else:
60
+ off = multiple * (self.offset if const_expr(not padded) else self.offset_padded)
61
+ offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off)
62
+ idx = (offset,) + (None,) * (cute.rank(mT) - 1)
63
+ return cute.domain_offset(idx, mT)
64
 
65
 
66
  @dataclass(frozen=True)
67
  class SeqlenInfoQK:
68
+ offset_q: Int32
69
+ offset_k: Int32
70
+ padded_offset_q: Int32
71
+ padded_offset_k: Int32
72
+ seqlen_q: Int32
73
+ seqlen_k: Int32
74
  has_cu_seqlens_q: cutlass.Constexpr[bool]
75
  has_cu_seqlens_k: cutlass.Constexpr[bool]
76
  has_seqused_q: cutlass.Constexpr[bool]
 
78
 
79
  @staticmethod
80
  def create(
81
+ batch_idx: Int32,
82
+ seqlen_q_static: Int32,
83
+ seqlen_k_static: Int32,
84
  mCuSeqlensQ: Optional[cute.Tensor] = None,
85
  mCuSeqlensK: Optional[cute.Tensor] = None,
86
  mSeqUsedQ: Optional[cute.Tensor] = None,
87
  mSeqUsedK: Optional[cute.Tensor] = None,
88
+ tile_m: cutlass.Constexpr[Int32] = 128,
89
+ tile_n: cutlass.Constexpr[Int32] = 128,
90
  ):
91
  offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
92
  offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
93
  padded_offset_q = (
94
  0
95
  if const_expr(mCuSeqlensQ is None)
96
+ else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m)
97
  )
98
  padded_offset_k = (
99
  0
100
  if const_expr(mCuSeqlensK is None)
101
+ else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n)
102
  )
103
  if const_expr(mSeqUsedQ is not None):
104
  seqlen_q = mSeqUsedQ[batch_idx]
 
116
  if const_expr(mCuSeqlensK is None)
117
  else mCuSeqlensK[batch_idx + 1] - offset_k
118
  )
 
 
 
 
119
  return SeqlenInfoQK(
120
  offset_q,
121
  offset_k,
 
123
  padded_offset_k,
124
  seqlen_q,
125
  seqlen_k,
126
+ has_cu_seqlens_q=mCuSeqlensQ is not None,
127
+ has_cu_seqlens_k=mCuSeqlensK is not None,
128
+ has_seqused_q=mSeqUsedQ is not None,
129
+ has_seqused_k=mSeqUsedK is not None,
130
  )
131
 
132
  def offset_batch_Q(
 
135
  batch_idx: Int32,
136
  dim: int,
137
  padded: cutlass.Constexpr[bool] = False,
138
+ ragged: cutlass.Constexpr[bool] = False,
139
  ) -> cute.Tensor:
140
  """Seqlen must be the first dimension of mQ"""
141
+ if const_expr(not ragged):
142
+ if const_expr(not self.has_cu_seqlens_q):
143
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
144
+ return mQ[idx]
145
+ else:
146
+ offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
147
+ offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q)
148
+ idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1)
149
+ return cute.domain_offset(idx, mQ)
150
  else:
151
+ if const_expr(not self.has_cu_seqlens_q):
152
+ offset_q = 0
153
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
154
+ mQ = mQ[idx]
155
+ else:
156
+ offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
157
+ if const_expr(cute.rank(mQ.shape[0]) == 1):
158
+ return copy_utils.offset_ragged_tensor(
159
+ mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True
160
+ )
161
+ else: # PackGQA
162
+ assert cute.rank(mQ.shape[0]) == 2
163
+ # Unpack before calling offset_ragged_tensor, then pack
164
+ idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1)
165
+ mQ = mQ[idx]
166
+ mQ = copy_utils.offset_ragged_tensor(
167
+ mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True
168
+ )
169
+ return cute.group_modes(mQ, 0, 2)
170
 
171
  def offset_batch_K(
172
  self,
 
174
  batch_idx: Int32,
175
  dim: int,
176
  padded: cutlass.Constexpr[bool] = False,
177
+ ragged: cutlass.Constexpr[bool] = False,
178
+ multiple: int = 1,
179
  ) -> cute.Tensor:
180
  """Seqlen must be the first dimension of mK"""
181
+ if const_expr(not ragged):
182
+ if const_expr(not self.has_cu_seqlens_k):
183
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
184
+ return mK[idx]
185
+ else:
186
+ offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
187
+ offset_k *= multiple
188
+ idx = (offset_k,) + (None,) * (cute.rank(mK) - 1)
189
+ return cute.domain_offset(idx, mK)
190
  else:
191
+ if const_expr(not self.has_cu_seqlens_k):
192
+ offset_k = 0
193
+ idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
194
+ mK = mK[idx]
195
+ else:
196
+ offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
197
+ offset_k *= multiple
198
+ return copy_utils.offset_ragged_tensor(
199
+ mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True
200
+ )
201
+
202
+
203
+ @dataclass(frozen=True)
204
+ class SeqlenInfoQKNewK:
205
+ """Sequence length info for append-KV with left-padding and new K support.
206
+
207
+ Extends SeqlenInfoQK with:
208
+ - leftpad_k: left padding for K (tokens to skip at the start of the KV cache)
209
+ - offset_k_new: offset into the new K tensor
210
+ - seqlen_k_og: original K length (before appending new K), excluding leftpad
211
+ - seqlen_k_new: length of new K to append
212
+ - seqlen_k: total K length (seqlen_k_og + seqlen_k_new)
213
+ - seqlen_rotary: position for rotary embedding computation
214
+ """
215
+
216
+ leftpad_k: Int32
217
+ offset_q: Int32
218
+ offset_k: Int32
219
+ offset_k_new: Int32
220
+ seqlen_q: Int32
221
+ seqlen_k_og: Int32
222
+ seqlen_k_new: Int32
223
+ seqlen_k: Int32
224
+ seqlen_rotary: Int32
225
+
226
+ @staticmethod
227
+ def create(
228
+ batch_idx: Int32,
229
+ seqlen_q_static: Int32,
230
+ seqlen_k_static: Int32,
231
+ shape_K_new_0: Int32,
232
+ mCuSeqlensQ: Optional[cute.Tensor] = None,
233
+ mCuSeqlensK: Optional[cute.Tensor] = None,
234
+ mCuSeqlensKNew: Optional[cute.Tensor] = None,
235
+ mSeqUsedQ: Optional[cute.Tensor] = None,
236
+ mSeqUsedK: Optional[cute.Tensor] = None,
237
+ mLeftpadK: Optional[cute.Tensor] = None,
238
+ mSeqlensRotary: Optional[cute.Tensor] = None,
239
+ ):
240
+ leftpad_k = 0 if const_expr(mLeftpadK is None) else mLeftpadK[batch_idx]
241
+ offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
242
+ if const_expr(mCuSeqlensK is not None):
243
+ offset_k = mCuSeqlensK[batch_idx] + leftpad_k
244
+ else:
245
+ offset_k = leftpad_k if const_expr(mCuSeqlensQ is not None) else 0
246
+ offset_k_new = 0 if const_expr(mCuSeqlensKNew is None) else mCuSeqlensKNew[batch_idx]
247
+ # seqlen_q
248
+ if const_expr(mSeqUsedQ is not None):
249
+ seqlen_q = mSeqUsedQ[batch_idx]
250
+ elif const_expr(mCuSeqlensQ is not None):
251
+ seqlen_q = mCuSeqlensQ[batch_idx + 1] - mCuSeqlensQ[batch_idx]
252
+ else:
253
+ seqlen_q = seqlen_q_static
254
+ # seqlen_k_og: original K length (excluding leftpad)
255
+ if const_expr(mSeqUsedK is not None):
256
+ seqlen_k_og = mSeqUsedK[batch_idx] - leftpad_k
257
+ elif const_expr(mCuSeqlensK is not None):
258
+ seqlen_k_og = mCuSeqlensK[batch_idx + 1] - mCuSeqlensK[batch_idx] - leftpad_k
259
+ else:
260
+ seqlen_k_og = (
261
+ seqlen_k_static - leftpad_k
262
+ if const_expr(mCuSeqlensQ is not None)
263
+ else seqlen_k_static
264
+ )
265
+ # seqlen_k_new
266
+ if const_expr(mCuSeqlensKNew is None):
267
+ seqlen_k_new = 0 if const_expr(mCuSeqlensQ is None) else shape_K_new_0
268
+ else:
269
+ seqlen_k_new = mCuSeqlensKNew[batch_idx + 1] - mCuSeqlensKNew[batch_idx]
270
+ seqlen_k = seqlen_k_og if const_expr(mCuSeqlensQ is None) else seqlen_k_og + seqlen_k_new
271
+
272
+ # seqlen_rotary: defaults to seqlen_k_og + leftpad_k unless explicitly provided
273
+ if const_expr(mSeqlensRotary is not None):
274
+ seqlen_rotary = mSeqlensRotary[batch_idx]
275
+ else:
276
+ seqlen_rotary = seqlen_k_og + leftpad_k
277
+ return SeqlenInfoQKNewK(
278
+ leftpad_k,
279
+ offset_q,
280
+ offset_k,
281
+ offset_k_new,
282
+ seqlen_q,
283
+ seqlen_k_og,
284
+ seqlen_k_new,
285
+ seqlen_k,
286
+ seqlen_rotary,
287
+ )
build/torch-cuda/sm90_config_search.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search feasible SM90 fwd/bwd attention configs for given (head_dim, head_dim_v).
2
+
3
+ Enumerates tile sizes, swap modes, atom layouts, and staging options.
4
+ Checks GMMA divisibility, register budget, and shared memory budget.
5
+
6
+ Usage:
7
+ python flash_attn/cute/sm90_config_search.py --headdim 128
8
+ python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128
9
+ python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-n 64,96
10
+ """
11
+
12
+ import math
13
+
14
+ # H100 hardware limits
15
+ SMEM_LIMIT = 224 * 1024 # 228 KB minus ~3 KB for LSE, dPsum, mbarriers
16
+ REG_LIMITS = {2: 216, 3: 128} # per-WG budget: 2WG=240-24, 3WG=160-32
17
+ THREADS_PER_WG = 128
18
+
19
+
20
+ def _divisors(n):
21
+ return [d for d in range(1, n + 1) if n % d == 0]
22
+
23
+
24
+ def _acc_regs(M, N, num_wg):
25
+ """Accumulator registers per thread per WG."""
26
+ return M * N // (num_wg * THREADS_PER_WG)
27
+
28
+
29
+ def _check_mma(M, N, num_wg, atom_layout_m, swap_AB):
30
+ """Check MMA feasibility. Returns regs per WG, or None if infeasible.
31
+
32
+ GMMA atom M=64. Swap exchanges (M, N) and atom layout.
33
+ Requires: M divisible by (atom_layout_m * 64), N by (atom_layout_n * 8).
34
+ """
35
+ if swap_AB:
36
+ M, N = N, M
37
+ atom_layout_m = num_wg // atom_layout_m
38
+ atom_layout_n = num_wg // atom_layout_m
39
+ if M % (atom_layout_m * 64) != 0 or N % (atom_layout_n * 8) != 0:
40
+ return None
41
+ return _acc_regs(M, N, num_wg)
42
+
43
+
44
+ def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False):
45
+ """Total SMEM read traffic for one MMA (all WGs combined).
46
+
47
+ num_instr = (M_eff / 64) * wg_n instructions total.
48
+ Each reads A(64, K_red) and B(N_eff/wg_n, K_red) from smem (bf16).
49
+ """
50
+ num_instr = (M_eff // 64) * wg_n
51
+ A_per = 64 * K_red * 2 if not is_rs else 0
52
+ B_per = (N_eff // wg_n) * K_red * 2
53
+ return num_instr * (A_per + B_per)
54
+
55
+
56
+ # ============================================================================
57
+ # Backward
58
+ # ============================================================================
59
+
60
+
61
+ def _check_bwd_config(
62
+ hdim,
63
+ hdimv,
64
+ tile_m,
65
+ tile_n,
66
+ num_wg,
67
+ SdP_swapAB,
68
+ dKV_swapAB,
69
+ dQ_swapAB,
70
+ AtomLayoutMSdP,
71
+ AtomLayoutNdKV,
72
+ AtomLayoutMdQ,
73
+ ):
74
+ reg_limit = REG_LIMITS[num_wg]
75
+
76
+ # MMA feasibility
77
+ regs_SdP = _check_mma(tile_m, tile_n, num_wg, AtomLayoutMSdP, SdP_swapAB)
78
+ regs_dK = _check_mma(tile_n, hdim, num_wg, AtomLayoutNdKV, dKV_swapAB)
79
+ regs_dV = _check_mma(tile_n, hdimv, num_wg, AtomLayoutNdKV, dKV_swapAB)
80
+ regs_dQ = _check_mma(tile_m, hdim, num_wg, AtomLayoutMdQ, dQ_swapAB)
81
+ if any(r is None for r in (regs_SdP, regs_dK, regs_dV, regs_dQ)):
82
+ return None
83
+
84
+ # Peak regs: max(S+dP, dQ) + dK + dV
85
+ total_regs = max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV
86
+ if total_regs > reg_limit:
87
+ return None
88
+
89
+ # SMEM
90
+ mma_dkv_is_rs = (
91
+ AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_wg and SdP_swapAB and not dKV_swapAB
92
+ )
93
+ Q_stage, PdS_stage = 2, 1
94
+
95
+ for dO_stage in (2, 1):
96
+ sQ = tile_m * hdim * 2 * Q_stage
97
+ sK = tile_n * hdim * 2
98
+ sV = tile_n * hdimv * 2
99
+ sdO = tile_m * hdimv * 2 * dO_stage
100
+ sPdS = tile_m * tile_n * 2 * PdS_stage
101
+ sP = sPdS if not mma_dkv_is_rs else 0
102
+ sdQaccum = tile_m * hdim * 4
103
+ smem = sQ + sK + sV + sdO + sP + sPdS + sdQaccum
104
+ if smem <= SMEM_LIMIT:
105
+ break
106
+ else:
107
+ return None
108
+
109
+ # SMEM traffic
110
+ def _swap(a, b, s):
111
+ return (b, a) if s else (a, b)
112
+
113
+ def _wg_n(al_m, s):
114
+ return al_m if s else num_wg // al_m
115
+
116
+ M_s, N_s = _swap(tile_m, tile_n, SdP_swapAB)
117
+ wn_SdP = _wg_n(AtomLayoutMSdP, SdP_swapAB)
118
+ traffic_S = _mma_traffic(M_s, N_s, hdim, num_wg, wn_SdP)
119
+ traffic_dP = _mma_traffic(M_s, N_s, hdimv, num_wg, wn_SdP)
120
+
121
+ wn_dKV = _wg_n(AtomLayoutNdKV, dKV_swapAB)
122
+ M_dv, N_dv = _swap(tile_n, hdimv, dKV_swapAB)
123
+ traffic_dV = _mma_traffic(M_dv, N_dv, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs)
124
+ M_dk, N_dk = _swap(tile_n, hdim, dKV_swapAB)
125
+ traffic_dK = _mma_traffic(M_dk, N_dk, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs)
126
+
127
+ M_dq, N_dq = _swap(tile_m, hdim, dQ_swapAB)
128
+ wn_dQ = _wg_n(AtomLayoutMdQ, dQ_swapAB)
129
+ traffic_dQ = _mma_traffic(M_dq, N_dq, tile_n, num_wg, wn_dQ)
130
+
131
+ traffic_P_store = tile_m * tile_n * 2 if not mma_dkv_is_rs else 0
132
+ traffic_dS_store = tile_m * tile_n * 2
133
+ traffic_dQ_smem = tile_m * hdim * 4 * 2 # store + TMA load
134
+
135
+ smem_traffic = (
136
+ traffic_S
137
+ + traffic_dP
138
+ + traffic_dV
139
+ + traffic_dK
140
+ + traffic_dQ
141
+ + traffic_P_store
142
+ + traffic_dS_store
143
+ + traffic_dQ_smem
144
+ )
145
+
146
+ return dict(
147
+ tile_m=tile_m,
148
+ tile_n=tile_n,
149
+ num_wg=num_wg,
150
+ Q_stage=Q_stage,
151
+ dO_stage=dO_stage,
152
+ PdS_stage=PdS_stage,
153
+ SdP_swapAB=SdP_swapAB,
154
+ dKV_swapAB=dKV_swapAB,
155
+ dQ_swapAB=dQ_swapAB,
156
+ AtomLayoutMSdP=AtomLayoutMSdP,
157
+ AtomLayoutNdKV=AtomLayoutNdKV,
158
+ AtomLayoutMdQ=AtomLayoutMdQ,
159
+ mma_dkv_is_rs=mma_dkv_is_rs,
160
+ regs_SdP=regs_SdP,
161
+ regs_dK=regs_dK,
162
+ regs_dV=regs_dV,
163
+ regs_dQ=regs_dQ,
164
+ total_regs=total_regs,
165
+ reg_limit=reg_limit,
166
+ smem_bytes=smem,
167
+ smem_kb=smem / 1024,
168
+ smem_traffic=smem_traffic,
169
+ smem_traffic_kb=smem_traffic / 1024,
170
+ smem_traffic_per_block=smem_traffic / (tile_m * tile_n),
171
+ )
172
+
173
+
174
+ def find_feasible_bwd_configs(
175
+ head_dim,
176
+ head_dim_v=None,
177
+ tile_m_choices=(64, 80, 96, 112, 128),
178
+ tile_n_choices=(64, 80, 96, 112, 128),
179
+ ):
180
+ if head_dim_v is None:
181
+ head_dim_v = head_dim
182
+ hdim = int(math.ceil(head_dim / 32) * 32)
183
+ hdimv = int(math.ceil(head_dim_v / 32) * 32)
184
+
185
+ results = []
186
+ for num_wg in (2, 3):
187
+ divs = _divisors(num_wg)
188
+ for tile_m in tile_m_choices:
189
+ for tile_n in tile_n_choices:
190
+ for SdP_swap in (False, True):
191
+ if (tile_n if SdP_swap else tile_m) % 64 != 0:
192
+ continue
193
+ for dKV_swap in (False, True):
194
+ if not dKV_swap and tile_n % 64 != 0:
195
+ continue
196
+ if dKV_swap and (hdim % 64 != 0 or hdimv % 64 != 0):
197
+ continue
198
+ for dQ_swap in (False, True):
199
+ if (hdim if dQ_swap else tile_m) % 64 != 0:
200
+ continue
201
+ for a1 in divs:
202
+ for a2 in divs:
203
+ for a3 in divs:
204
+ cfg = _check_bwd_config(
205
+ hdim,
206
+ hdimv,
207
+ tile_m,
208
+ tile_n,
209
+ num_wg,
210
+ SdP_swap,
211
+ dKV_swap,
212
+ dQ_swap,
213
+ a1,
214
+ a2,
215
+ a3,
216
+ )
217
+ if cfg is not None:
218
+ results.append(cfg)
219
+
220
+ results.sort(key=lambda c: (-c["tile_n"], -c["tile_m"], c["smem_traffic_per_block"]))
221
+ return results
222
+
223
+
224
+ def print_bwd_configs(configs, max_results=20):
225
+ if not configs:
226
+ print("No feasible configs found!")
227
+ return
228
+ n = min(len(configs), max_results)
229
+ print(f"Found {len(configs)} feasible configs (showing top {n}):\n")
230
+ hdr = (
231
+ f"{'wg':>2} {'tm':>3} {'tn':>3} "
232
+ f"{'SdP':>3} {'dKV':>3} {'dQ':>3} "
233
+ f"{'aSdP':>4} {'adKV':>4} {'adQ':>4} "
234
+ f"{'Qs':>2} {'dOs':>3} "
235
+ f"{'rS':>3} {'rdK':>3} {'rdV':>3} {'rdQ':>3} {'tot':>4}/{'':<3} "
236
+ f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}"
237
+ )
238
+ print(hdr)
239
+ print("-" * len(hdr))
240
+ B = lambda b: "T" if b else "F"
241
+ for c in configs[:max_results]:
242
+ print(
243
+ f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} "
244
+ f"{B(c['SdP_swapAB']):>3} {B(c['dKV_swapAB']):>3} {B(c['dQ_swapAB']):>3} "
245
+ f"{c['AtomLayoutMSdP']:>4} {c['AtomLayoutNdKV']:>4} {c['AtomLayoutMdQ']:>4} "
246
+ f"{c['Q_stage']:>2} {c['dO_stage']:>3} "
247
+ f"{c['regs_SdP']:>3} {c['regs_dK']:>3} {c['regs_dV']:>3} {c['regs_dQ']:>3} "
248
+ f"{c['total_regs']:>4}/{c['reg_limit']:<3} "
249
+ f"{c['smem_kb']:>4.0f}K "
250
+ f"{c['smem_traffic_kb']:>6.0f}K "
251
+ f"{c['smem_traffic_per_block']:>6.1f}"
252
+ )
253
+
254
+
255
+ # ============================================================================
256
+ # Forward
257
+ # ============================================================================
258
+
259
+
260
+ def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg):
261
+ reg_limit = REG_LIMITS[num_wg]
262
+ tile_m = num_wg * 64
263
+
264
+ if tile_n % 8 != 0:
265
+ return None
266
+
267
+ regs_S = _acc_regs(tile_m, tile_n, num_wg)
268
+ regs_O = _acc_regs(tile_m, hdimv, num_wg)
269
+ regs_P = regs_S // 2 # bf16 = half of f32
270
+
271
+ if overlap_wg:
272
+ total_regs = regs_S + regs_P + regs_O
273
+ else:
274
+ total_regs = regs_S + regs_O
275
+
276
+ if total_regs > reg_limit:
277
+ return None
278
+
279
+ # SMEM: 1 stage Q, 2 stages K/V, O overlaps Q, sP if not RS
280
+ sQ = tile_m * hdim * 2
281
+ sK = tile_n * hdim * 2 * 2
282
+ sV = tile_n * hdimv * 2 * 2
283
+ sO = tile_m * hdimv * 2
284
+ sP = tile_m * tile_n * 2 if not pv_is_rs else 0
285
+ smem = max(sQ, sO) + sK + sV + sP
286
+ if smem > SMEM_LIMIT:
287
+ return None
288
+
289
+ # SMEM traffic: num_instr = num_wg (all WGs in M, wg_n=1)
290
+ traffic_S = num_wg * (64 * hdim * 2 + tile_n * hdim * 2)
291
+ A_pv = 64 * tile_n * 2 if not pv_is_rs else 0
292
+ traffic_O = num_wg * (A_pv + hdimv * tile_n * 2)
293
+ traffic_P_store = tile_m * tile_n * 2 if not pv_is_rs else 0
294
+ smem_traffic = traffic_S + traffic_O + traffic_P_store
295
+
296
+ return dict(
297
+ tile_m=tile_m,
298
+ tile_n=tile_n,
299
+ num_wg=num_wg,
300
+ pv_is_rs=pv_is_rs,
301
+ overlap_wg=overlap_wg,
302
+ regs_S=regs_S,
303
+ regs_O=regs_O,
304
+ regs_P=regs_P,
305
+ total_regs=total_regs,
306
+ reg_limit=reg_limit,
307
+ smem_bytes=smem,
308
+ smem_kb=smem / 1024,
309
+ smem_traffic=smem_traffic,
310
+ smem_traffic_kb=smem_traffic / 1024,
311
+ smem_traffic_per_block=smem_traffic / (tile_m * tile_n),
312
+ )
313
+
314
+
315
+ def find_feasible_fwd_configs(
316
+ head_dim, head_dim_v=None, tile_n_choices=(64, 80, 96, 112, 128, 144, 160, 176, 192)
317
+ ):
318
+ if head_dim_v is None:
319
+ head_dim_v = head_dim
320
+ hdim = int(math.ceil(head_dim / 32) * 32)
321
+ hdimv = int(math.ceil(head_dim_v / 32) * 32)
322
+
323
+ results = []
324
+ for num_wg in (2, 3):
325
+ for tile_n in tile_n_choices:
326
+ for pv_is_rs in (True, False):
327
+ for overlap_wg in (True, False):
328
+ cfg = _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg)
329
+ if cfg is not None:
330
+ results.append(cfg)
331
+
332
+ results.sort(key=lambda c: (-c["tile_n"], c["smem_traffic_per_block"]))
333
+ return results
334
+
335
+
336
+ def print_fwd_configs(configs, max_results=20):
337
+ if not configs:
338
+ print("No feasible configs found!")
339
+ return
340
+ n = min(len(configs), max_results)
341
+ print(f"Found {len(configs)} feasible configs (showing top {n}):\n")
342
+ hdr = (
343
+ f"{'wg':>2} {'tm':>3} {'tn':>3} "
344
+ f"{'RS':>2} {'olap':>4} "
345
+ f"{'rS':>3} {'rP':>3} {'rO':>3} {'tot':>4}/{'':<3} "
346
+ f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}"
347
+ )
348
+ print(hdr)
349
+ print("-" * len(hdr))
350
+ B = lambda b: "T" if b else "F"
351
+ for c in configs[:max_results]:
352
+ print(
353
+ f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} "
354
+ f"{B(c['pv_is_rs']):>2} {B(c['overlap_wg']):>4} "
355
+ f"{c['regs_S']:>3} {c['regs_P']:>3} {c['regs_O']:>3} "
356
+ f"{c['total_regs']:>4}/{c['reg_limit']:<3} "
357
+ f"{c['smem_kb']:>4.0f}K "
358
+ f"{c['smem_traffic_kb']:>6.0f}K "
359
+ f"{c['smem_traffic_per_block']:>6.1f}"
360
+ )
361
+
362
+
363
+ # ============================================================================
364
+ # CLI
365
+ # ============================================================================
366
+
367
+ if __name__ == "__main__":
368
+ import argparse
369
+
370
+ parser = argparse.ArgumentParser(description="Search feasible SM90 MMA configs")
371
+ parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both")
372
+ parser.add_argument(
373
+ "--headdim", type=str, default="128", help="Head dim, or hdim-hdimv (e.g. 192-128)"
374
+ )
375
+ parser.add_argument("--tile-m", type=str, default="64,80,96,112,128", help="Bwd tile_m choices")
376
+ parser.add_argument(
377
+ "--tile-n",
378
+ type=str,
379
+ default=None,
380
+ help="tile_n choices (default: fwd up to 192, bwd up to 128)",
381
+ )
382
+ parser.add_argument("-n", "--num-results", type=int, default=30)
383
+ args = parser.parse_args()
384
+
385
+ parts = args.headdim.split("-")
386
+ hdim = int(parts[0])
387
+ hdimv = int(parts[1]) if len(parts) > 1 else hdim
388
+
389
+ TN_FWD = "64,80,96,112,128,144,160,176,192"
390
+ TN_BWD = "64,80,96,112,128"
391
+
392
+ if args.mode in ("fwd", "both"):
393
+ tn = tuple(int(x) for x in (args.tile_n or TN_FWD).split(","))
394
+ print(f"=== FWD configs: hdim={hdim}, hdimv={hdimv} ===\n")
395
+ print_fwd_configs(find_feasible_fwd_configs(hdim, hdimv, tn), args.num_results)
396
+ print()
397
+
398
+ if args.mode in ("bwd", "both"):
399
+ tm = tuple(int(x) for x in args.tile_m.split(","))
400
+ tn = tuple(int(x) for x in (args.tile_n or TN_BWD).split(","))
401
+ print(f"=== BWD configs: hdim={hdim}, hdimv={hdimv} ===\n")
402
+ print_bwd_configs(find_feasible_bwd_configs(hdim, hdimv, tm, tn), args.num_results)
build/torch-cuda/softmax.py CHANGED
@@ -10,7 +10,7 @@ import cutlass.cute as cute
10
  from cutlass import Float32
11
 
12
  from .quack import layout_utils
13
- from . import utils
14
  from .quack.cute_dsl_utils import ParamsBase
15
  from .seqlen_info import SeqlenInfoQK
16
 
 
10
  from cutlass import Float32
11
 
12
  from .quack import layout_utils
13
+ from . import utils as utils
14
  from .quack.cute_dsl_utils import ParamsBase
15
  from .seqlen_info import SeqlenInfoQK
16
 
build/torch-cuda/tile_scheduler.py CHANGED
@@ -1,6 +1,7 @@
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
- from typing import Optional, Tuple
 
4
  from dataclasses import dataclass
5
 
6
  try:
@@ -9,17 +10,80 @@ except ImportError: # Python < 3.12
9
  from typing_extensions import override
10
 
11
  import cutlass
 
12
  from cutlass._mlir import ir
13
  import cutlass.cute as cute
14
  from cutlass import Int32, const_expr
15
  from cutlass.cute import FastDivmodDivisor
 
16
 
17
  from .quack.cute_dsl_utils import ParamsBase
18
 
19
- from . import utils
20
  from .fast_math import clz
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class WorkTileInfo(cutlass.utils.WorkTileInfo):
24
  """Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
25
 
@@ -31,6 +95,47 @@ class WorkTileInfo(cutlass.utils.WorkTileInfo):
31
  return WorkTileInfo(new_tile_idx, new_is_valid_tile)
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @dataclass
35
  class TileSchedulerArguments(ParamsBase):
36
  num_block: Int32
@@ -51,6 +156,7 @@ class TileSchedulerArguments(ParamsBase):
51
  lpt: cutlass.Constexpr[bool] = False
52
  is_split_kv: cutlass.Constexpr[bool] = False
53
  head_swizzle: cutlass.Constexpr[bool] = False
 
54
 
55
 
56
  class SingleTileScheduler:
@@ -63,6 +169,7 @@ class SingleTileScheduler:
63
  num_splits_divmod: FastDivmodDivisor
64
  is_split_kv: cutlass.Constexpr[bool] = False
65
  cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
 
66
 
67
  @staticmethod
68
  def create(
@@ -76,6 +183,7 @@ class SingleTileScheduler:
76
  FastDivmodDivisor(args.num_splits),
77
  args.is_split_kv,
78
  args.cluster_shape_mn,
 
79
  )
80
 
81
  def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
@@ -86,18 +194,26 @@ class SingleTileScheduler:
86
  self._ip = ip
87
 
88
  @staticmethod
89
- def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
 
 
 
 
 
 
 
 
 
90
  return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
91
 
92
  @staticmethod
93
- def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler":
94
- # if const_expr(cute.size(params.cluster_shape_mn) == 1):
95
- # blk_coord = cute.arch.block_idx()
96
- # else:
97
- # # All CTAs in a cluster must get the same block coordinate
98
- # blk_coord = cute.arch.cluster_idx()
99
- # Temporary set to block_idx until we sort out the best way to handle cluster
100
- blk_coord = cute.arch.block_idx()
101
  return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
102
 
103
  # called by host
@@ -110,8 +226,13 @@ class SingleTileScheduler:
110
  ) -> Tuple[Int32, Int32, Int32]:
111
  # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
112
  assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
 
 
 
 
 
113
  return (
114
- cute.round_up(params.num_block, params.cluster_shape_mn[0]),
115
  params.num_head * params.num_splits,
116
  params.num_batch,
117
  )
@@ -135,6 +256,10 @@ class SingleTileScheduler:
135
 
136
  def advance_to_next_work(self, *, loc=None, ip=None):
137
  self._is_first_block = False
 
 
 
 
138
 
139
  def __extract_mlir_values__(self):
140
  values, self._values_pos = [], []
@@ -180,18 +305,28 @@ class StaticPersistentTileScheduler:
180
  self._ip = ip
181
 
182
  @staticmethod
183
- def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
 
 
 
 
 
 
 
 
 
184
  return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
185
 
186
  @staticmethod
187
- def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler":
 
 
188
  if const_expr(cute.size(params.cluster_shape_m) == 1):
189
  tile_idx = cute.arch.block_idx()[0]
190
  else:
191
  tile_idx = cute.arch.cluster_idx()[0]
192
  return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
193
 
194
- # called by host
195
  @staticmethod
196
  def get_grid_shape(
197
  params: Params,
@@ -201,18 +336,14 @@ class StaticPersistentTileScheduler:
201
  ) -> Tuple[Int32, Int32, Int32]:
202
  hardware_info = cutlass.utils.HardwareInfo()
203
  sm_count = hardware_info.get_device_multiprocessor_count()
204
- # Grid must be a multiple of cluster_shape_m for CUDA cluster launch.
205
  max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
206
  grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)
207
  return (grid_x, Int32(1), Int32(1))
208
 
209
- # @cute.jit
210
  def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
211
  hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
212
  batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
213
  is_valid = self._tile_idx < self.params.total_blocks_cluster
214
- # if cute.arch.thread_idx()[0] == 0:
215
- # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)
216
  return WorkTileInfo(
217
  (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
218
  )
@@ -228,6 +359,10 @@ class StaticPersistentTileScheduler:
228
  self._tile_idx += cute.arch.grid_dim()[0]
229
  else:
230
  self._tile_idx += cute.arch.cluster_dim()[0]
 
 
 
 
231
 
232
  def __extract_mlir_values__(self):
233
  values, self._values_pos = [], []
@@ -254,32 +389,41 @@ class SingleTileLPTScheduler:
254
  total_blocks: Int32
255
  num_splits: Int32
256
  num_block: Int32
 
 
257
  l2_minor: Int32
258
- num_block_divmod: FastDivmodDivisor
259
  num_head_divmod: FastDivmodDivisor
260
  l2_minor_divmod: FastDivmodDivisor
261
  l2_major_divmod: FastDivmodDivisor
262
  l2_minor_residual_divmod: FastDivmodDivisor
263
  num_hb_quotient: Int32
 
264
  is_split_kv: cutlass.Constexpr[bool] = False
 
 
 
265
 
266
  @staticmethod
267
  @cute.jit
268
  def create(
269
- args: TileSchedulerArguments, *, loc=None, ip=None
 
 
 
 
270
  ) -> "SingleTileLPTScheduler.Params":
271
- # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size)
 
 
272
  size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
273
  size_one_head = size_one_kv_head
274
  size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
275
  # Swizzle is the size of each "section". Round swizzle to a power of 2
276
  # Need to be careful about the case where only one head will fit
277
  # swizzle is how many heads can fit in L2
278
- # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
279
- # Seems faster if swizzle if a power of 2
280
  log2_floor = lambda n: 31 - clz(n)
281
  swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
282
- # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)
283
  # If we're in the last section (called residual), we don't want to divide by
284
  # swizzle. Instead we want to divide by the remainder.
285
  num_hb_quotient = (args.num_head * args.num_batch) // swizzle
@@ -287,37 +431,84 @@ class SingleTileLPTScheduler:
287
  return SingleTileLPTScheduler.Params(
288
  total_blocks=args.num_block * args.num_head * args.num_batch,
289
  num_block=args.num_block,
 
 
290
  l2_minor=Int32(swizzle),
291
- num_block_divmod=FastDivmodDivisor(args.num_block),
292
  num_head_divmod=FastDivmodDivisor(args.num_head),
293
  l2_minor_divmod=FastDivmodDivisor(swizzle),
294
  l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
295
- l2_minor_residual_divmod=FastDivmodDivisor(
296
- max(num_hb_remainder, 1)
297
- ), # don't divide by 0
298
  num_hb_quotient=Int32(num_hb_quotient),
299
  num_splits=args.num_splits,
 
300
  is_split_kv=args.is_split_kv,
 
 
 
301
  )
302
 
303
- def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
 
 
 
 
 
 
 
 
 
304
  self.params = params
305
  self._tile_idx = tile_idx
306
  self._split_idx = split_idx
 
307
  self._loc = loc
308
  self._ip = ip
309
 
310
  @staticmethod
311
- def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
312
- return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  @staticmethod
315
  @cute.jit
316
- def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  tile_idx, split_idx, _ = cute.arch.block_idx()
318
  return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
319
 
320
- # called by host
321
  @staticmethod
322
  def get_grid_shape(
323
  params: Params,
@@ -325,10 +516,40 @@ class SingleTileLPTScheduler:
325
  loc=None,
326
  ip=None,
327
  ) -> Tuple[Int32, Int32, Int32]:
 
 
328
  return (params.total_blocks, params.num_splits, Int32(1))
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  @cute.jit
331
  def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
 
 
 
 
 
332
  params = self.params
333
  # Implement LPT scheduling coordinate calculation
334
  bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
@@ -342,25 +563,45 @@ class SingleTileLPTScheduler:
342
  bidhb_actual = bidhb * params.l2_minor + bidhb_residual
343
  batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
344
  # Longest-processing-time-first
345
- block = params.num_block - 1 - block
 
346
  is_valid = self._tile_idx < params.total_blocks
347
  return WorkTileInfo(
348
  (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
349
  )
350
 
 
351
  def initial_work_tile_info(self, *, loc=None, ip=None):
 
 
 
 
352
  return self.get_current_work(loc=loc, ip=ip)
353
 
354
  def prefetch_next_work(self, *, loc=None, ip=None):
355
- pass
 
356
 
357
  def advance_to_next_work(self, *, loc=None, ip=None):
 
 
 
 
 
358
  # Single tile scheduler - set to invalid tile_idx to indicate no more work
359
  self._tile_idx = self.params.total_blocks
 
 
 
 
 
360
 
361
  def __extract_mlir_values__(self):
362
  values, self._values_pos = [], []
363
- for obj in [self.params, self._tile_idx, self._split_idx]:
 
 
 
364
  obj_values = cutlass.extract_mlir_values(obj)
365
  values += obj_values
366
  self._values_pos.append(len(obj_values))
@@ -368,10 +609,13 @@ class SingleTileLPTScheduler:
368
 
369
  def __new_from_mlir_values__(self, values):
370
  obj_list = []
371
- for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos):
 
 
 
372
  obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
373
  values = values[n_items:]
374
- return self.__class__(*(tuple(obj_list)), loc=self._loc)
375
 
376
 
377
  class SingleTileLPTBwdScheduler:
@@ -395,8 +639,8 @@ class SingleTileLPTBwdScheduler:
395
  ) -> "SingleTileLPTBwdScheduler.Params":
396
  size_l2 = 50 * 1024 * 1024
397
  size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
398
- # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
399
- size_one_dqaccum_head = 0
400
  size_one_head = size_one_qdo_head + size_one_dqaccum_head
401
  log2_floor = lambda n: 31 - clz(n)
402
  swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
@@ -430,7 +674,16 @@ class SingleTileLPTBwdScheduler:
430
  self._ip = ip
431
 
432
  @staticmethod
433
- def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
 
 
 
 
 
 
 
 
 
434
  return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
435
 
436
  @staticmethod
@@ -481,6 +734,7 @@ class SingleTileLPTBwdScheduler:
481
  def advance_to_next_work(self, *, loc=None, ip=None):
482
  # Single tile scheduler - set to invalid tile_idx to indicate no more work
483
  self._tile_idx = self.params.total_blocks
 
484
 
485
  def __extract_mlir_values__(self):
486
  values, self._values_pos = [], []
@@ -514,20 +768,38 @@ class SingleTileVarlenScheduler:
514
  is_split_kv: cutlass.Constexpr[bool] = False
515
  head_swizzle: cutlass.Constexpr[bool] = False
516
  cluster_shape_m: cutlass.Constexpr[int] = 1
 
517
 
518
  @staticmethod
519
  @cute.jit
520
  def create(
521
- args: TileSchedulerArguments, *, loc=None, ip=None
 
 
 
 
522
  ) -> "SingleTileVarlenScheduler.Params":
 
 
 
523
  size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
524
- max_kvblock_in_l2 = size_l2 // (
 
525
  (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
526
  )
 
 
 
 
527
  assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
528
  "At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
529
  )
530
  assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
 
 
 
 
 
531
  return SingleTileVarlenScheduler.Params(
532
  num_head=args.num_head,
533
  num_batch=args.num_batch,
@@ -542,22 +814,65 @@ class SingleTileVarlenScheduler:
542
  is_split_kv=args.is_split_kv,
543
  head_swizzle=args.head_swizzle,
544
  cluster_shape_m=args.cluster_shape_mn[0],
 
545
  )
546
 
547
- def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
 
 
 
 
 
 
 
 
 
548
  self.params = params
549
  self._tile_idx = tile_idx
550
  self._split_idx = split_idx
551
  self._is_first_block = True
 
552
  self._loc = loc
553
  self._ip = ip
554
 
555
  @staticmethod
556
- def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:
557
- return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip)
 
 
 
 
 
 
 
 
558
 
559
  @staticmethod
560
- def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  tile_idx, split_idx, _ = cute.arch.block_idx()
562
  return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
563
 
@@ -573,7 +888,7 @@ class SingleTileVarlenScheduler:
573
  params.total_q
574
  + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
575
  ) // params.tile_shape_mn[0]
576
- # round down to nearest multiple of cluster since odd excess is always padding
577
  total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
578
  return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
579
 
@@ -601,7 +916,8 @@ class SingleTileVarlenScheduler:
601
  )
602
 
603
  @cute.jit
604
- def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
 
605
  params = self.params
606
  lane_idx = cute.arch.lane_idx()
607
  num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
@@ -654,6 +970,7 @@ class SingleTileVarlenScheduler:
654
  num_n_blocks = (
655
  num_m_blocks
656
  * params.tile_shape_mn[0]
 
657
  // params.qhead_per_kvhead_packgqa
658
  // params.tile_shape_mn[1]
659
  )
@@ -698,19 +1015,62 @@ class SingleTileVarlenScheduler:
698
  split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
699
  return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  def initial_work_tile_info(self, *, loc=None, ip=None):
702
- return self.get_current_work(loc=loc, ip=ip)
 
 
 
 
 
 
 
 
 
 
 
703
 
704
  def prefetch_next_work(self, *, loc=None, ip=None):
705
- pass
 
706
 
707
  def advance_to_next_work(self, *, loc=None, ip=None):
708
- # Single tile scheduler - set to invalid tile_idx to indicate no more work
 
 
 
 
709
  self._is_first_block = False
 
 
 
 
 
710
 
711
  def __extract_mlir_values__(self):
712
  values, self._values_pos = [], []
713
- for obj in [self.params, self._tile_idx, self._split_idx]:
 
 
 
714
  obj_values = cutlass.extract_mlir_values(obj)
715
  values += obj_values
716
  self._values_pos.append(len(obj_values))
@@ -718,10 +1078,10 @@ class SingleTileVarlenScheduler:
718
 
719
  def __new_from_mlir_values__(self, values):
720
  obj_list = []
721
- for obj, n_items in zip(
722
- [self.params, self._tile_idx, self._split_idx],
723
- self._values_pos,
724
- ):
725
  obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
726
  values = values[n_items:]
727
- return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc)
 
1
  # Copyright (c) 2025, Tri Dao.
2
 
3
+ from enum import IntEnum, auto
4
+ from typing import Optional, Tuple, Protocol, runtime_checkable
5
  from dataclasses import dataclass
6
 
7
  try:
 
10
  from typing_extensions import override
11
 
12
  import cutlass
13
+ from cutlass.pipeline import PipelineClcFetchAsync, PipelineState
14
  from cutlass._mlir import ir
15
  import cutlass.cute as cute
16
  from cutlass import Int32, const_expr
17
  from cutlass.cute import FastDivmodDivisor
18
+ from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams
19
 
20
  from .quack.cute_dsl_utils import ParamsBase
21
 
22
+ from . import utils as utils
23
  from .fast_math import clz
24
 
25
 
26
+ class SchedulingMode(IntEnum):
27
+ NONE = auto()
28
+ STATIC = auto()
29
+ DYNAMIC = auto()
30
+ CLC = auto()
31
+
32
+
33
+ @dataclass
34
+ class ClcState(ParamsBase):
35
+ """Owns the runtime state shared by CLC-capable tile schedulers.
36
+
37
+ `FlashAttentionForwardSm100` constructs this state because it owns the CLC
38
+ response buffer, mbarrier storage, and launch geometry needed to initialize
39
+ the hardware scheduler and async pipeline. Individual tile schedulers then
40
+ consume this state and map the returned hardware work tiles into their own
41
+ logical `WorkTileInfo` coordinates.
42
+
43
+ To add CLC support to a scheduler:
44
+ - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler
45
+ - accept `clc: ClcState | None` in `create(...)` / `__init__`
46
+ - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates
47
+ """
48
+
49
+ _hw_scheduler: ClcDynamicPersistentTileScheduler
50
+ _pipeline: PipelineClcFetchAsync
51
+ _consumer_state: PipelineState
52
+ _producer_state: PipelineState
53
+
54
+ @staticmethod
55
+ def create(
56
+ *,
57
+ hw_scheduler: ClcDynamicPersistentTileScheduler,
58
+ pipeline: PipelineClcFetchAsync,
59
+ consumer_state: PipelineState,
60
+ producer_state: PipelineState,
61
+ ) -> "ClcState":
62
+ return ClcState(hw_scheduler, pipeline, consumer_state, producer_state)
63
+
64
+ def initial_work_tile_info(self):
65
+ return self._hw_scheduler.initial_work_tile_info()
66
+
67
+ def get_current_work(self):
68
+ return self._hw_scheduler.get_current_work()
69
+
70
+ def prefetch_next_work(self, *, loc=None, ip=None):
71
+ self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip)
72
+ mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip)
73
+ self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip)
74
+ self._producer_state.advance(loc=loc, ip=ip)
75
+
76
+ def consumer_wait(self, *, loc=None, ip=None):
77
+ self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip)
78
+
79
+ def consumer_release(self, *, loc=None, ip=None):
80
+ self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip)
81
+ self._consumer_state.advance(loc=loc, ip=ip)
82
+
83
+ def producer_tail(self, *, loc=None, ip=None):
84
+ self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip)
85
+
86
+
87
  class WorkTileInfo(cutlass.utils.WorkTileInfo):
88
  """Altered WorkTileInfo which includes four axes: (block, head, batch, split)"""
89
 
 
95
  return WorkTileInfo(new_tile_idx, new_is_valid_tile)
96
 
97
 
98
+ @runtime_checkable
99
+ class TileSchedulerProtocol(Protocol):
100
+ """Protocol defining the interface all tile schedulers must implement.
101
+
102
+ Schedulers are responsible for:
103
+ 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split)
104
+ 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic)
105
+ """
106
+
107
+ def get_current_work(self) -> WorkTileInfo:
108
+ """Get the current work tile coordinates."""
109
+ ...
110
+
111
+ def initial_work_tile_info(self) -> WorkTileInfo:
112
+ """Get the initial work tile for this CTA."""
113
+ ...
114
+
115
+ def advance_to_next_work(self, *, loc=None, ip=None):
116
+ """Consumer-side advance: move to next tile and return it.
117
+
118
+ For static schedulers: grid-stride increment + get_current_work.
119
+ For CLC schedulers: consumer wait + get_current_work + consumer release + state advance.
120
+ """
121
+ ...
122
+
123
+ def prefetch_next_work(self, *, loc=None, ip=None) -> None:
124
+ """Producer-side prefetch of next work tile (no-op for static schedulers).
125
+
126
+ For CLC schedulers: producer acquire + issue CLC query + producer state advance.
127
+ Only called by the scheduler warp.
128
+ """
129
+ ...
130
+
131
+ def producer_tail(self, *, loc=None, ip=None) -> None:
132
+ """Producer-side cleanup after the last tile.
133
+
134
+ No-op for static schedulers. For CLC schedulers: pipeline producer_tail.
135
+ """
136
+ ...
137
+
138
+
139
  @dataclass
140
  class TileSchedulerArguments(ParamsBase):
141
  num_block: Int32
 
156
  lpt: cutlass.Constexpr[bool] = False
157
  is_split_kv: cutlass.Constexpr[bool] = False
158
  head_swizzle: cutlass.Constexpr[bool] = False
159
+ use_cluster_idx: cutlass.Constexpr[bool] = False
160
 
161
 
162
  class SingleTileScheduler:
 
169
  num_splits_divmod: FastDivmodDivisor
170
  is_split_kv: cutlass.Constexpr[bool] = False
171
  cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
172
+ use_cluster_idx: cutlass.Constexpr[bool] = False
173
 
174
  @staticmethod
175
  def create(
 
183
  FastDivmodDivisor(args.num_splits),
184
  args.is_split_kv,
185
  args.cluster_shape_mn,
186
+ args.use_cluster_idx,
187
  )
188
 
189
  def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
 
194
  self._ip = ip
195
 
196
  @staticmethod
197
+ def to_underlying_arguments(
198
+ args: TileSchedulerArguments,
199
+ *,
200
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
201
+ loc=None,
202
+ ip=None,
203
+ ) -> Params:
204
+ assert scheduling_mode == SchedulingMode.STATIC, (
205
+ f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}"
206
+ )
207
  return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)
208
 
209
  @staticmethod
210
+ def create(
211
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
212
+ ) -> "SingleTileScheduler":
213
+ if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx):
214
+ blk_coord = cute.arch.block_idx()
215
+ else:
216
+ blk_coord = cute.arch.cluster_idx()
 
217
  return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)
218
 
219
  # called by host
 
226
  ) -> Tuple[Int32, Int32, Int32]:
227
  # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
228
  assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
229
+ if const_expr(params.use_cluster_idx):
230
+ # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters
231
+ grid_x = params.num_block * params.cluster_shape_mn[0]
232
+ else:
233
+ grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0])
234
  return (
235
+ grid_x,
236
  params.num_head * params.num_splits,
237
  params.num_batch,
238
  )
 
256
 
257
  def advance_to_next_work(self, *, loc=None, ip=None):
258
  self._is_first_block = False
259
+ return self.get_current_work()
260
+
261
+ def producer_tail(self, *, loc=None, ip=None):
262
+ pass
263
 
264
  def __extract_mlir_values__(self):
265
  values, self._values_pos = [], []
 
305
  self._ip = ip
306
 
307
  @staticmethod
308
+ def to_underlying_arguments(
309
+ args: TileSchedulerArguments,
310
+ *,
311
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
312
+ loc=None,
313
+ ip=None,
314
+ ) -> Params:
315
+ assert scheduling_mode == SchedulingMode.STATIC, (
316
+ f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}"
317
+ )
318
  return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)
319
 
320
  @staticmethod
321
+ def create(
322
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
323
+ ) -> "StaticPersistentTileScheduler":
324
  if const_expr(cute.size(params.cluster_shape_m) == 1):
325
  tile_idx = cute.arch.block_idx()[0]
326
  else:
327
  tile_idx = cute.arch.cluster_idx()[0]
328
  return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
329
 
 
330
  @staticmethod
331
  def get_grid_shape(
332
  params: Params,
 
336
  ) -> Tuple[Int32, Int32, Int32]:
337
  hardware_info = cutlass.utils.HardwareInfo()
338
  sm_count = hardware_info.get_device_multiprocessor_count()
 
339
  max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m
340
  grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)
341
  return (grid_x, Int32(1), Int32(1))
342
 
 
343
  def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
344
  hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)
345
  batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)
346
  is_valid = self._tile_idx < self.params.total_blocks_cluster
 
 
347
  return WorkTileInfo(
348
  (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid
349
  )
 
359
  self._tile_idx += cute.arch.grid_dim()[0]
360
  else:
361
  self._tile_idx += cute.arch.cluster_dim()[0]
362
+ return self.get_current_work()
363
+
364
+ def producer_tail(self, *, loc=None, ip=None):
365
+ pass
366
 
367
  def __extract_mlir_values__(self):
368
  values, self._values_pos = [], []
 
389
  total_blocks: Int32
390
  num_splits: Int32
391
  num_block: Int32
392
+ num_head: Int32
393
+ num_batch: Int32
394
  l2_minor: Int32
 
395
  num_head_divmod: FastDivmodDivisor
396
  l2_minor_divmod: FastDivmodDivisor
397
  l2_major_divmod: FastDivmodDivisor
398
  l2_minor_residual_divmod: FastDivmodDivisor
399
  num_hb_quotient: Int32
400
+ num_splits_divmod: FastDivmodDivisor
401
  is_split_kv: cutlass.Constexpr[bool] = False
402
+ cluster_shape_m: cutlass.Constexpr[int] = 1
403
+ scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
404
+ lpt: cutlass.Constexpr[bool] = True
405
 
406
  @staticmethod
407
  @cute.jit
408
  def create(
409
+ args: TileSchedulerArguments,
410
+ *,
411
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
412
+ loc=None,
413
+ ip=None,
414
  ) -> "SingleTileLPTScheduler.Params":
415
+ assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
416
+ f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
417
+ )
418
  size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
419
  size_one_head = size_one_kv_head
420
  size_l2 = 50 * 1024 * 1024 # 40 MB for K & V
421
  # Swizzle is the size of each "section". Round swizzle to a power of 2
422
  # Need to be careful about the case where only one head will fit
423
  # swizzle is how many heads can fit in L2
424
+ # Seems faster if swizzle is a power of 2
 
425
  log2_floor = lambda n: 31 - clz(n)
426
  swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
 
427
  # If we're in the last section (called residual), we don't want to divide by
428
  # swizzle. Instead we want to divide by the remainder.
429
  num_hb_quotient = (args.num_head * args.num_batch) // swizzle
 
431
  return SingleTileLPTScheduler.Params(
432
  total_blocks=args.num_block * args.num_head * args.num_batch,
433
  num_block=args.num_block,
434
+ num_head=args.num_head,
435
+ num_batch=args.num_batch,
436
  l2_minor=Int32(swizzle),
 
437
  num_head_divmod=FastDivmodDivisor(args.num_head),
438
  l2_minor_divmod=FastDivmodDivisor(swizzle),
439
  l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),
440
+ l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)),
 
 
441
  num_hb_quotient=Int32(num_hb_quotient),
442
  num_splits=args.num_splits,
443
+ num_splits_divmod=FastDivmodDivisor(args.num_splits),
444
  is_split_kv=args.is_split_kv,
445
+ cluster_shape_m=args.cluster_shape_mn[0],
446
+ scheduling_mode=scheduling_mode,
447
+ lpt=args.lpt,
448
  )
449
 
450
+ def __init__(
451
+ self,
452
+ params: Params,
453
+ tile_idx: Int32,
454
+ split_idx: Int32,
455
+ clc: ClcState | None = None,
456
+ *,
457
+ loc=None,
458
+ ip=None,
459
+ ):
460
  self.params = params
461
  self._tile_idx = tile_idx
462
  self._split_idx = split_idx
463
+ self.clc = clc
464
  self._loc = loc
465
  self._ip = ip
466
 
467
  @staticmethod
468
+ def to_underlying_arguments(
469
+ args: TileSchedulerArguments,
470
+ *,
471
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
472
+ loc=None,
473
+ ip=None,
474
+ ) -> Params:
475
+ return SingleTileLPTScheduler.Params.create(
476
+ args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
477
+ )
478
+
479
+ @staticmethod
480
+ def _clc_grid_shape(params: Params):
481
+ num_batch_splits = (
482
+ params.num_batch * params.num_splits
483
+ if const_expr(params.is_split_kv)
484
+ else params.num_batch
485
+ )
486
+ return (
487
+ cute.round_up(params.num_block, params.cluster_shape_m),
488
+ params.num_head,
489
+ num_batch_splits,
490
+ )
491
 
492
  @staticmethod
493
  @cute.jit
494
+ def clc_problem_shape(params: Params):
495
+ return ClcDynamicPersistentTileSchedulerParams(
496
+ problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params),
497
+ cluster_shape_mnk=(params.cluster_shape_m, 1, 1),
498
+ )
499
+
500
+ @staticmethod
501
+ @cute.jit
502
+ def create(
503
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
504
+ ) -> "SingleTileLPTScheduler":
505
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
506
+ return SingleTileLPTScheduler(
507
+ params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip
508
+ )
509
  tile_idx, split_idx, _ = cute.arch.block_idx()
510
  return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
511
 
 
512
  @staticmethod
513
  def get_grid_shape(
514
  params: Params,
 
516
  loc=None,
517
  ip=None,
518
  ) -> Tuple[Int32, Int32, Int32]:
519
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
520
+ return SingleTileLPTScheduler._clc_grid_shape(params)
521
  return (params.total_blocks, params.num_splits, Int32(1))
522
 
523
+ @cute.jit
524
+ def clc_work_to_coords(self, work) -> WorkTileInfo:
525
+ """Convert CLC response (block, head, batch_split) to WorkTileInfo.
526
+
527
+ CLC returns raw grid coordinates — no L2 swizzle (hardware decides order).
528
+ We only apply cluster division, optional LPT block reversal, and split_kv unpacking.
529
+ """
530
+ block_idx = work.tile_idx[0]
531
+ if const_expr(self.params.cluster_shape_m > 1):
532
+ block_idx = block_idx // self.params.cluster_shape_m
533
+ if const_expr(self.params.lpt):
534
+ # Longest-processing-time-first: reverse block order
535
+ block_idx = self.params.num_block - 1 - block_idx
536
+ split_idx = Int32(0)
537
+ if const_expr(self.params.is_split_kv):
538
+ batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod)
539
+ else:
540
+ batch_idx = work.tile_idx[2]
541
+ return WorkTileInfo(
542
+ (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)),
543
+ work.is_valid_tile,
544
+ )
545
+
546
  @cute.jit
547
  def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
548
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
549
+ work = self.clc.get_current_work()
550
+ self._tile_idx = work.tile_idx[0]
551
+ return self.clc_work_to_coords(work)
552
+ # Static path: L2-swizzled coordinate mapping
553
  params = self.params
554
  # Implement LPT scheduling coordinate calculation
555
  bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)
 
563
  bidhb_actual = bidhb * params.l2_minor + bidhb_residual
564
  batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)
565
  # Longest-processing-time-first
566
+ if const_expr(params.lpt):
567
+ block = params.num_block - 1 - block
568
  is_valid = self._tile_idx < params.total_blocks
569
  return WorkTileInfo(
570
  (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid
571
  )
572
 
573
+ @cute.jit
574
  def initial_work_tile_info(self, *, loc=None, ip=None):
575
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
576
+ work = self.clc.initial_work_tile_info()
577
+ self._tile_idx = work.tile_idx[0]
578
+ return self.clc_work_to_coords(work)
579
  return self.get_current_work(loc=loc, ip=ip)
580
 
581
  def prefetch_next_work(self, *, loc=None, ip=None):
582
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
583
+ self.clc.prefetch_next_work(loc=loc, ip=ip)
584
 
585
  def advance_to_next_work(self, *, loc=None, ip=None):
586
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
587
+ self.clc.consumer_wait(loc=loc, ip=ip)
588
+ work = self.get_current_work()
589
+ self.clc.consumer_release(loc=loc, ip=ip)
590
+ return work
591
  # Single tile scheduler - set to invalid tile_idx to indicate no more work
592
  self._tile_idx = self.params.total_blocks
593
+ return self.get_current_work()
594
+
595
+ def producer_tail(self, *, loc=None, ip=None):
596
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
597
+ self.clc.producer_tail(loc=loc, ip=ip)
598
 
599
  def __extract_mlir_values__(self):
600
  values, self._values_pos = [], []
601
+ objs = [self.params, self._tile_idx, self._split_idx]
602
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
603
+ objs += [self.clc]
604
+ for obj in objs:
605
  obj_values = cutlass.extract_mlir_values(obj)
606
  values += obj_values
607
  self._values_pos.append(len(obj_values))
 
609
 
610
  def __new_from_mlir_values__(self, values):
611
  obj_list = []
612
+ objs = [self.params, self._tile_idx, self._split_idx]
613
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
614
+ objs += [self.clc]
615
+ for obj, n_items in zip(objs, self._values_pos):
616
  obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
617
  values = values[n_items:]
618
+ return self.__class__(*obj_list, loc=self._loc)
619
 
620
 
621
  class SingleTileLPTBwdScheduler:
 
639
  ) -> "SingleTileLPTBwdScheduler.Params":
640
  size_l2 = 50 * 1024 * 1024
641
  size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
642
+ size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
643
+ # size_one_dqaccum_head = 0
644
  size_one_head = size_one_qdo_head + size_one_dqaccum_head
645
  log2_floor = lambda n: 31 - clz(n)
646
  swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
 
674
  self._ip = ip
675
 
676
  @staticmethod
677
+ def to_underlying_arguments(
678
+ args: TileSchedulerArguments,
679
+ *,
680
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
681
+ loc=None,
682
+ ip=None,
683
+ ) -> Params:
684
+ assert scheduling_mode == SchedulingMode.STATIC, (
685
+ f"SingleTileLPTBwdScheduler only supports STATIC, got {scheduling_mode!r}"
686
+ )
687
  return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)
688
 
689
  @staticmethod
 
734
  def advance_to_next_work(self, *, loc=None, ip=None):
735
  # Single tile scheduler - set to invalid tile_idx to indicate no more work
736
  self._tile_idx = self.params.total_blocks
737
+ return self.get_current_work()
738
 
739
  def __extract_mlir_values__(self):
740
  values, self._values_pos = [], []
 
768
  is_split_kv: cutlass.Constexpr[bool] = False
769
  head_swizzle: cutlass.Constexpr[bool] = False
770
  cluster_shape_m: cutlass.Constexpr[int] = 1
771
+ scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC
772
 
773
  @staticmethod
774
  @cute.jit
775
  def create(
776
+ args: TileSchedulerArguments,
777
+ *,
778
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
779
+ loc=None,
780
+ ip=None,
781
  ) -> "SingleTileVarlenScheduler.Params":
782
+ assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), (
783
+ f"Only STATIC and CLC are supported, got {scheduling_mode!r}"
784
+ )
785
  size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
786
+ # if backward, this is qdo block size
787
+ kv_block_size = (
788
  (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
789
  )
790
+ # if backward, add dqaccum block size to calculate swizzle
791
+ if args.head_swizzle:
792
+ kv_block_size += args.headdim * 4 * args.tile_shape_mn[1]
793
+ max_kvblock_in_l2 = size_l2 // kv_block_size
794
  assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
795
  "At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
796
  )
797
  assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
798
+ # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the
799
+ # flattened-tile decode so cluster unpacking semantics are explicit.
800
+ assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, (
801
+ "Varlen CLC currently requires cluster_shape_mn[0] == 1"
802
+ )
803
  return SingleTileVarlenScheduler.Params(
804
  num_head=args.num_head,
805
  num_batch=args.num_batch,
 
814
  is_split_kv=args.is_split_kv,
815
  head_swizzle=args.head_swizzle,
816
  cluster_shape_m=args.cluster_shape_mn[0],
817
+ scheduling_mode=scheduling_mode,
818
  )
819
 
820
+ def __init__(
821
+ self,
822
+ params: Params,
823
+ tile_idx: Int32,
824
+ split_idx: Int32,
825
+ clc: ClcState | None = None,
826
+ *,
827
+ loc=None,
828
+ ip=None,
829
+ ):
830
  self.params = params
831
  self._tile_idx = tile_idx
832
  self._split_idx = split_idx
833
  self._is_first_block = True
834
+ self.clc = clc
835
  self._loc = loc
836
  self._ip = ip
837
 
838
  @staticmethod
839
+ def to_underlying_arguments(
840
+ args: TileSchedulerArguments,
841
+ *,
842
+ scheduling_mode: SchedulingMode = SchedulingMode.STATIC,
843
+ loc=None,
844
+ ip=None,
845
+ ) -> Params:
846
+ return SingleTileVarlenScheduler.Params.create(
847
+ args, scheduling_mode=scheduling_mode, loc=loc, ip=ip
848
+ )
849
 
850
  @staticmethod
851
+ @cute.jit
852
+ def clc_problem_shape(params: Params):
853
+ return ClcDynamicPersistentTileSchedulerParams(
854
+ problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params),
855
+ cluster_shape_mnk=(1, 1, 1),
856
+ )
857
+
858
+ @staticmethod
859
+ @cute.jit
860
+ def create(
861
+ params: Params, clc: ClcState | None = None, *, loc=None, ip=None
862
+ ) -> "SingleTileVarlenScheduler":
863
+ if const_expr(params.scheduling_mode == SchedulingMode.CLC):
864
+ block_idx = cute.arch.block_idx()
865
+ split_idx = Int32(0)
866
+ if const_expr(params.is_split_kv):
867
+ split_idx = block_idx[1]
868
+ return SingleTileVarlenScheduler(
869
+ params,
870
+ block_idx[0],
871
+ split_idx,
872
+ clc,
873
+ loc=loc,
874
+ ip=ip,
875
+ )
876
  tile_idx, split_idx, _ = cute.arch.block_idx()
877
  return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)
878
 
 
888
  params.total_q
889
  + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)
890
  ) // params.tile_shape_mn[0]
891
+ # Round down to nearest multiple of cluster since odd excess is always padding.
892
  total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m
893
  return (total_blocks_max * params.num_head, params.num_splits, Int32(1))
894
 
 
916
  )
917
 
918
  @cute.jit
919
+ def _varlen_coord_map(self) -> WorkTileInfo:
920
+ """Map self._tile_idx to (block, head, batch) via warp-level prefix sums."""
921
  params = self.params
922
  lane_idx = cute.arch.lane_idx()
923
  num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)
 
970
  num_n_blocks = (
971
  num_m_blocks
972
  * params.tile_shape_mn[0]
973
+ * params.cluster_shape_m
974
  // params.qhead_per_kvhead_packgqa
975
  // params.tile_shape_mn[1]
976
  )
 
1015
  split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)
1016
  return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)
1017
 
1018
+ @cute.jit
1019
+ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
1020
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1021
+ clc_work = self.clc.get_current_work()
1022
+ # Default to grid_dim (one past last valid flat index) so _varlen_coord_map
1023
+ # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when
1024
+ # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural
1025
+ # mismatch on self inside the runtime if.
1026
+ new_tile_idx = cute.arch.grid_dim()[0]
1027
+ new_split_idx = Int32(0)
1028
+ if clc_work.is_valid_tile:
1029
+ new_tile_idx = clc_work.tile_idx[0]
1030
+ if const_expr(self.params.is_split_kv):
1031
+ new_split_idx = clc_work.tile_idx[1]
1032
+ self._tile_idx = new_tile_idx
1033
+ self._split_idx = new_split_idx
1034
+ return self._varlen_coord_map()
1035
+
1036
+ @cute.jit
1037
  def initial_work_tile_info(self, *, loc=None, ip=None):
1038
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1039
+ clc_work = self.clc.initial_work_tile_info()
1040
+ # See get_current_work for why grid_dim and local-then-assign.
1041
+ new_tile_idx = cute.arch.grid_dim()[0]
1042
+ new_split_idx = Int32(0)
1043
+ if clc_work.is_valid_tile:
1044
+ new_tile_idx = clc_work.tile_idx[0]
1045
+ if const_expr(self.params.is_split_kv):
1046
+ new_split_idx = clc_work.tile_idx[1]
1047
+ self._tile_idx = new_tile_idx
1048
+ self._split_idx = new_split_idx
1049
+ return self._varlen_coord_map()
1050
 
1051
  def prefetch_next_work(self, *, loc=None, ip=None):
1052
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1053
+ self.clc.prefetch_next_work(loc=loc, ip=ip)
1054
 
1055
  def advance_to_next_work(self, *, loc=None, ip=None):
1056
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1057
+ self.clc.consumer_wait(loc=loc, ip=ip)
1058
+ work = self.get_current_work()
1059
+ self.clc.consumer_release(loc=loc, ip=ip)
1060
+ return work
1061
  self._is_first_block = False
1062
+ return self.get_current_work()
1063
+
1064
+ def producer_tail(self, *, loc=None, ip=None):
1065
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1066
+ self.clc.producer_tail(loc=loc, ip=ip)
1067
 
1068
  def __extract_mlir_values__(self):
1069
  values, self._values_pos = [], []
1070
+ objs = [self.params, self._tile_idx, self._split_idx]
1071
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1072
+ objs += [self.clc]
1073
+ for obj in objs:
1074
  obj_values = cutlass.extract_mlir_values(obj)
1075
  values += obj_values
1076
  self._values_pos.append(len(obj_values))
 
1078
 
1079
  def __new_from_mlir_values__(self, values):
1080
  obj_list = []
1081
+ objs = [self.params, self._tile_idx, self._split_idx]
1082
+ if const_expr(self.params.scheduling_mode == SchedulingMode.CLC):
1083
+ objs += [self.clc]
1084
+ for obj, n_items in zip(objs, self._values_pos):
1085
  obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
1086
  values = values[n_items:]
1087
+ return self.__class__(*obj_list, loc=self._loc)
build/torch-cuda/utils.py CHANGED
@@ -3,12 +3,14 @@
3
  import math
4
  import hashlib
5
  import inspect
 
6
  from typing import Type, Callable, Optional, Tuple, overload
7
 
8
  import cutlass
9
  import cutlass.cute as cute
10
 
11
- from cutlass import Float32, const_expr
 
12
  from cutlass.cutlass_dsl import T, dsl_user_op
13
  from cutlass._mlir.dialects import nvvm, llvm
14
  from cutlass.cute.runtime import from_dlpack
@@ -54,6 +56,17 @@ POLY_EX2 = {
54
  ),
55
  }
56
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def _compute_base_hash(func: Callable) -> str:
59
  """Compute hash from source code or bytecode and closure values."""
@@ -123,6 +136,40 @@ def create_softcap_scoremod(softcap_val):
123
  return scoremod_premask_fn
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
127
  return (
128
  from_dlpack(x, assumed_align=alignment)
@@ -215,6 +262,21 @@ def warp_reduce(
215
  return val
216
 
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  @dsl_user_op
219
  def fmax(
220
  a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
@@ -429,8 +491,48 @@ def shuffle_sync(
429
  return val[0]
430
 
431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  @dsl_user_op
433
  def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
 
 
 
 
 
 
434
  return cutlass.Uint32(
435
  llvm.inline_asm(
436
  T.i32(),
@@ -438,7 +540,7 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) ->
438
  cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
439
  cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
440
  ],
441
- "shr.s32 $0, $1, $2;",
442
  "=r,r,r",
443
  has_side_effects=False,
444
  is_align_stack=False,
 
3
  import math
4
  import hashlib
5
  import inspect
6
+ import os
7
  from typing import Type, Callable, Optional, Tuple, overload
8
 
9
  import cutlass
10
  import cutlass.cute as cute
11
 
12
+ from cutlass import Float32, Int32, const_expr
13
+ from cutlass.cute import FastDivmodDivisor
14
  from cutlass.cutlass_dsl import T, dsl_user_op
15
  from cutlass._mlir.dialects import nvvm, llvm
16
  from cutlass.cute.runtime import from_dlpack
 
56
  ),
57
  }
58
 
59
+ _fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1"
60
+ _fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1"
61
+
62
+
63
+ def _get_use_clc_scheduler_default() -> bool:
64
+ return _fa_clc_enabled
65
+
66
+
67
+ def _get_disable_2cta_default() -> bool:
68
+ return _fa_disable_2cta_enabled
69
+
70
 
71
  def _compute_base_hash(func: Callable) -> str:
72
  """Compute hash from source code or bytecode and closure values."""
 
136
  return scoremod_premask_fn
137
 
138
 
139
+ LOG2_E = math.log2(math.e)
140
+
141
+
142
+ def compute_softmax_scale_log2(softmax_scale, score_mod):
143
+ """Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used.
144
+
145
+ When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale
146
+ to None. When score_mod is present, keep softmax_scale separate so it can be applied before
147
+ the score_mod, and set softmax_scale_log2 to just the change-of-base constant.
148
+
149
+ Returns (softmax_scale_log2, softmax_scale).
150
+ """
151
+ if const_expr(score_mod is None):
152
+ return softmax_scale * LOG2_E, None
153
+ else:
154
+ return LOG2_E, softmax_scale
155
+
156
+
157
+ def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None):
158
+ """Compute FastDivmodDivisor pairs for aux_tensors index computation.
159
+
160
+ Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None.
161
+ """
162
+ if const_expr(aux_tensors is None):
163
+ return None
164
+ seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1)
165
+ seqlen_k = (
166
+ cute.size(mK.shape[0])
167
+ if const_expr(mPageTable is None)
168
+ else mK.shape[0] * mPageTable.shape[1]
169
+ )
170
+ return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k))
171
+
172
+
173
  def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
174
  return (
175
  from_dlpack(x, assumed_align=alignment)
 
262
  return val
263
 
264
 
265
+ @dsl_user_op
266
+ def smid(*, loc=None, ip=None) -> Int32:
267
+ return Int32(
268
+ llvm.inline_asm(
269
+ T.i32(),
270
+ [],
271
+ "mov.u32 $0, %smid;",
272
+ "=r",
273
+ has_side_effects=False,
274
+ is_align_stack=False,
275
+ asm_dialect=llvm.AsmDialect.AD_ATT,
276
+ )
277
+ )
278
+
279
+
280
  @dsl_user_op
281
  def fmax(
282
  a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
 
491
  return val[0]
492
 
493
 
494
+ @dsl_user_op
495
+ def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
496
+ """
497
+ Left-shift val by shift bits using PTX shl.b32 (sign-agnostic).
498
+
499
+ Named ``shl_u32`` (not ``shl_b32``) because python type annotations
500
+ distinguish signed/unsigned.
501
+
502
+ PTX semantics (§9.7.8.8): "Shift amounts greater than the register width N
503
+ are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0.
504
+
505
+ This differs from C/C++ and LLVM IR, where shifting by >= the type width is
506
+ undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain
507
+ Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer
508
+ may treat the result as poison and eliminate dependent code. Inline PTX
509
+ bypasses the LLVM IR shift entirely — the instruction is emitted verbatim
510
+ into PTX where clamping makes it safe for all shift amounts.
511
+ """
512
+ return cutlass.Uint32(
513
+ llvm.inline_asm(
514
+ T.i32(),
515
+ [
516
+ cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
517
+ cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
518
+ ],
519
+ "shl.b32 $0, $1, $2;",
520
+ "=r,r,r",
521
+ has_side_effects=False,
522
+ is_align_stack=False,
523
+ asm_dialect=llvm.AsmDialect.AD_ATT,
524
+ )
525
+ )
526
+
527
+
528
  @dsl_user_op
529
  def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
530
+ """
531
+ Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills).
532
+
533
+ See ``shl_u32`` docstring for why inline PTX is used instead of plain
534
+ CuTeDSL shift operators (LLVM shift-by-type-width UB).
535
+ """
536
  return cutlass.Uint32(
537
  llvm.inline_asm(
538
  T.i32(),
 
540
  cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
541
  cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
542
  ],
543
+ "shr.u32 $0, $1, $2;",
544
  "=r,r,r",
545
  has_side_effects=False,
546
  is_align_stack=False,