| |
|
|
| from typing import Optional, Callable, TypeAlias |
| from dataclasses import dataclass |
|
|
| import cutlass |
| import cutlass.cute as cute |
| from cutlass import Float32, Int32, Uint32, const_expr |
|
|
| from .quack import layout_utils |
| from . import utils as utils |
| from .seqlen_info import SeqlenInfoQK |
|
|
| MaskGenFn: TypeAlias = Callable[[int], Uint32] |
| MASK_R2P_CHUNK_SIZE: int = 32 |
|
|
|
|
| @cute.jit |
| def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: |
| """32-bit R2P bitmask keeping positions < limit (exclusive upper bound). |
| |
| Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). |
| Uses inline PTX to avoid shift-by-type-width UB. |
| """ |
| m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) |
| return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) |
|
|
|
|
| @cute.jit |
| def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: |
| """32-bit R2P bitmask keeping positions >= limit (inclusive lower bound). |
| |
| Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). |
| Uses inline PTX to avoid shift-by-type-width UB. |
| """ |
| n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) |
| return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) |
|
|
|
|
| @cute.jit |
| def mask_r2p_lambda( |
| X: cute.Tensor, |
| mask_gen_fn: cutlass.Constexpr[MaskGenFn], |
| rank1: bool = False, |
| ) -> None: |
| """Apply R2P masking with a custom bitmask generator. |
| |
| mask_gen_fn(chunk_idx: constexpr int) -> Uint32: |
| Returns a 32-bit bitmask for the chunk. Bit i set means column |
| chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf. |
| """ |
| ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) |
| |
| CHUNK_SIZE = MASK_R2P_CHUNK_SIZE |
| for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)): |
| mask = mask_gen_fn(s) |
| |
| for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)): |
| in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) |
| c = s * CHUNK_SIZE + i |
| if const_expr(rank1): |
| X[c] = X[c] if in_bound else -Float32.inf |
| else: |
| for r in cutlass.range_constexpr(cute.size(X.shape[0])): |
| X[r, c] = X[r, c] if in_bound else -Float32.inf |
|
|
|
|
| @cute.jit |
| def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32: |
| """Transform SM90 MMA column coordinate to R2P element index. |
| |
| SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ... |
| Element indices are contiguous: 0, 1, 2, 3, 4, 5, ... |
| This converts a column-space threshold to element-space for r2p_bitmask_below/above. |
| """ |
| return col_limit // 8 * 2 + min(col_limit % 8, 2) |
|
|
|
|
| @cute.jit |
| def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: |
| """Convert a row coordinate to an R2P element index in the warp-group interleaved layout. |
| |
| In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom |
| distributes rows in an interleaved pattern: elements 0..num_rep-1 map to |
| rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to |
| rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on. |
| Row-coordinate thresholds (causal limits, window bounds, uih_len) must be |
| converted to element indices before use with r2p_bitmask_above/below. |
| |
| Rows not owned by this thread (in the gap between warp groups) are clamped |
| to the boundary element index, which is safe because R2P thresholds are |
| monotonic. |
| |
| Example with num_rep=16, num_wg=2: |
| row 0 -> elem 0, row 15 -> elem 15, |
| row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped), |
| row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31. |
| """ |
| return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) |
|
|
|
|
| @dataclass(frozen=True) |
| class AttentionMask: |
| tile_m: cutlass.Constexpr[int] |
| tile_n: cutlass.Constexpr[int] |
| seqlen_info: SeqlenInfoQK |
| window_size_left: Optional[Int32] = None |
| window_size_right: Optional[Int32] = None |
| qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 |
| swap_AB: cutlass.Constexpr[bool] = False |
|
|
| @property |
| def seqlen_q(self) -> Int32: |
| return self.seqlen_info.seqlen_q |
|
|
| @property |
| def seqlen_k(self) -> Int32: |
| return self.seqlen_info.seqlen_k |
|
|
| @cute.jit |
| def apply_mask( |
| self, |
| acc_S: cute.Tensor, |
| batch_idx: cutlass.Int32, |
| head_idx: cutlass.Int32, |
| m_block: cutlass.Int32, |
| n_block: cutlass.Int32, |
| thr_mma: cute.TiledMma, |
| mask_seqlen: cutlass.Constexpr[bool], |
| mask_causal: cutlass.Constexpr[bool], |
| mask_local: cutlass.Constexpr[bool] = False, |
| mask_mod: cutlass.Constexpr[Optional[Callable]] = None, |
| aux_tensors: Optional[list] = None, |
| fastdiv_mods=(None, None), |
| ) -> None: |
| assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" |
| acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB) |
| acc_shape = (self.tile_m, self.tile_n) |
| cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) |
| tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB) |
| |
| |
| t0ScS_mn = layout_utils.reshape_acc_to_mn( |
| thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB |
| ) |
| ROW = 0 if const_expr(not self.swap_AB) else 1 |
| COL = 1 if const_expr(not self.swap_AB) else 0 |
| thr_col_offset = tScS_mn[0][COL] |
| |
| |
| |
| if n_block < 0: |
| n_block = 0 |
| seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset |
| if const_expr(not mask_causal and not mask_local and mask_mod is None): |
| if const_expr(mask_seqlen): |
| r2p = const_expr(not self.swap_AB) |
| if const_expr(not r2p): |
| |
| for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): |
| oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit |
| for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): |
| acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] |
| else: |
| seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit) |
| mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s)) |
|
|
| elif const_expr( |
| not mask_causal and not mask_local and mask_mod is not None |
| ): |
| nrow = const_expr(cute.size(tScS_mn.shape[0])) |
| ncol = const_expr(cute.size(tScS_mn.shape[1])) |
| has_fastdiv = const_expr( |
| fastdiv_mods is not None |
| and fastdiv_mods[0] is not None |
| and fastdiv_mods[1] is not None |
| ) |
| wrap_aux_indices = const_expr( |
| has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) |
| ) |
|
|
| for r in cutlass.range_constexpr(nrow): |
| |
| local_row = tScS_mn[r, 0][ROW] |
| global_row_idx = local_row + m_block * self.tile_m |
| row_for_mod = global_row_idx |
| head_idx_for_mod = head_idx |
| if const_expr(self.qhead_per_kvhead_packgqa != 1): |
| head_offset = global_row_idx % self.qhead_per_kvhead_packgqa |
| head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset |
| row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa |
| row_for_seqlen = row_for_mod |
| if const_expr(wrap_aux_indices): |
| _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0]) |
|
|
| for col in cutlass.range_constexpr(ncol): |
| col_idx_local = t0ScS_mn[0, col][COL] |
| |
| global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n |
| col_for_mod = global_col_idx |
| if const_expr(wrap_aux_indices): |
| _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) |
|
|
| batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) |
| head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) |
| q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32) |
| kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32) |
| mask_value = mask_mod( |
| batch_idx_ssa, |
| head_idx_ssa, |
| q_idx_ssa, |
| kv_idx_ssa, |
| self.seqlen_info, |
| aux_tensors, |
| ) |
| cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) |
| if const_expr(mask_seqlen): |
| out_of_bounds = (row_for_seqlen >= self.seqlen_q) or ( |
| global_col_idx >= self.seqlen_k |
| ) |
| if out_of_bounds: |
| acc_S_mn[r, col] = -cutlass.Float32.inf |
| else: |
| acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf |
| else: |
| acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf |
|
|
| else: |
| if const_expr(not self.swap_AB): |
| |
| threads_per_row = thr_mma.tv_layout_C.shape[0][0] |
| mma_m_idx = None |
| if const_expr(self.qhead_per_kvhead_packgqa != 1): |
| assert not self.swap_AB, "swap_AB with PackGQA not supported yet" |
| assert cute.arch.WARP_SIZE % threads_per_row == 0, ( |
| "threads_per_row must divide WARP_SIZE" |
| ) |
| assert cute.size(acc_S_mn.shape[0]) <= threads_per_row |
| tidx = thr_mma.thr_idx |
| mma_m_idx = ( |
| m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] |
| ) // self.qhead_per_kvhead_packgqa |
| causal_row_offset = ( |
| 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset |
| ) |
| if const_expr(mask_causal): |
| r2p = const_expr(not self.swap_AB) |
| for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): |
| |
| if const_expr(self.qhead_per_kvhead_packgqa == 1): |
| row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m |
| else: |
| row_idx = utils.shuffle_sync( |
| mma_m_idx, r % threads_per_row, width=threads_per_row |
| ) |
| col_limit_right = row_idx + causal_row_offset |
| if const_expr(mask_seqlen): |
| col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) |
| if const_expr(not r2p): |
| |
| for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): |
| acc_S_mn[r, c] = ( |
| -Float32.inf |
| if t0ScS_mn[0, c][1] >= col_limit_right |
| else acc_S_mn[r, c] |
| ) |
| else: |
| col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right) |
| mask_r2p_lambda( |
| acc_S_mn[r, None], |
| lambda s: r2p_bitmask_below(col_limit_r2p, s), |
| rank1=True, |
| ) |
| else: |
| local_row_offset_right = ( |
| causal_row_offset + self.window_size_right |
| if const_expr(self.window_size_right is not None) |
| else None |
| ) |
| local_row_offset_left = ( |
| causal_row_offset - 1 - self.window_size_left |
| if const_expr(self.window_size_left is not None) |
| else None |
| ) |
| r2p_local = const_expr(not self.swap_AB) |
| for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): |
| if const_expr(self.qhead_per_kvhead_packgqa == 1): |
| row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m |
| else: |
| row_idx = utils.shuffle_sync( |
| mma_m_idx, r % threads_per_row, width=threads_per_row |
| ) |
| if const_expr(self.window_size_right is not None): |
| col_limit_right = row_idx + local_row_offset_right |
| else: |
| col_limit_right = self.tile_n |
| if const_expr(mask_seqlen): |
| col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) |
| col_limit_left = ( |
| row_idx + local_row_offset_left |
| if const_expr(self.window_size_left is not None) |
| else 0 |
| ) |
| if const_expr(not r2p_local): |
| |
| for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): |
| col_idx = t0ScS_mn[0, c][1] |
| if col_idx >= col_limit_right or col_idx < col_limit_left: |
| acc_S_mn[r, c] = -Float32.inf |
| else: |
| col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right) |
| col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left) |
|
|
| def mask_gen_fn(s: int) -> Uint32: |
| return r2p_bitmask_below( |
| col_limit_right_r2p, s |
| ) & r2p_bitmask_above(col_limit_left_r2p, s) |
|
|
| mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True) |
| else: |
| assert self.qhead_per_kvhead_packgqa == 1 |
| thr_row_offset = tScS_mn[0][ROW] |
| causal_row_offset = ( |
| seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset |
| ) |
| if const_expr(mask_causal): |
| for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): |
| col0 = t0ScS_mn[0, c][COL] |
| |
| |
| row_limit_top = ( |
| self.tile_m |
| if col0 >= seqlenk_col_limit and mask_seqlen |
| else col0 - causal_row_offset |
| ) |
| for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): |
| acc_S_mn[r, c] = ( |
| -Float32.inf |
| if t0ScS_mn[r, 0][ROW] < row_limit_top |
| else acc_S_mn[r, c] |
| ) |
| else: |
| for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): |
| col0 = t0ScS_mn[0, c][COL] |
| |
| |
| row_limit_top = ( |
| self.tile_m |
| if col0 >= seqlenk_col_limit and mask_seqlen |
| else ( |
| col0 - causal_row_offset - self.window_size_right |
| if const_expr(self.window_size_right is not None) |
| else 0 |
| ) |
| ) |
| row_limit_bot = ( |
| col0 - causal_row_offset + self.window_size_left |
| if const_expr(self.window_size_left is not None) |
| else self.tile_m |
| ) |
| for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): |
| row_idx = t0ScS_mn[r, 0][ROW] |
| acc_S_mn[r, c] = ( |
| -Float32.inf |
| if row_idx < row_limit_top or row_idx > row_limit_bot |
| else acc_S_mn[r, c] |
| ) |
|
|
| @cute.jit |
| def apply_mask_sm100( |
| self, |
| acc_S: cute.Tensor, |
| m_block: Int32, |
| n_block: Int32, |
| thr_mma: cute.TiledMma, |
| thr_tmem_load: cute.TiledCopy, |
| mask_seqlen: cutlass.Constexpr[bool], |
| mask_causal: cutlass.Constexpr[bool], |
| mask_local: cutlass.Constexpr[bool] = False, |
| mask_mod: cutlass.Constexpr[Optional[Callable]] = None, |
| batch_idx: Int32 = None, |
| head_idx: Int32 = None, |
| aux_tensors: Optional[list] = None, |
| fastdiv_mods=(None, None), |
| head_divmod=None, |
| check_q_boundary: bool = False, |
| ) -> None: |
| assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" |
| acc_shape = (self.tile_m, self.tile_n) |
| cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) |
| tScS = thr_mma.partition_C(cS) |
| tScS = tScS[(None, None), 0, 0] |
| tScS_t2r = thr_tmem_load.partition_D(tScS) |
| |
| |
| |
| if n_block < 0: |
| n_block = 0 |
| seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n |
| r2p = True |
| if const_expr(not mask_causal and not mask_local and mask_mod is None): |
| if const_expr(mask_seqlen): |
| if const_expr(not r2p): |
| for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): |
| |
| |
| |
| acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] |
| else: |
| mask_r2p_lambda( |
| acc_S, |
| lambda s: r2p_bitmask_below(seqlenk_col_limit, s), |
| rank1=True, |
| ) |
|
|
| elif const_expr(not mask_causal and not mask_local and mask_mod is not None): |
| |
| has_fastdiv = const_expr( |
| fastdiv_mods is not None |
| and fastdiv_mods[0] is not None |
| and fastdiv_mods[1] is not None |
| ) |
| batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) |
|
|
| ncol = const_expr(cute.size(tScS_t2r.shape)) |
| for i in cutlass.range_constexpr(ncol): |
| row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] |
| col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] |
| global_row = row_coord + m_block * self.tile_m |
| global_col = col_coord + n_block * self.tile_n |
|
|
| if const_expr(self.qhead_per_kvhead_packgqa != 1): |
| assert head_divmod is not None |
| mask_row, head_offset = divmod(global_row, head_divmod) |
| head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset |
| else: |
| head_idx_for_mod = head_idx |
| mask_row = global_row |
|
|
| mask_row_for_mod = mask_row |
| if const_expr(has_fastdiv and aux_tensors is not None): |
| if check_q_boundary: |
| _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) |
| global_col_for_mod = global_col |
| if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): |
| _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) |
|
|
| head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) |
| mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) |
| kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) |
| mask_value = mask_mod( |
| batch_idx_ssa, |
| head_idx_ssa, |
| mask_row_ssa, |
| kv_idx_ssa, |
| self.seqlen_info, |
| aux_tensors, |
| ) |
| cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) |
| acc_S[i] = acc_S[i] if cond else -Float32.inf |
| if const_expr(mask_seqlen): |
| acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] |
| if check_q_boundary: |
| acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] |
|
|
| else: |
| causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q |
| row_idx = tScS_t2r[0][0] + m_block * self.tile_m |
| if const_expr(self.qhead_per_kvhead_packgqa != 1): |
| row_idx = row_idx // self.qhead_per_kvhead_packgqa |
| if const_expr(mask_causal): |
| col_limit_right = row_idx + causal_row_offset + 1 |
| if const_expr(mask_seqlen): |
| col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) |
| |
| |
| ncol = const_expr(cute.size(tScS_t2r.shape)) |
| if const_expr(not r2p): |
| for i in cutlass.range(ncol, unroll_full=True): |
| acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] |
| else: |
| mask_r2p_lambda( |
| acc_S, |
| lambda s: r2p_bitmask_below(col_limit_right, s), |
| rank1=True, |
| ) |
| else: |
| local_row_offset_right = ( |
| causal_row_offset + 1 + self.window_size_right |
| if const_expr(self.window_size_right is not None) |
| else None |
| ) |
| local_row_offset_left = ( |
| causal_row_offset - self.window_size_left |
| if const_expr(self.window_size_left is not None) |
| else None |
| ) |
| if const_expr(self.window_size_right is not None): |
| col_limit_right = row_idx + local_row_offset_right |
| else: |
| col_limit_right = self.tile_n |
| if const_expr(mask_seqlen): |
| col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) |
| col_limit_left = ( |
| row_idx + local_row_offset_left |
| if const_expr(self.window_size_left is not None) |
| else 0 |
| ) |
| if const_expr(not r2p): |
| |
| for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): |
| col_idx = tScS_t2r[i][1] |
| acc_S[i] = ( |
| -Float32.inf |
| if col_idx >= col_limit_right or col_idx < col_limit_left |
| else acc_S[i] |
| ) |
| else: |
| |
| |
|
|
| def mask_gen_fn(s: int) -> Uint32: |
| return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above( |
| col_limit_left, s |
| ) |
|
|
| mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True) |
|
|
| @cute.jit |
| def apply_mask_sm100_transposed( |
| self, |
| acc_S: cute.Tensor, |
| tScS_t2r: cute.Tensor, |
| t0ScS_t2r: cute.Tensor, |
| m_block: cutlass.Int32, |
| n_block: cutlass.Int32, |
| mask_seqlen: cutlass.Constexpr, |
| mask_causal: cutlass.Constexpr, |
| mask_local: cutlass.Constexpr, |
| mask_mod: cutlass.Constexpr[Optional[Callable]] = None, |
| batch_idx: Int32 = None, |
| head_idx: Int32 = None, |
| aux_tensors: Optional[list] = None, |
| fastdiv_mods=(None, None), |
| is_full_block: bool = False, |
| check_m_boundary: bool = True, |
| ) -> None: |
| """ |
| Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. |
| |
| Coordinate conventio: |
| - ROW corresponds to Q (m_block) |
| - COL corresponds to KV (n_block) |
| |
| is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking. |
| check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks). |
| When iterating m_blocks in forward order, only the last m_block may be partial. |
| """ |
| assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" |
| ROW = 0 if const_expr(not self.swap_AB) else 1 |
| COL = 1 if const_expr(not self.swap_AB) else 0 |
| |
| thr_col_offset = tScS_t2r[0][COL] |
| seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset |
|
|
| if const_expr(not mask_causal and not mask_local and mask_mod is not None): |
| |
| |
| |
| |
| |
| |
| |
| |
| if is_full_block: |
| if const_expr(mask_seqlen): |
| if seqlenk_col_limit <= 0: |
| |
| for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): |
| acc_S[i] = -cutlass.Float32.inf |
| elif check_m_boundary: |
| |
| ncol = const_expr(cute.size(tScS_t2r.shape)) |
| for i in cutlass.range_constexpr(ncol): |
| row_coord = tScS_t2r[i][ROW] |
| col_coord = tScS_t2r[i][COL] |
| global_q = row_coord + m_block * self.tile_m |
| global_kv = col_coord + n_block * self.tile_n |
| q_out_of_bounds = global_q >= self.seqlen_q |
| kv_out_of_bounds = global_kv >= self.seqlen_k |
| out_of_bounds = q_out_of_bounds or kv_out_of_bounds |
| acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] |
| else: |
| |
| has_fastdiv = const_expr( |
| fastdiv_mods is not None |
| and fastdiv_mods[0] is not None |
| and fastdiv_mods[1] is not None |
| ) |
| wrap_aux_indices = const_expr( |
| has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) |
| ) |
| batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) |
| head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) |
|
|
| ncol = const_expr(cute.size(tScS_t2r.shape)) |
| for i in cutlass.range_constexpr(ncol): |
| row_coord = tScS_t2r[i][ROW] |
| col_coord = tScS_t2r[i][COL] |
| global_q = row_coord + m_block * self.tile_m |
| global_kv = col_coord + n_block * self.tile_n |
|
|
| q_idx_for_mod = global_q |
| kv_idx_for_mod = global_kv |
| if const_expr(wrap_aux_indices): |
| _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0]) |
| _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1]) |
|
|
| q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32) |
| kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32) |
|
|
| mask_value = mask_mod( |
| batch_idx_ssa, |
| head_idx_ssa, |
| q_idx_ssa, |
| kv_idx_ssa, |
| self.seqlen_info, |
| aux_tensors, |
| ) |
| cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) |
| acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf |
|
|
| if const_expr(mask_seqlen): |
| |
| q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q) |
| kv_out_of_bounds = global_kv >= self.seqlen_k |
| out_of_bounds = q_out_of_bounds or kv_out_of_bounds |
| acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] |
|
|
| elif const_expr(not mask_causal and not mask_local): |
| if const_expr(mask_seqlen): |
| if seqlenk_col_limit <= 0: |
| for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): |
| acc_S[i] = -cutlass.Float32.inf |
| else: |
| thr_row_offset = tScS_t2r[0][ROW] |
| seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset |
| causal_offset = seqlenq_row_limit - seqlenk_col_limit |
| if const_expr(mask_causal): |
| |
| |
| |
| row_limit_top = causal_offset |
| if const_expr(mask_seqlen): |
| |
| |
| if seqlenk_col_limit <= 0: |
| row_limit_top = self.tile_m |
| r2p = True |
| if const_expr(not r2p): |
| for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): |
| acc_S[i] = ( |
| -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i] |
| ) |
| else: |
| num_rep = cute.size(tScS_t2r, mode=[0]) |
| num_wg = 2 |
| row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) |
| mask_r2p_lambda( |
| acc_S, |
| lambda s: r2p_bitmask_above(row_limit, s), |
| rank1=True, |
| ) |
| else: |
| if const_expr(self.window_size_right is not None): |
| row_limit_top = causal_offset - self.window_size_right |
| else: |
| row_limit_top = 0 |
| if const_expr(self.window_size_left is not None): |
| row_limit_bot = causal_offset + self.window_size_left |
| if const_expr(mask_seqlen): |
| if seqlenk_col_limit <= 0: |
| row_limit_top = self.tile_m |
| r2p = True |
| if const_expr(not r2p): |
| for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): |
| row_idx = t0ScS_t2r[i][ROW] |
| local_mask = row_idx < row_limit_top |
| if const_expr(self.window_size_left is not None): |
| local_mask |= row_idx > row_limit_bot |
| acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] |
| else: |
|
|
| def mask_gen_fn(s: int) -> Uint32: |
| num_rep = cute.size(tScS_t2r, mode=[0]) |
| num_wg = 2 |
|
|
| row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) |
| mask = r2p_bitmask_above(row_limit, s) |
|
|
| if const_expr(self.window_size_left is not None): |
| row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg) |
| mask = mask & r2p_bitmask_below(row_limit_bottom, s) |
|
|
| return mask |
|
|
| mask_r2p_lambda( |
| acc_S, |
| mask_gen_fn, |
| rank1=True, |
| ) |
|
|