danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
57a7570 verified
# Copyright (c) 2025, Tri Dao.
from typing import Optional, Callable, TypeAlias
from dataclasses import dataclass
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, Uint32, const_expr
from .quack import layout_utils
from . import utils as utils
from .seqlen_info import SeqlenInfoQK
MaskGenFn: TypeAlias = Callable[[int], Uint32]
MASK_R2P_CHUNK_SIZE: int = 32
@cute.jit
def r2p_bitmask_below(limit: Int32, s: int) -> Uint32:
"""32-bit R2P bitmask keeping positions < limit (exclusive upper bound).
Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).
Uses inline PTX to avoid shift-by-type-width UB.
"""
m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0)
return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m))
@cute.jit
def r2p_bitmask_above(limit: Int32, s: int) -> Uint32:
"""32-bit R2P bitmask keeping positions >= limit (inclusive lower bound).
Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).
Uses inline PTX to avoid shift-by-type-width UB.
"""
n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0)
return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n))
@cute.jit
def mask_r2p_lambda(
X: cute.Tensor,
mask_gen_fn: cutlass.Constexpr[MaskGenFn],
rank1: bool = False,
) -> None:
"""Apply R2P masking with a custom bitmask generator.
mask_gen_fn(chunk_idx: constexpr int) -> Uint32:
Returns a 32-bit bitmask for the chunk. Bit i set means column
chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf.
"""
ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))
# 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep).
CHUNK_SIZE = MASK_R2P_CHUNK_SIZE
for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)):
mask = mask_gen_fn(s)
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)):
in_bound = cutlass.Boolean(mask & (Uint32(1) << i))
c = s * CHUNK_SIZE + i
if const_expr(rank1):
X[c] = X[c] if in_bound else -Float32.inf
else:
for r in cutlass.range_constexpr(cute.size(X.shape[0])):
X[r, c] = X[r, c] if in_bound else -Float32.inf
@cute.jit
def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32:
"""Transform SM90 MMA column coordinate to R2P element index.
SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ...
Element indices are contiguous: 0, 1, 2, 3, 4, 5, ...
This converts a column-space threshold to element-space for r2p_bitmask_below/above.
"""
return col_limit // 8 * 2 + min(col_limit % 8, 2)
@cute.jit
def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:
"""Convert a row coordinate to an R2P element index in the warp-group interleaved layout.
In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom
distributes rows in an interleaved pattern: elements 0..num_rep-1 map to
rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to
rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on.
Row-coordinate thresholds (causal limits, window bounds, uih_len) must be
converted to element indices before use with r2p_bitmask_above/below.
Rows not owned by this thread (in the gap between warp groups) are clamped
to the boundary element index, which is safe because R2P thresholds are
monotonic.
Example with num_rep=16, num_wg=2:
row 0 -> elem 0, row 15 -> elem 15,
row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped),
row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31.
"""
return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)
@dataclass(frozen=True)
class AttentionMask:
tile_m: cutlass.Constexpr[int]
tile_n: cutlass.Constexpr[int]
seqlen_info: SeqlenInfoQK
window_size_left: Optional[Int32] = None
window_size_right: Optional[Int32] = None
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA
swap_AB: cutlass.Constexpr[bool] = False
@property
def seqlen_q(self) -> Int32:
return self.seqlen_info.seqlen_q
@property
def seqlen_k(self) -> Int32:
return self.seqlen_info.seqlen_k
@cute.jit
def apply_mask(
self,
acc_S: cute.Tensor,
batch_idx: cutlass.Int32,
head_idx: cutlass.Int32,
m_block: cutlass.Int32,
n_block: cutlass.Int32,
thr_mma: cute.TiledMma,
mask_seqlen: cutlass.Constexpr[bool],
mask_causal: cutlass.Constexpr[bool],
mask_local: cutlass.Constexpr[bool] = False,
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
) -> None:
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB)
acc_shape = (self.tile_m, self.tile_n)
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB)
# We use t0ScS as these indices are known at compile time. We then must subtract the
# column limit by the thread column offset.
t0ScS_mn = layout_utils.reshape_acc_to_mn(
thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
)
ROW = 0 if const_expr(not self.swap_AB) else 1
COL = 1 if const_expr(not self.swap_AB) else 0
thr_col_offset = tScS_mn[0][COL]
# To handle edge cases of completely masked out rows where n_block_max = 0,
# we treat negative n_blocks as 0th n_block
# TODO: find more transparent solution
if n_block < 0:
n_block = 0
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
if const_expr(not mask_causal and not mask_local and mask_mod is None):
if const_expr(mask_seqlen):
r2p = const_expr(not self.swap_AB)
if const_expr(not r2p):
# traverse column index.
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]
else:
seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit)
mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s))
elif const_expr(
not mask_causal and not mask_local and mask_mod is not None
): # FlexAttention mask mod
nrow = const_expr(cute.size(tScS_mn.shape[0]))
ncol = const_expr(cute.size(tScS_mn.shape[1]))
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
and fastdiv_mods[1] is not None
)
wrap_aux_indices = const_expr(
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
)
for r in cutlass.range_constexpr(nrow):
# Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.
local_row = tScS_mn[r, 0][ROW]
global_row_idx = local_row + m_block * self.tile_m
row_for_mod = global_row_idx
head_idx_for_mod = head_idx
if const_expr(self.qhead_per_kvhead_packgqa != 1):
head_offset = global_row_idx % self.qhead_per_kvhead_packgqa
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa
row_for_seqlen = row_for_mod
if const_expr(wrap_aux_indices):
_, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])
for col in cutlass.range_constexpr(ncol):
col_idx_local = t0ScS_mn[0, col][COL]
# Convert to absolute column index
global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n
col_for_mod = global_col_idx
if const_expr(wrap_aux_indices):
_, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
q_idx_ssa,
kv_idx_ssa,
self.seqlen_info,
aux_tensors,
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
if const_expr(mask_seqlen):
out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (
global_col_idx >= self.seqlen_k
)
if out_of_bounds:
acc_S_mn[r, col] = -cutlass.Float32.inf
else:
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
else:
acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf
else: # Causal or local
if const_expr(not self.swap_AB):
# If PackGQA, we split the work of compute divmod among threads in the same row
threads_per_row = thr_mma.tv_layout_C.shape[0][0]
mma_m_idx = None
if const_expr(self.qhead_per_kvhead_packgqa != 1):
assert not self.swap_AB, "swap_AB with PackGQA not supported yet"
assert cute.arch.WARP_SIZE % threads_per_row == 0, (
"threads_per_row must divide WARP_SIZE"
)
assert cute.size(acc_S_mn.shape[0]) <= threads_per_row
tidx = thr_mma.thr_idx
mma_m_idx = (
m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0]
) // self.qhead_per_kvhead_packgqa
causal_row_offset = (
1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset
)
if const_expr(mask_causal):
r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
# get the column index limit based on current row. Only consider the row index, so the column index sets to 0.
if const_expr(self.qhead_per_kvhead_packgqa == 1):
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
else:
row_idx = utils.shuffle_sync(
mma_m_idx, r % threads_per_row, width=threads_per_row
)
col_limit_right = row_idx + causal_row_offset
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
if const_expr(not r2p):
# traverse column index.
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
acc_S_mn[r, c] = (
-Float32.inf
if t0ScS_mn[0, c][1] >= col_limit_right
else acc_S_mn[r, c]
)
else:
col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right)
mask_r2p_lambda(
acc_S_mn[r, None],
lambda s: r2p_bitmask_below(col_limit_r2p, s),
rank1=True,
)
else: # Local
local_row_offset_right = (
causal_row_offset + self.window_size_right
if const_expr(self.window_size_right is not None)
else None
)
local_row_offset_left = (
causal_row_offset - 1 - self.window_size_left
if const_expr(self.window_size_left is not None)
else None
)
r2p_local = const_expr(not self.swap_AB)
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
if const_expr(self.qhead_per_kvhead_packgqa == 1):
row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m
else:
row_idx = utils.shuffle_sync(
mma_m_idx, r % threads_per_row, width=threads_per_row
)
if const_expr(self.window_size_right is not None):
col_limit_right = row_idx + local_row_offset_right
else:
col_limit_right = self.tile_n
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
col_limit_left = (
row_idx + local_row_offset_left
if const_expr(self.window_size_left is not None)
else 0
)
if const_expr(not r2p_local):
# traverse column index.
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
col_idx = t0ScS_mn[0, c][1]
if col_idx >= col_limit_right or col_idx < col_limit_left:
acc_S_mn[r, c] = -Float32.inf
else:
col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right)
col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left)
def mask_gen_fn(s: int) -> Uint32:
return r2p_bitmask_below(
col_limit_right_r2p, s
) & r2p_bitmask_above(col_limit_left_r2p, s)
mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True)
else: # swap_AB
assert self.qhead_per_kvhead_packgqa == 1
thr_row_offset = tScS_mn[0][ROW]
causal_row_offset = (
seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset
)
if const_expr(mask_causal):
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
col0 = t0ScS_mn[0, c][COL]
# If col0 is beyond the column limit, we want to mask out the entire
# column, by setting row limit to be self.tile_m.
row_limit_top = (
self.tile_m
if col0 >= seqlenk_col_limit and mask_seqlen
else col0 - causal_row_offset
)
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
acc_S_mn[r, c] = (
-Float32.inf
if t0ScS_mn[r, 0][ROW] < row_limit_top
else acc_S_mn[r, c]
)
else:
for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):
col0 = t0ScS_mn[0, c][COL]
# If col0 is beyond the column limit, we want to mask out the entire
# column, by setting row limit to be self.tile_m.
row_limit_top = (
self.tile_m
if col0 >= seqlenk_col_limit and mask_seqlen
else (
col0 - causal_row_offset - self.window_size_right
if const_expr(self.window_size_right is not None)
else 0
)
)
row_limit_bot = (
col0 - causal_row_offset + self.window_size_left
if const_expr(self.window_size_left is not None)
else self.tile_m
)
for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):
row_idx = t0ScS_mn[r, 0][ROW]
acc_S_mn[r, c] = (
-Float32.inf
if row_idx < row_limit_top or row_idx > row_limit_bot
else acc_S_mn[r, c]
)
@cute.jit
def apply_mask_sm100(
self,
acc_S: cute.Tensor,
m_block: Int32,
n_block: Int32,
thr_mma: cute.TiledMma,
thr_tmem_load: cute.TiledCopy,
mask_seqlen: cutlass.Constexpr[bool],
mask_causal: cutlass.Constexpr[bool],
mask_local: cutlass.Constexpr[bool] = False,
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
batch_idx: Int32 = None,
head_idx: Int32 = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
head_divmod=None,
check_q_boundary: bool = False,
) -> None:
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
acc_shape = (self.tile_m, self.tile_n)
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
tScS = thr_mma.partition_C(cS)
tScS = tScS[(None, None), 0, 0]
tScS_t2r = thr_tmem_load.partition_D(tScS)
# To handle edge cases of completely masked out rows where n_block_max = 0,
# we treat negative n_blocks as 0th n_block
# TODO: find more transparent solution
if n_block < 0:
n_block = 0
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n
r2p = True
if const_expr(not mask_causal and not mask_local and mask_mod is None):
if const_expr(mask_seqlen):
if const_expr(not r2p):
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
# if tScS_t2r[i][1] >= seqlenk_col_limit:
# acc_S[i] = -Float32.inf
# For some reason the 2 lines above generate really bad SASS
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]
else:
mask_r2p_lambda(
acc_S,
lambda s: r2p_bitmask_below(seqlenk_col_limit, s),
rank1=True,
)
elif const_expr(not mask_causal and not mask_local and mask_mod is not None):
# Block sparse case w/ mask_mod
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
and fastdiv_mods[1] is not None
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
ncol = const_expr(cute.size(tScS_t2r.shape))
for i in cutlass.range_constexpr(ncol):
row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]
col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]
global_row = row_coord + m_block * self.tile_m
global_col = col_coord + n_block * self.tile_n
if const_expr(self.qhead_per_kvhead_packgqa != 1):
assert head_divmod is not None
mask_row, head_offset = divmod(global_row, head_divmod)
head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset
else:
head_idx_for_mod = head_idx
mask_row = global_row
mask_row_for_mod = mask_row
if const_expr(has_fastdiv and aux_tensors is not None):
if check_q_boundary:
_, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])
global_col_for_mod = global_col
if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):
_, global_col_for_mod = divmod(global_col, fastdiv_mods[1])
head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)
mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
mask_row_ssa,
kv_idx_ssa,
self.seqlen_info,
aux_tensors,
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
acc_S[i] = acc_S[i] if cond else -Float32.inf
if const_expr(mask_seqlen):
acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]
if check_q_boundary:
acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]
else: # Causal or local
causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q
row_idx = tScS_t2r[0][0] + m_block * self.tile_m
if const_expr(self.qhead_per_kvhead_packgqa != 1):
row_idx = row_idx // self.qhead_per_kvhead_packgqa
if const_expr(mask_causal):
col_limit_right = row_idx + causal_row_offset + 1
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
# if cute.arch.thread_idx()[0] % 32 == 0:
# cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset)
ncol = const_expr(cute.size(tScS_t2r.shape))
if const_expr(not r2p):
for i in cutlass.range(ncol, unroll_full=True):
acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]
else:
mask_r2p_lambda(
acc_S,
lambda s: r2p_bitmask_below(col_limit_right, s),
rank1=True,
)
else:
local_row_offset_right = (
causal_row_offset + 1 + self.window_size_right
if const_expr(self.window_size_right is not None)
else None
)
local_row_offset_left = (
causal_row_offset - self.window_size_left
if const_expr(self.window_size_left is not None)
else None
)
if const_expr(self.window_size_right is not None):
col_limit_right = row_idx + local_row_offset_right
else:
col_limit_right = self.tile_n
if const_expr(mask_seqlen):
col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)
col_limit_left = (
row_idx + local_row_offset_left
if const_expr(self.window_size_left is not None)
else 0
)
if const_expr(not r2p):
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
col_idx = tScS_t2r[i][1]
acc_S[i] = (
-Float32.inf
if col_idx >= col_limit_right or col_idx < col_limit_left
else acc_S[i]
)
else:
# Dual-bound R2P masking for SM100.
# Masks elements where: NOT (col_limit_left <= col < col_limit_right)
def mask_gen_fn(s: int) -> Uint32:
return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above(
col_limit_left, s
)
mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True)
@cute.jit
def apply_mask_sm100_transposed(
self,
acc_S: cute.Tensor,
tScS_t2r: cute.Tensor,
t0ScS_t2r: cute.Tensor,
m_block: cutlass.Int32,
n_block: cutlass.Int32,
mask_seqlen: cutlass.Constexpr,
mask_causal: cutlass.Constexpr,
mask_local: cutlass.Constexpr,
mask_mod: cutlass.Constexpr[Optional[Callable]] = None,
batch_idx: Int32 = None,
head_idx: Int32 = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
is_full_block: bool = False,
check_m_boundary: bool = True,
) -> None:
"""
Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.
Coordinate conventio:
- ROW corresponds to Q (m_block)
- COL corresponds to KV (n_block)
is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking.
check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks).
When iterating m_blocks in forward order, only the last m_block may be partial.
"""
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
ROW = 0 if const_expr(not self.swap_AB) else 1
COL = 1 if const_expr(not self.swap_AB) else 0
# assert t0ScS_t2r[0][COL] == 0, "col0 == 0" # tmp comment for 2-cta bwd
thr_col_offset = tScS_t2r[0][COL]
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
if const_expr(not mask_causal and not mask_local and mask_mod is not None):
# Block sparse case with mask_mod (backward)
#
# Coordinate convention: ROW → Q (m_block), COL → KV (n_block).
# These already account for swap_AB.
#
# FULL blocks: mask_mod returns True for all elements, so skip it.
# Still need seqlen bounds check (elements may be OOB on last m_block).
# PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds.
if is_full_block:
if const_expr(mask_seqlen):
if seqlenk_col_limit <= 0:
# Entire tile is OOB for K
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
acc_S[i] = -cutlass.Float32.inf
elif check_m_boundary:
# Last m_block: check Q and K boundaries
ncol = const_expr(cute.size(tScS_t2r.shape))
for i in cutlass.range_constexpr(ncol):
row_coord = tScS_t2r[i][ROW]
col_coord = tScS_t2r[i][COL]
global_q = row_coord + m_block * self.tile_m
global_kv = col_coord + n_block * self.tile_n
q_out_of_bounds = global_q >= self.seqlen_q
kv_out_of_bounds = global_kv >= self.seqlen_k
out_of_bounds = q_out_of_bounds or kv_out_of_bounds
acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
else:
# Partial block
has_fastdiv = const_expr(
fastdiv_mods is not None
and fastdiv_mods[0] is not None
and fastdiv_mods[1] is not None
)
wrap_aux_indices = const_expr(
has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)
)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)
head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)
ncol = const_expr(cute.size(tScS_t2r.shape))
for i in cutlass.range_constexpr(ncol):
row_coord = tScS_t2r[i][ROW]
col_coord = tScS_t2r[i][COL]
global_q = row_coord + m_block * self.tile_m
global_kv = col_coord + n_block * self.tile_n
q_idx_for_mod = global_q
kv_idx_for_mod = global_kv
if const_expr(wrap_aux_indices):
_, q_idx_for_mod = divmod(global_q, fastdiv_mods[0])
_, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1])
q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32)
kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32)
mask_value = mask_mod(
batch_idx_ssa,
head_idx_ssa,
q_idx_ssa,
kv_idx_ssa,
self.seqlen_info,
aux_tensors,
)
cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))
acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf
if const_expr(mask_seqlen):
# check_m_boundary=False skips q check for non-boundary m_blocks
q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q)
kv_out_of_bounds = global_kv >= self.seqlen_k
out_of_bounds = q_out_of_bounds or kv_out_of_bounds
acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]
elif const_expr(not mask_causal and not mask_local):
if const_expr(mask_seqlen):
if seqlenk_col_limit <= 0:
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
acc_S[i] = -cutlass.Float32.inf
else: # Causal or local
thr_row_offset = tScS_t2r[0][ROW]
seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset
causal_offset = seqlenq_row_limit - seqlenk_col_limit
if const_expr(mask_causal):
# tidx = cute.arch.thread_idx()[0] % 256
# if tidx < 32:
# cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])
row_limit_top = causal_offset
if const_expr(mask_seqlen):
# If col is beyond the column limit, we want to mask out the entire
# column, by setting row limit to be self.tile_m.
if seqlenk_col_limit <= 0:
row_limit_top = self.tile_m
r2p = True
if const_expr(not r2p):
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
acc_S[i] = (
-cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i]
)
else:
num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32
num_wg = 2
row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)
mask_r2p_lambda(
acc_S,
lambda s: r2p_bitmask_above(row_limit, s),
rank1=True,
)
else:
if const_expr(self.window_size_right is not None):
row_limit_top = causal_offset - self.window_size_right
else:
row_limit_top = 0
if const_expr(self.window_size_left is not None):
row_limit_bot = causal_offset + self.window_size_left
if const_expr(mask_seqlen):
if seqlenk_col_limit <= 0:
row_limit_top = self.tile_m
r2p = True
if const_expr(not r2p):
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
row_idx = t0ScS_t2r[i][ROW]
local_mask = row_idx < row_limit_top
if const_expr(self.window_size_left is not None):
local_mask |= row_idx > row_limit_bot
acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]
else:
def mask_gen_fn(s: int) -> Uint32:
num_rep = cute.size(tScS_t2r, mode=[0])
num_wg = 2
row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)
mask = r2p_bitmask_above(row_limit, s)
if const_expr(self.window_size_left is not None):
row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg)
mask = mask & r2p_bitmask_below(row_limit_bottom, s)
return mask
mask_r2p_lambda(
acc_S,
mask_gen_fn,
rank1=True,
)