""" Block-sparse runtime utilities for CUTE DSL kernels. This module contains runtime execution functions for block-sparse attention kernels. These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. """ from typing import Callable, Optional from functools import partial import math import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from .quack import copy_utils # Import data structures from block_sparsity from .block_sparsity import BlockSparseTensors from .named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] # # For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active # KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so # the softmax warp-group has no row stats to publish. # # The correction warp-group seeds fully-masked-row stats and runs the usual correction # epilogue so output/LSE have well-defined values. Both warp-groups must still perform # the softmax<->correction mbarrier handshake so phases advance correctly across # empty->empty and empty->non-empty tile sequences. # # In the no-sink case, this corresponds to the usual fully-masked-row convention: # output is zero and LSE is -inf. # # Barrier contract (each is `mbar_ptr + + stage`): # # Producer/consumer pairs: # - `mbar_softmax_corr_full` : softmax arrive -> correction wait # - `mbar_softmax_corr_empty` : correction arrive -> softmax wait # - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait # - `mbar_P_full_2` : softmax arrive -> MMA wait # - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate) # # Empty tile (`total_block_cnt == 0`): # - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`). # It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal. # At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty` # before each tile (when block-sparse) to drain a prior correction arrival and keep # phases aligned across non-empty -> empty transitions. # - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`, # and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable). # - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction # (and correction<->epilogue) handshakes advance phases. # # Non-empty tile: # - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to # publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases; # arrives `mbar_P_full_*` when P is stored. # - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty` # to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed. # # Backward (SM100): # - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute. # - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles # skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward). # - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros # even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`). @cute.jit def load_block_list( block_indices: cute.Tensor, block_count, load_q_with_first: cutlass.Constexpr, first_block_preloaded: cutlass.Constexpr, kv_producer_state, load_Q, load_K, load_V, pipeline_k, pipeline_v, use_tma_q: cutlass.Constexpr, tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, ): """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. for the intra_wg_overlap case, we overlap the loads of K and V. And this means we need to pipeline the last V load from the partial block case, with the loads for the full blocks. Set first_block_preloaded when the caller has already issued the first K load for the list. Note: we iterate along the block_n indices in reverse. Returns: Updated kv_producer_state after processing the block list. """ if block_count > 0: if const_expr(not intra_wg_overlap): # Peel first iteration: the first block may need to load Q alongside K, # Parameters are already Constexpr, so no need to wrap in const_expr() n_block_first = block_indices[block_count - 1] extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) if const_expr(load_q_with_first and use_tma_q): load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) load_K(src_idx=n_block_first, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_first, producer_state=kv_producer_state) kv_producer_state.advance() for offset in cutlass.range(1, block_count): n_block = block_indices[block_count - 1 - offset] pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() else: n_block_first = block_indices[block_count - 1] if const_expr(not first_block_preloaded): extra_tx = ( tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 ) pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) if const_expr(load_q_with_first and use_tma_q): load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) load_K(src_idx=n_block_first, producer_state=kv_producer_state) for idx in cutlass.range(block_count - 1, unroll=1): n_block_prev = block_indices[block_count - 1 - idx] n_block = block_indices[block_count - 2 - idx] kv_producer_state_prev = kv_producer_state.clone() kv_producer_state.advance() pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) return kv_producer_state @cute.jit def finish_overlap_v_load( block_indices: cute.Tensor, block_count, load_V, pipeline_v, kv_producer_state, ): """Load the final V block after overlapped K/V loads.""" if block_count > 0: n_block_last = block_indices[0] pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_last, producer_state=kv_producer_state) kv_producer_state.advance() return kv_producer_state @cute.jit def sparse_tensor_m_block( m_block, qhead_per_kvhead: cutlass.Constexpr[int], q_subtile_factor: cutlass.Constexpr[int], ): """Map packed m_block indices to block-sparse tensor indices.""" block = m_block if const_expr(qhead_per_kvhead != 1): block = block // qhead_per_kvhead if const_expr(q_subtile_factor != 1): block = block // q_subtile_factor return block @cute.jit def produce_block_sparse_loads( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, kv_producer_state, load_Q, load_K, load_V, pipeline_k, pipeline_v, use_tma_q: cutlass.Constexpr, tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): """Iterate over the mask and full block lists for a single tile. The masked (partial) list may leave the last V load pending when intra-warp-group overlap is enabled. The first full block must consume that pending V while issuing its own K load on the next pipeline stage. In the intra-wg-overlap path, the last masked block leaves its V copy in flight while we advance the producer state to start the next full K. Either the full list overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. Args: qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 if mask_empty: # No masked blocks: the full list owns the initial Q+K load. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=True, first_block_preloaded=False, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, use_tma_q=use_tma_q, tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: kv_producer_state = finish_overlap_v_load( curr_full_block_idx, curr_full_block_cnt, load_V, pipeline_v, kv_producer_state, ) else: # Masked blocks present: load Q together with the first masked K so consumers can # start immediately. When overlap is disabled this fully drains the list. kv_producer_state = load_block_list( curr_mask_block_idx, curr_mask_block_cnt, load_q_with_first=True, first_block_preloaded=False, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, use_tma_q=use_tma_q, tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) if full_empty: if const_expr(intra_wg_overlap): kv_producer_state = finish_overlap_v_load( curr_mask_block_idx, curr_mask_block_cnt, load_V, pipeline_v, kv_producer_state, ) else: if const_expr(intra_wg_overlap): # Bridge the masked list to the full list by overlapping the pending masked V # with the first full K load. n_block_mask_last = curr_mask_block_idx[0] n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] kv_producer_state_prev = kv_producer_state.clone() kv_producer_state.advance() pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=False, first_block_preloaded=True, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, use_tma_q=use_tma_q, tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) kv_producer_state = finish_overlap_v_load( curr_full_block_idx, curr_full_block_cnt, load_V, pipeline_v, kv_producer_state, ) else: # Non-overlap path with both lists: run the full list normally (skipping the Q # reload because the masked list already issued it). kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=False, first_block_preloaded=False, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, use_tma_q=use_tma_q, tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) return kv_producer_state @cute.jit def consume_block_sparse_loads( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, seqlen, kv_consumer_state, mma_pv_fn, mma_one_n_block, process_first_half_block, process_last_half_block, mask_fn, score_mod_fn, O_should_accumulate, mask_mod, fastdiv_mods, intra_wg_overlap: cutlass.Constexpr, warp_scheduler_barrier_sync: Callable, warp_scheduler_barrier_arrive: Callable, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): """Consume the mask and full block lists for a single tile on the consumer side. Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses the same sparse tensor indexing. Args: qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 if const_expr(not intra_wg_overlap): if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial( mask_fn, mask_mod=mask_mod, mask_seqlen=True, fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, ), is_first_n_block=True, ) O_should_accumulate = True for i in cutlass.range(1, curr_mask_block_cnt): mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), is_first_n_block=False, ) O_should_accumulate = True if curr_full_block_cnt == 0: warp_scheduler_barrier_arrive() if curr_full_block_cnt > 0: full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_seqlen=True), is_first_n_block=True, ) O_should_accumulate = True for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_seqlen=False), is_first_n_block=False, ) O_should_accumulate = True else: kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), is_first_n_block=False, ) O_should_accumulate = True for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), is_first_n_block=False, ) O_should_accumulate = True warp_scheduler_barrier_arrive() else: if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] kv_consumer_state = process_first_half_block( n_block=mask_n_block, seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial( mask_fn, mask_mod=mask_mod, mask_seqlen=True, fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, ), score_mod_fn=score_mod_fn, is_first_block=True, ) for i in cutlass.range(1, curr_mask_block_cnt): mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), ) O_should_accumulate = True if curr_full_block_cnt > 0: full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: kv_consumer_state = process_first_half_block( n_block=full_n_block, seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, is_first_block=True, ) else: kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), ) O_should_accumulate = True for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), ) O_should_accumulate = True if curr_mask_block_cnt + curr_full_block_cnt > 0: kv_consumer_state = process_last_half_block( kv_consumer_state=kv_consumer_state, zero_init=not O_should_accumulate, ) O_should_accumulate = True return kv_consumer_state, O_should_accumulate, processed_any @cute.jit def load_block_list_sm100( block_indices: cute.Tensor, block_count, load_q_with_first: cutlass.Constexpr, q_stage: cutlass.Constexpr, kv_producer_state, load_Q, load_K, load_V, pipeline_kv, ): """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" if block_count > 0: # First iteration: load Q alongside K if requested n_block_first = block_indices[block_count - 1] if const_expr(load_q_with_first): # SM100 loads Q0 and optionally Q1 load_Q(block=0, stage=0) if const_expr(q_stage == 2): load_Q(block=1, stage=1) # SM100 doesn't use producer_acquire for pipeline_kv in load path # The pipeline barriers are handled inside load_KV load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() # Remaining blocks for offset in cutlass.range(1, block_count): n_block = block_indices[block_count - 1 - offset] load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() return kv_producer_state # SM100-specific tile processor using SM100 helpers @cute.jit def produce_block_sparse_loads_sm100( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, kv_producer_state, load_Q, load_K, load_V, pipeline_kv, q_stage: cutlass.Constexpr, q_producer_phase: Int32, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, ): """SM100 entry point for sparse block iteration. SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use simplified block processing that just calls producer_acquire without extras. Args: m_block: which tile of m we are processing qhead_per_kvhead: Constexpr pack factor """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 q_phase_flipped = False if mask_empty: # No masked blocks: process full list with Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_kv=pipeline_kv, ) q_phase_flipped = not full_empty else: # Process masked blocks with Q loading kv_producer_state = load_block_list_sm100( curr_mask_block_idx, curr_mask_block_cnt, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_kv=pipeline_kv, ) q_phase_flipped = True if not full_empty: # Process full blocks without Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=False, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_kv=pipeline_kv, ) if q_phase_flipped: q_producer_phase ^= 1 return kv_producer_state, q_producer_phase @cute.jit def get_total_block_count( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors if const_expr(full_block_cnt is not None): return ( mask_block_cnt[batch_idx, head_idx, m_block_sparse] + full_block_cnt[batch_idx, head_idx, m_block_sparse] ) else: return mask_block_cnt[batch_idx, head_idx, m_block_sparse] @cute.jit def handle_block_sparse_empty_tile_correction_sm100( tidx: Int32, q_stage: cutlass.Constexpr, m_block_size: cutlass.Constexpr, qhead_per_kvhead, pack_gqa: cutlass.Constexpr, is_split_kv: cutlass.Constexpr, learnable_sink, mLSE, seqlen, m_block: Int32, head_idx: Int32, batch_idx: Int32, split_idx: Int32, sScale: cute.Tensor, stats: list, correction_epilogue: Callable, thr_mma_pv: cute.core.ThrMma, tOtO: cute.Tensor, sO: cute.Tensor, pipeline_sm_stats: cutlass.pipeline.PipelineAsync, sm_stats_barrier: cutlass.pipeline.NamedBarrier, pipeline_o_epi: cutlass.pipeline.PipelineAsync, sm_stats_consumer_phase: Int32, o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, mO_cur: Optional[cute.Tensor] = None, gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Handle SM100 forward block-sparse tiles with no active KV blocks. This path is taken when `total_block_cnt == 0`. The softmax warp-group still arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction warp-group can: - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE - run `correction_epilogue` with `scale=0` so the output tile is written as zeros (independent of any prior tmem contents) - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty` (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles This helper intentionally does not touch `mbar_P_full_*` since no P is produced. See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. """ LOG2_E = Float32(math.log2(math.e)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 for stage in cutlass.range_constexpr(q_stage): row_sum_value = Float32(1.0) row_max_value = ( -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None ) if const_expr(learnable_sink is not None): sink_val = -Float32.inf if const_expr(not pack_gqa): sink_val = Float32(learnable_sink[head_idx]) elif tidx < m_block_size: q_head_idx = ( (q_stage * m_block + stage) * m_block_size + tidx ) % qhead_per_kvhead + head_idx * qhead_per_kvhead sink_val = Float32(learnable_sink[q_head_idx]) if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): if row_max_value == -Float32.inf: row_max_value = sink_val * (LOG2_E / softmax_scale_log2) row_sum_value = Float32(1.0) else: row_sum_value = row_sum_value + cute.math.exp2( sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True ) if tidx < m_block_size: scale_row_idx = tidx + stage * m_block_size sScale[scale_row_idx] = row_sum_value if const_expr(mLSE is not None or learnable_sink is not None): sScale[scale_row_idx + q_stage * m_block_size] = row_max_value acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value stats[stage] = (row_sum_value, row_max_value, acc_flag) # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) pipeline_sm_stats.consumer_release_w_index(stage) if const_expr(gmem_tiled_copy_O is None): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], tidx, stage, m_block, seqlen.seqlen_q, Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, gO[None, None, stage], gmem_tiled_copy_O, ) if const_expr(gmem_tiled_copy_O is None): pipeline_o_epi.producer_commit_w_index(stage) sm_stats_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 return ( sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, ) @cute.jit def softmax_block_sparse_sm100( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, softmax_step: Callable, mask_fn: Callable, mask_fn_none: Callable, mma_si_consumer_phase: Int32, si_corr_producer_phase: Int32, s0_s1_sequence_phase: Int32, pipeline_sm_stats: cutlass.pipeline.PipelineAsync, sm_stats_barrier: cutlass.pipeline.NamedBarrier, q_stage: cutlass.Constexpr, stage_idx: Int32, check_m_boundary: bool, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr[int] = 1, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt if total_block_cnt == 0: # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. # pipeline_sm_stats.producer_commit_w_index(stage_idx) sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, mask_n_block, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), ) for i in cutlass.range(1, curr_mask_block_cnt): mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, mask_n_block, mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), ) if curr_full_block_cnt > 0: full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, full_n_block, is_first=True, mask_fn=partial( mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary ), ) else: ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, full_n_block, is_first=False, mask_fn=partial( mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary ), ) for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, full_n_block, mask_fn=partial( mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary ), ) return ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, total_block_cnt == 0, ) # ============================================================================= # Backward-specific block-sparse helpers (SM100) # ============================================================================= # # In backward, iteration is transposed compared to forward: # - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles) # - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles) # # The backward block-sparse tensors use "Q direction" indexing: # - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile # - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process # @cute.jit def get_total_q_block_count_bwd( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """Count total tile iterations for given n_block (KV tile) in backward.""" q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors total = q_block_cnt[batch_idx, head_idx, n_block] if const_expr(full_block_cnt is not None): total = total + full_block_cnt[batch_idx, head_idx, n_block] return total * subtile_factor @cute.jit def produce_block_sparse_q_loads_bwd_sm100( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, # Pipeline states (will be returned after advancing) producer_state_Q_LSE, producer_state_dO_dPsum, # Pipelines pipeline_Q, pipeline_LSE, pipeline_dO, pipeline_dPsum, # Load functions load_K, load_V, load_Q, load_dO, copy_stats, # Global tensors for LSE/dPsum gLSE, sLSE, gdPsum, sdPsum, # TMA copy bytes for extra_tx_count tma_copy_bytes_K, tma_copy_bytes_V, # Flags for which loads to perform should_load_Q: cutlass.Constexpr, should_load_dO: cutlass.Constexpr, # Subtiling factor and bounds subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """SM100 backward block sparse loading with subtiling. Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum). First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO. """ ( curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, loop_count, ) = get_block_sparse_iteration_info_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max ) for iter_idx in cutlass.range(loop_count, unroll=1): m_block, _ = get_m_block_from_iter_bwd( iter_idx, curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, subtile_factor, m_block_max, ) m_block_safe = m_block if m_block_max > 0: m_block_safe = cutlass.min(m_block, m_block_max - 1) if iter_idx == 0: # First block: load K/V alongside Q/dO if const_expr(should_load_Q): pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) load_Q(m_block_safe, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, m_block_safe], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): pipeline_dO.producer_acquire( producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V ) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, m_block_safe], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() else: # Subsequent blocks: just load Q/dO (K/V already loaded) if const_expr(should_load_Q): pipeline_Q.producer_acquire(producer_state_Q_LSE) load_Q(m_block_safe, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, m_block_safe], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): pipeline_dO.producer_acquire(producer_state_dO_dPsum) load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, m_block_safe], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() return producer_state_Q_LSE, producer_state_dO_dPsum @cute.jit def get_block_sparse_iteration_info_bwd( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """Extract block-sparse iteration info for backward pass. Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None sparse_block_count = curr_q_cnt if const_expr(full_cnt is not None): sparse_block_count = sparse_block_count + curr_full_cnt total_count = sparse_block_count * subtile_factor return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count @cute.jit def get_m_block_from_iter_bwd( iter_idx, curr_q_cnt, curr_q_idx: cute.Tensor, curr_full_cnt, curr_full_idx: Optional[cute.Tensor], subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """Derive m_block index and is_full_block flag from iteration index. Returns (m_block, is_full_block): - m_block: The actual Q-tile block index - is_full_block: True if this is a full block (no mask_mod needed) """ sparse_iter_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor sparse_m_block = Int32(0) is_full_block = False if const_expr(curr_full_idx is not None): if sparse_iter_idx < curr_q_cnt: sparse_m_block = curr_q_idx[sparse_iter_idx] else: sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt] is_full_block = True else: sparse_m_block = curr_q_idx[sparse_iter_idx] return sparse_m_block * subtile_factor + subtile_offset, is_full_block @cute.jit def _load_q_do_block_sm90( m_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage: cutlass.Constexpr, load_kv: bool, ): """Load one Q/dO block, optionally loading K/V on first iteration.""" if load_kv: pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) else: pipeline_Q.producer_acquire(producer_state_Q) load_Q(m_block, producer_state=producer_state_Q) load_LSE(m_block, producer_state=producer_state_Q) producer_state_dO_cur = ( producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q ) if load_kv: pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) else: pipeline_dO.producer_acquire(producer_state_dO_cur) load_dO(m_block, producer_state=producer_state_dO_cur) load_dPsum(m_block, producer_state=producer_state_dO_cur) producer_state_Q.advance() producer_state_dO.advance() return producer_state_Q, producer_state_dO @cute.jit def produce_block_sparse_q_loads_bwd_sm90( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage: cutlass.Constexpr, subtile_factor: cutlass.Constexpr, m_block_max: int, ): """SM90 backward block sparse loading with separate partial/full loops. K/V are loaded with the first valid block. Iterates partial blocks first, then full blocks, matching consumer order. Returns updated (producer_state_Q, producer_state_dO). """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None kv_loaded = False for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: producer_state_Q, producer_state_dO = _load_q_do_block_sm90( m_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage, load_kv=not kv_loaded, ) kv_loaded = True if const_expr(full_cnt is not None): for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: producer_state_Q, producer_state_dO = _load_q_do_block_sm90( m_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage, load_kv=not kv_loaded, ) kv_loaded = True return producer_state_Q, producer_state_dO @cute.jit def consume_block_sparse_mma_bwd_sm90( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, consumer_state_Q, consumer_state_dO, mma_one_m_block_fn, mask, mask_mod, is_causal: cutlass.Constexpr, is_local: cutlass.Constexpr, thr_mma_SdP, score_mod_fn=None, score_mod_bwd_fn=None, subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, aux_tensors=None, fastdiv_mods=(None, None), ): """SM90 backward block sparse MMA consumption with separate partial/full loops. Partial blocks are processed first (with mask_mod applied), then full blocks (without mask_mod). This ensures mask_mod is only applied where needed. Returns updated (consumer_state_Q, consumer_state_dO). """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None dKV_accumulate = False mask_fn_partial = partial( mask.apply_mask, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, thr_mma=thr_mma_SdP, mask_seqlen=True, mask_causal=is_causal, mask_local=is_local, mask_mod=mask_mod, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) mask_fn_full = partial( mask.apply_mask, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, thr_mma=thr_mma_SdP, mask_seqlen=True, mask_causal=is_causal, mask_local=is_local, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( m_block, consumer_state_Q, consumer_state_dO, mask_fn=mask_fn_partial, score_mod_fn=score_mod_fn, score_mod_bwd_fn=score_mod_bwd_fn, dKV_accumulate=dKV_accumulate, ) dKV_accumulate = True if const_expr(full_cnt is not None): for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( m_block, consumer_state_Q, consumer_state_dO, mask_fn=mask_fn_full, score_mod_fn=score_mod_fn, score_mod_bwd_fn=score_mod_bwd_fn, dKV_accumulate=dKV_accumulate, ) dKV_accumulate = True return consumer_state_Q, consumer_state_dO @cute.jit def _store_one_dQaccum_sm90( m_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, num_mma_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """Store dQaccum for a single m_block.""" for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, ) for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, ) with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, gdQaccum[None, warp_group_idx, m_block].iterator, tma_copy_bytes_dQ, ) cute.arch.cp_async_bulk_commit_group() @cute.jit def dQaccum_store_block_sparse_bwd_sm90( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, subtile_factor: cutlass.Constexpr, m_block_max: int, num_mma_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """SM90 backward block sparse dQaccum store with separate partial/full loops. Iterates partial blocks first, then full blocks, matching producer/consumer order. """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: _store_one_dQaccum_sm90( m_block, sdQaccum, gdQaccum, num_mma_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) if const_expr(full_cnt is not None): for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: _store_one_dQaccum_sm90( m_block, sdQaccum, gdQaccum, num_mma_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, )