| """ |
| Block-sparse runtime utilities for CUTE DSL kernels. |
| |
| This module contains runtime execution functions for block-sparse attention kernels. |
| These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. |
| """ |
|
|
| from typing import Callable, Optional |
| from functools import partial |
| import math |
| import cutlass |
| import cutlass.cute as cute |
| from cutlass import Float32, Int32, const_expr |
|
|
| from .quack import copy_utils |
|
|
| |
| from .block_sparsity import BlockSparseTensors |
| from .named_barrier import NamedBarrierBwd |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @cute.jit |
| def load_block_list( |
| block_indices: cute.Tensor, |
| block_count, |
| first_block_preloaded: cutlass.Constexpr, |
| kv_producer_state, |
| load_K, |
| load_V, |
| pipeline_k, |
| pipeline_v, |
| intra_wg_overlap: cutlass.Constexpr, |
| ): |
| """Iterate over the sparse blocks and load K, V into the pipeline. |
| For the intra_wg_overlap case, we overlap the loads of K and V. And this |
| means we need to pipeline the last V load from the partial block case, |
| with the loads for the full blocks. Set first_block_preloaded when the |
| caller has already issued the first K load for the list. |
| |
| Q is loaded separately on its own mbarrier before this function is called. |
| |
| Note: |
| we iterate along the block_n indices in reverse. |
| |
| Returns: |
| Updated kv_producer_state after processing the block list. |
| |
| """ |
| if block_count > 0: |
| if const_expr(not intra_wg_overlap): |
| for offset in cutlass.range(block_count): |
| n_block = block_indices[block_count - 1 - offset] |
| pipeline_k.producer_acquire(kv_producer_state) |
| load_K(src_idx=n_block, producer_state=kv_producer_state) |
| pipeline_v.producer_acquire(kv_producer_state) |
| load_V(src_idx=n_block, producer_state=kv_producer_state) |
| kv_producer_state.advance() |
| else: |
| n_block_first = block_indices[block_count - 1] |
| if const_expr(not first_block_preloaded): |
| pipeline_k.producer_acquire(kv_producer_state) |
| load_K(src_idx=n_block_first, producer_state=kv_producer_state) |
|
|
| for idx in cutlass.range(block_count - 1, unroll=1): |
| n_block_prev = block_indices[block_count - 1 - idx] |
| n_block = block_indices[block_count - 2 - idx] |
| kv_producer_state_prev = kv_producer_state.clone() |
| kv_producer_state.advance() |
| pipeline_k.producer_acquire(kv_producer_state) |
| load_K(src_idx=n_block, producer_state=kv_producer_state) |
| pipeline_v.producer_acquire(kv_producer_state_prev) |
| load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) |
|
|
| return kv_producer_state |
|
|
|
|
| @cute.jit |
| def finish_overlap_v_load( |
| block_indices: cute.Tensor, |
| block_count, |
| load_V, |
| pipeline_v, |
| kv_producer_state, |
| ): |
| """Load the final V block after overlapped K/V loads.""" |
| if block_count > 0: |
| n_block_last = block_indices[0] |
| pipeline_v.producer_acquire(kv_producer_state) |
| load_V(src_idx=n_block_last, producer_state=kv_producer_state) |
| kv_producer_state.advance() |
|
|
| return kv_producer_state |
|
|
|
|
| @cute.jit |
| def sparse_tensor_m_block( |
| m_block, |
| qhead_per_kvhead: cutlass.Constexpr[int], |
| q_subtile_factor: cutlass.Constexpr[int], |
| ): |
| """Map packed m_block indices to block-sparse tensor indices.""" |
| block = m_block |
| if const_expr(qhead_per_kvhead != 1): |
| block = block // qhead_per_kvhead |
| if const_expr(q_subtile_factor != 1): |
| block = block // q_subtile_factor |
| return block |
|
|
|
|
| @cute.jit |
| def produce_block_sparse_loads( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| m_block, |
| kv_producer_state, |
| load_K, |
| load_V, |
| pipeline_k, |
| pipeline_v, |
| intra_wg_overlap: cutlass.Constexpr, |
| qhead_per_kvhead: cutlass.Constexpr[int] = 1, |
| q_subtile_factor: cutlass.Constexpr[int] = 1, |
| ): |
| """Iterate over the mask and full block lists for a single tile. |
| |
| Q is loaded separately on its own mbarrier before this function is called. |
| |
| The masked (partial) list may leave the last V load pending when intra-warp-group |
| overlap is enabled. The first full block must consume that pending V while |
| issuing its own K load on the next pipeline stage. |
| |
| In the intra-wg-overlap path, the last masked block leaves its V copy in flight |
| while we advance the producer state to start the next full K. Either the full list |
| overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. |
| |
| Args: |
| qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and |
| must be converted to unpacked for sparse tensor indexing. |
| """ |
|
|
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors |
|
|
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) |
|
|
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] |
|
|
| if const_expr(full_block_cnt is not None): |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] |
| else: |
| curr_full_block_cnt = Int32(0) |
| curr_full_block_idx = None |
|
|
| mask_empty = curr_mask_block_cnt == 0 |
| full_empty = curr_full_block_cnt == 0 |
|
|
| if mask_empty: |
| |
| kv_producer_state = load_block_list( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| first_block_preloaded=False, |
| kv_producer_state=kv_producer_state, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_k=pipeline_k, |
| pipeline_v=pipeline_v, |
| intra_wg_overlap=intra_wg_overlap, |
| ) |
|
|
| if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: |
| kv_producer_state = finish_overlap_v_load( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| load_V, |
| pipeline_v, |
| kv_producer_state, |
| ) |
| else: |
| |
| kv_producer_state = load_block_list( |
| curr_mask_block_idx, |
| curr_mask_block_cnt, |
| first_block_preloaded=False, |
| kv_producer_state=kv_producer_state, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_k=pipeline_k, |
| pipeline_v=pipeline_v, |
| intra_wg_overlap=intra_wg_overlap, |
| ) |
|
|
| if full_empty: |
| if const_expr(intra_wg_overlap): |
| kv_producer_state = finish_overlap_v_load( |
| curr_mask_block_idx, |
| curr_mask_block_cnt, |
| load_V, |
| pipeline_v, |
| kv_producer_state, |
| ) |
| else: |
| if const_expr(intra_wg_overlap): |
| |
| |
| n_block_mask_last = curr_mask_block_idx[0] |
| n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] |
| kv_producer_state_prev = kv_producer_state.clone() |
| kv_producer_state.advance() |
| pipeline_k.producer_acquire(kv_producer_state) |
| load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) |
| pipeline_v.producer_acquire(kv_producer_state_prev) |
| load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) |
|
|
| kv_producer_state = load_block_list( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| first_block_preloaded=True, |
| kv_producer_state=kv_producer_state, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_k=pipeline_k, |
| pipeline_v=pipeline_v, |
| intra_wg_overlap=intra_wg_overlap, |
| ) |
|
|
| kv_producer_state = finish_overlap_v_load( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| load_V, |
| pipeline_v, |
| kv_producer_state, |
| ) |
| else: |
| |
| kv_producer_state = load_block_list( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| first_block_preloaded=False, |
| kv_producer_state=kv_producer_state, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_k=pipeline_k, |
| pipeline_v=pipeline_v, |
| intra_wg_overlap=intra_wg_overlap, |
| ) |
|
|
| return kv_producer_state |
|
|
|
|
| @cute.jit |
| def consume_block_sparse_loads( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| m_block, |
| seqlen, |
| kv_consumer_state, |
| mma_pv_fn, |
| mma_one_n_block, |
| process_first_half_block, |
| process_last_half_block, |
| mask_fn, |
| score_mod_fn, |
| O_should_accumulate, |
| mask_mod, |
| fastdiv_mods, |
| intra_wg_overlap: cutlass.Constexpr, |
| warp_scheduler_barrier_sync: Callable, |
| warp_scheduler_barrier_arrive: Callable, |
| qhead_per_kvhead: cutlass.Constexpr[int] = 1, |
| q_subtile_factor: cutlass.Constexpr[int] = 1, |
| ): |
| """Consume the mask and full block lists for a single tile on the consumer side. |
| |
| Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses |
| the same sparse tensor indexing. |
| |
| Args: |
| qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and |
| must be converted to unpacked for sparse tensor indexing. |
| """ |
|
|
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors |
|
|
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) |
|
|
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] |
|
|
| processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 |
|
|
| if const_expr(not intra_wg_overlap): |
| if curr_mask_block_cnt > 0: |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] |
| warp_scheduler_barrier_sync() |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=mask_n_block, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial( |
| mask_fn, |
| mask_mod=mask_mod, |
| mask_seqlen=True, |
| fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, |
| ), |
| is_first_n_block=True, |
| ) |
| O_should_accumulate = True |
| for i in cutlass.range(1, curr_mask_block_cnt): |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=mask_n_block, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), |
| is_first_n_block=False, |
| ) |
| O_should_accumulate = True |
| if curr_full_block_cnt == 0: |
| warp_scheduler_barrier_arrive() |
|
|
| if curr_full_block_cnt > 0: |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] |
| if curr_mask_block_cnt == 0: |
| warp_scheduler_barrier_sync() |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=full_n_block, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_seqlen=True), |
| is_first_n_block=True, |
| ) |
| O_should_accumulate = True |
| for i in cutlass.range(1, curr_full_block_cnt): |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=full_n_block, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_seqlen=False), |
| is_first_n_block=False, |
| ) |
| O_should_accumulate = True |
| else: |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=full_n_block, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), |
| is_first_n_block=False, |
| ) |
| O_should_accumulate = True |
| for i in cutlass.range(1, curr_full_block_cnt): |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=full_n_block, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), |
| is_first_n_block=False, |
| ) |
| O_should_accumulate = True |
| warp_scheduler_barrier_arrive() |
| else: |
| if curr_mask_block_cnt > 0: |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] |
| kv_consumer_state = process_first_half_block( |
| n_block=mask_n_block, |
| seqlen=seqlen, |
| kv_consumer_state=kv_consumer_state, |
| mask_fn=partial( |
| mask_fn, |
| mask_mod=mask_mod, |
| mask_seqlen=True, |
| fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, |
| ), |
| score_mod_fn=score_mod_fn, |
| is_first_block=True, |
| ) |
| for i in cutlass.range(1, curr_mask_block_cnt): |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=mask_n_block, |
| seqlen=seqlen, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), |
| ) |
| O_should_accumulate = True |
|
|
| if curr_full_block_cnt > 0: |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] |
| if curr_mask_block_cnt == 0: |
| kv_consumer_state = process_first_half_block( |
| n_block=full_n_block, |
| seqlen=seqlen, |
| kv_consumer_state=kv_consumer_state, |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), |
| score_mod_fn=score_mod_fn, |
| is_first_block=True, |
| ) |
| else: |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=full_n_block, |
| seqlen=seqlen, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), |
| ) |
| O_should_accumulate = True |
| for i in cutlass.range(1, curr_full_block_cnt): |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] |
| kv_consumer_state = mma_one_n_block( |
| kv_consumer_state, |
| n_block=full_n_block, |
| seqlen=seqlen, |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), |
| ) |
| O_should_accumulate = True |
|
|
| if curr_mask_block_cnt + curr_full_block_cnt > 0: |
| kv_consumer_state = process_last_half_block( |
| kv_consumer_state=kv_consumer_state, |
| zero_init=not O_should_accumulate, |
| ) |
| O_should_accumulate = True |
|
|
| return kv_consumer_state, O_should_accumulate, processed_any |
|
|
|
|
| @cute.jit |
| def load_block_list_sm100( |
| block_indices: cute.Tensor, |
| block_count, |
| load_q_with_first: cutlass.Constexpr, |
| q_stage: cutlass.Constexpr, |
| kv_producer_state, |
| load_Q, |
| load_K, |
| load_V, |
| pipeline_kv, |
| ): |
| """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" |
| if block_count > 0: |
| |
| n_block_first = block_indices[block_count - 1] |
|
|
| if const_expr(load_q_with_first): |
| |
| load_Q(block=0, stage=0) |
| if const_expr(q_stage == 2): |
| load_Q(block=1, stage=1) |
|
|
| |
| |
| load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) |
| kv_producer_state.advance() |
| load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) |
| kv_producer_state.advance() |
|
|
| |
| for offset in cutlass.range(1, block_count): |
| n_block = block_indices[block_count - 1 - offset] |
| load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) |
| kv_producer_state.advance() |
| load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) |
| kv_producer_state.advance() |
|
|
| return kv_producer_state |
|
|
|
|
| |
| @cute.jit |
| def produce_block_sparse_loads_sm100( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| m_block, |
| kv_producer_state, |
| load_Q, |
| load_K, |
| load_V, |
| pipeline_kv, |
| q_stage: cutlass.Constexpr, |
| q_producer_phase: Int32, |
| qhead_per_kvhead: cutlass.Constexpr, |
| q_subtile_factor: cutlass.Constexpr, |
| ): |
| """SM100 entry point for sparse block iteration. |
| |
| SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use |
| simplified block processing that just calls producer_acquire without extras. |
| |
| Args: |
| m_block: which tile of m we are processing |
| qhead_per_kvhead: Constexpr pack factor |
| """ |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) |
|
|
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors |
|
|
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] |
|
|
| if const_expr(full_block_cnt is not None): |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] |
| else: |
| curr_full_block_cnt = Int32(0) |
| curr_full_block_idx = None |
|
|
| mask_empty = curr_mask_block_cnt == 0 |
| full_empty = curr_full_block_cnt == 0 |
|
|
| q_phase_flipped = False |
|
|
| if mask_empty: |
| |
| kv_producer_state = load_block_list_sm100( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| load_q_with_first=True, |
| q_stage=q_stage, |
| kv_producer_state=kv_producer_state, |
| load_Q=load_Q, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_kv=pipeline_kv, |
| ) |
| q_phase_flipped = not full_empty |
| else: |
| |
| kv_producer_state = load_block_list_sm100( |
| curr_mask_block_idx, |
| curr_mask_block_cnt, |
| load_q_with_first=True, |
| q_stage=q_stage, |
| kv_producer_state=kv_producer_state, |
| load_Q=load_Q, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_kv=pipeline_kv, |
| ) |
| q_phase_flipped = True |
|
|
| if not full_empty: |
| |
| kv_producer_state = load_block_list_sm100( |
| curr_full_block_idx, |
| curr_full_block_cnt, |
| load_q_with_first=False, |
| q_stage=q_stage, |
| kv_producer_state=kv_producer_state, |
| load_Q=load_Q, |
| load_K=load_K, |
| load_V=load_V, |
| pipeline_kv=pipeline_kv, |
| ) |
|
|
| if q_phase_flipped: |
| q_producer_phase ^= 1 |
|
|
| return kv_producer_state, q_producer_phase |
|
|
|
|
| @cute.jit |
| def get_total_block_count( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| m_block, |
| qhead_per_kvhead: cutlass.Constexpr, |
| q_subtile_factor: cutlass.Constexpr, |
| ): |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) |
|
|
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors |
| if const_expr(full_block_cnt is not None): |
| return ( |
| mask_block_cnt[batch_idx, head_idx, m_block_sparse] |
| + full_block_cnt[batch_idx, head_idx, m_block_sparse] |
| ) |
| else: |
| return mask_block_cnt[batch_idx, head_idx, m_block_sparse] |
|
|
|
|
| @cute.jit |
| def handle_block_sparse_empty_tile_correction_sm100( |
| tidx: Int32, |
| q_stage: cutlass.Constexpr, |
| m_block_size: cutlass.Constexpr, |
| qhead_per_kvhead, |
| pack_gqa: cutlass.Constexpr, |
| is_split_kv: cutlass.Constexpr, |
| learnable_sink, |
| mLSE, |
| seqlen, |
| m_block: Int32, |
| head_idx: Int32, |
| batch_idx: Int32, |
| split_idx: Int32, |
| sScale: cute.Tensor, |
| stats: list, |
| correction_epilogue: Callable, |
| thr_mma_pv: cute.core.ThrMma, |
| tOtO: cute.Tensor, |
| sO: cute.Tensor, |
| pipeline_sm_stats: cutlass.pipeline.PipelineAsync, |
| sm_stats_barrier: cutlass.pipeline.NamedBarrier, |
| pipeline_o_epi: cutlass.pipeline.PipelineAsync, |
| sm_stats_consumer_phase: Int32, |
| o_corr_consumer_phase: Int32, |
| corr_epi_producer_phase: Int32, |
| softmax_scale_log2: Float32, |
| mO_cur: Optional[cute.Tensor] = None, |
| gO: Optional[cute.Tensor] = None, |
| gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, |
| ): |
| """Handle SM100 forward block-sparse tiles with no active KV blocks. |
| |
| This path is taken when `total_block_cnt == 0`. The softmax warp-group still |
| arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction |
| warp-group can: |
| |
| - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE |
| - run `correction_epilogue` with `scale=0` so the output tile is written as zeros |
| (independent of any prior tmem contents) |
| - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty` |
| (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles |
| |
| This helper intentionally does not touch `mbar_P_full_*` since no P is produced. |
| See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. |
| """ |
| LOG2_E = Float32(math.log2(math.e)) |
| warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 |
|
|
| for stage in cutlass.range_constexpr(q_stage): |
| row_sum_value = Float32(1.0) |
| row_max_value = ( |
| -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None |
| ) |
| if const_expr(learnable_sink is not None): |
| sink_val = -Float32.inf |
| if const_expr(not pack_gqa): |
| sink_val = Float32(learnable_sink[head_idx]) |
| elif tidx < m_block_size: |
| q_head_idx = ( |
| (q_stage * m_block + stage) * m_block_size + tidx |
| ) % qhead_per_kvhead + head_idx * qhead_per_kvhead |
| sink_val = Float32(learnable_sink[q_head_idx]) |
| if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): |
| if row_max_value == -Float32.inf: |
| row_max_value = sink_val * (LOG2_E / softmax_scale_log2) |
| row_sum_value = Float32(1.0) |
| else: |
| row_sum_value = row_sum_value + cute.math.exp2( |
| sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True |
| ) |
| if tidx < m_block_size: |
| scale_row_idx = tidx + stage * m_block_size |
| sScale[scale_row_idx] = row_sum_value |
| if const_expr(mLSE is not None or learnable_sink is not None): |
| sScale[scale_row_idx + q_stage * m_block_size] = row_max_value |
| acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value |
| stats[stage] = (row_sum_value, row_max_value, acc_flag) |
|
|
| |
| |
| sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) |
| pipeline_sm_stats.consumer_release_w_index(stage) |
|
|
| if const_expr(gmem_tiled_copy_O is None): |
| pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) |
| correction_epilogue( |
| thr_mma_pv, |
| tOtO[None, None, None, stage], |
| tidx, |
| stage, |
| m_block, |
| seqlen.seqlen_q, |
| Float32(0.0), |
| sO[None, None, stage], |
| mO_cur, |
| gO[None, None, stage], |
| gmem_tiled_copy_O, |
| ) |
| if const_expr(gmem_tiled_copy_O is None): |
| pipeline_o_epi.producer_commit_w_index(stage) |
|
|
| sm_stats_consumer_phase ^= 1 |
| corr_epi_producer_phase ^= 1 |
|
|
| return ( |
| sm_stats_consumer_phase, |
| o_corr_consumer_phase, |
| corr_epi_producer_phase, |
| ) |
|
|
|
|
| @cute.jit |
| def softmax_block_sparse_sm100( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| m_block, |
| softmax_step: Callable, |
| mask_fn: Callable, |
| mask_fn_none: Callable, |
| mma_si_consumer_phase: Int32, |
| si_corr_producer_phase: Int32, |
| s0_s1_sequence_phase: Int32, |
| pipeline_sm_stats: cutlass.pipeline.PipelineAsync, |
| sm_stats_barrier: cutlass.pipeline.NamedBarrier, |
| q_stage: cutlass.Constexpr, |
| stage_idx: Int32, |
| check_m_boundary: bool, |
| qhead_per_kvhead: cutlass.Constexpr, |
| q_subtile_factor: cutlass.Constexpr[int] = 1, |
| ): |
| warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) |
|
|
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors |
|
|
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] |
|
|
| if const_expr(full_block_cnt is not None): |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] |
| else: |
| curr_full_block_cnt = Int32(0) |
| curr_full_block_idx = None |
|
|
| total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt |
|
|
| if total_block_cnt == 0: |
| |
| |
| sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) |
| else: |
| if curr_mask_block_cnt > 0: |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] |
| ( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| ) = softmax_step( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| mask_n_block, |
| is_first=True, |
| mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), |
| ) |
| for i in cutlass.range(1, curr_mask_block_cnt): |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] |
| ( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| ) = softmax_step( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| mask_n_block, |
| mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), |
| ) |
|
|
| if curr_full_block_cnt > 0: |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] |
| if curr_mask_block_cnt == 0: |
| ( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| ) = softmax_step( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| full_n_block, |
| is_first=True, |
| mask_fn=partial( |
| mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary |
| ), |
| ) |
| else: |
| ( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| ) = softmax_step( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| full_n_block, |
| is_first=False, |
| mask_fn=partial( |
| mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary |
| ), |
| ) |
| for i in cutlass.range(1, curr_full_block_cnt): |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] |
| ( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| ) = softmax_step( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| full_n_block, |
| mask_fn=partial( |
| mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary |
| ), |
| ) |
|
|
| return ( |
| mma_si_consumer_phase, |
| si_corr_producer_phase, |
| s0_s1_sequence_phase, |
| total_block_cnt == 0, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @cute.jit |
| def get_total_q_block_count_bwd( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| n_block, |
| subtile_factor: cutlass.Constexpr = 1, |
| m_block_max: int = 0, |
| ): |
| """Count total tile iterations for given n_block (KV tile) in backward.""" |
| q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors |
| total = q_block_cnt[batch_idx, head_idx, n_block] |
| if const_expr(full_block_cnt is not None): |
| total = total + full_block_cnt[batch_idx, head_idx, n_block] |
| return total * subtile_factor |
|
|
|
|
| @cute.jit |
| def produce_block_sparse_q_loads_bwd_sm100( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| n_block, |
| |
| producer_state_Q_LSE, |
| producer_state_dO_dPsum, |
| |
| pipeline_Q, |
| pipeline_LSE, |
| pipeline_dO, |
| pipeline_dPsum, |
| |
| load_K, |
| load_V, |
| load_Q, |
| load_dO, |
| copy_stats, |
| |
| gLSE, |
| sLSE, |
| gdPsum, |
| sdPsum, |
| |
| tma_copy_bytes_K, |
| tma_copy_bytes_V, |
| |
| should_load_Q: cutlass.Constexpr, |
| should_load_dO: cutlass.Constexpr, |
| |
| subtile_factor: cutlass.Constexpr = 1, |
| m_block_max: int = 0, |
| ): |
| """SM100 backward block sparse loading with subtiling. |
| |
| Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum). |
| First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO. |
| """ |
| ( |
| curr_q_cnt, |
| curr_q_idx, |
| curr_full_cnt, |
| curr_full_idx, |
| loop_count, |
| ) = get_block_sparse_iteration_info_bwd( |
| blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max |
| ) |
|
|
| for iter_idx in cutlass.range(loop_count, unroll=1): |
| m_block, _ = get_m_block_from_iter_bwd( |
| iter_idx, |
| curr_q_cnt, |
| curr_q_idx, |
| curr_full_cnt, |
| curr_full_idx, |
| subtile_factor, |
| m_block_max, |
| ) |
| m_block_safe = m_block |
| if m_block_max > 0: |
| m_block_safe = cutlass.min(m_block, m_block_max - 1) |
|
|
| if iter_idx == 0: |
| |
| if const_expr(should_load_Q): |
| pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K) |
| load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) |
| load_Q(m_block_safe, producer_state=producer_state_Q_LSE) |
| pipeline_Q.producer_commit(producer_state_Q_LSE) |
| pipeline_LSE.producer_acquire(producer_state_Q_LSE) |
| with cute.arch.elect_one(): |
| copy_stats( |
| gLSE[None, m_block_safe], |
| sLSE[None, producer_state_Q_LSE.index], |
| mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), |
| ) |
| producer_state_Q_LSE.advance() |
| if const_expr(should_load_dO): |
| pipeline_dO.producer_acquire( |
| producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V |
| ) |
| load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) |
| load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) |
| pipeline_dO.producer_commit(producer_state_dO_dPsum) |
| pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) |
| with cute.arch.elect_one(): |
| copy_stats( |
| gdPsum[None, m_block_safe], |
| sdPsum[None, producer_state_dO_dPsum.index], |
| mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), |
| ) |
| producer_state_dO_dPsum.advance() |
| else: |
| |
| if const_expr(should_load_Q): |
| pipeline_Q.producer_acquire(producer_state_Q_LSE) |
| load_Q(m_block_safe, producer_state=producer_state_Q_LSE) |
| pipeline_Q.producer_commit(producer_state_Q_LSE) |
| pipeline_LSE.producer_acquire(producer_state_Q_LSE) |
| with cute.arch.elect_one(): |
| copy_stats( |
| gLSE[None, m_block_safe], |
| sLSE[None, producer_state_Q_LSE.index], |
| mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), |
| ) |
| producer_state_Q_LSE.advance() |
| if const_expr(should_load_dO): |
| pipeline_dO.producer_acquire(producer_state_dO_dPsum) |
| load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) |
| pipeline_dO.producer_commit(producer_state_dO_dPsum) |
| pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) |
| with cute.arch.elect_one(): |
| copy_stats( |
| gdPsum[None, m_block_safe], |
| sdPsum[None, producer_state_dO_dPsum.index], |
| mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), |
| ) |
| producer_state_dO_dPsum.advance() |
|
|
| return producer_state_Q_LSE, producer_state_dO_dPsum |
|
|
|
|
| @cute.jit |
| def get_block_sparse_iteration_info_bwd( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| n_block, |
| subtile_factor: cutlass.Constexpr = 1, |
| m_block_max: int = 0, |
| ): |
| """Extract block-sparse iteration info for backward pass. |
| |
| Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). |
| """ |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] |
|
|
| if const_expr(full_cnt is not None): |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] |
| else: |
| curr_full_cnt = Int32(0) |
| curr_full_idx = None |
|
|
| sparse_block_count = curr_q_cnt |
| if const_expr(full_cnt is not None): |
| sparse_block_count = sparse_block_count + curr_full_cnt |
| total_count = sparse_block_count * subtile_factor |
|
|
| return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count |
|
|
|
|
| @cute.jit |
| def get_m_block_from_iter_bwd( |
| iter_idx, |
| curr_q_cnt, |
| curr_q_idx: cute.Tensor, |
| curr_full_cnt, |
| curr_full_idx: Optional[cute.Tensor], |
| subtile_factor: cutlass.Constexpr = 1, |
| m_block_max: int = 0, |
| ): |
| """Derive m_block index and is_full_block flag from iteration index. |
| |
| Returns (m_block, is_full_block): |
| - m_block: The actual Q-tile block index |
| - is_full_block: True if this is a full block (no mask_mod needed) |
| """ |
| sparse_iter_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
|
|
| sparse_m_block = Int32(0) |
| is_full_block = False |
| if const_expr(curr_full_idx is not None): |
| if sparse_iter_idx < curr_q_cnt: |
| sparse_m_block = curr_q_idx[sparse_iter_idx] |
| else: |
| sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt] |
| is_full_block = True |
| else: |
| sparse_m_block = curr_q_idx[sparse_iter_idx] |
|
|
| return sparse_m_block * subtile_factor + subtile_offset, is_full_block |
|
|
|
|
| @cute.jit |
| def _load_q_do_block_sm90( |
| m_block, |
| producer_state_Q, |
| producer_state_dO, |
| pipeline_Q, |
| pipeline_dO, |
| load_K, |
| load_V, |
| load_Q, |
| load_dO, |
| load_LSE, |
| load_dPsum, |
| tma_copy_bytes_K, |
| tma_copy_bytes_V, |
| Q_stage_eq_dO_stage: cutlass.Constexpr, |
| load_kv: bool, |
| ): |
| """Load one Q/dO block, optionally loading K/V on first iteration.""" |
| if load_kv: |
| pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K) |
| load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) |
| else: |
| pipeline_Q.producer_acquire(producer_state_Q) |
| load_Q(m_block, producer_state=producer_state_Q) |
| load_LSE(m_block, producer_state=producer_state_Q) |
|
|
| producer_state_dO_cur = ( |
| producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q |
| ) |
| if load_kv: |
| pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V) |
| load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) |
| else: |
| pipeline_dO.producer_acquire(producer_state_dO_cur) |
| load_dO(m_block, producer_state=producer_state_dO_cur) |
| load_dPsum(m_block, producer_state=producer_state_dO_cur) |
|
|
| producer_state_Q.advance() |
| producer_state_dO.advance() |
| return producer_state_Q, producer_state_dO |
|
|
|
|
| @cute.jit |
| def produce_block_sparse_q_loads_bwd_sm90( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| n_block, |
| producer_state_Q, |
| producer_state_dO, |
| pipeline_Q, |
| pipeline_dO, |
| load_K, |
| load_V, |
| load_Q, |
| load_dO, |
| load_LSE, |
| load_dPsum, |
| tma_copy_bytes_K, |
| tma_copy_bytes_V, |
| Q_stage_eq_dO_stage: cutlass.Constexpr, |
| subtile_factor: cutlass.Constexpr, |
| m_block_max: int, |
| ): |
| """SM90 backward block sparse loading with separate partial/full loops. |
| |
| K/V are loaded with the first valid block. Iterates partial blocks first, |
| then full blocks, matching consumer order. |
| |
| Returns updated (producer_state_Q, producer_state_dO). |
| """ |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] |
|
|
| if const_expr(full_cnt is not None): |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] |
| else: |
| curr_full_cnt = Int32(0) |
| curr_full_idx = None |
|
|
| kv_loaded = False |
|
|
| for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): |
| sparse_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
| m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset |
|
|
| if m_block < m_block_max: |
| producer_state_Q, producer_state_dO = _load_q_do_block_sm90( |
| m_block, |
| producer_state_Q, |
| producer_state_dO, |
| pipeline_Q, |
| pipeline_dO, |
| load_K, |
| load_V, |
| load_Q, |
| load_dO, |
| load_LSE, |
| load_dPsum, |
| tma_copy_bytes_K, |
| tma_copy_bytes_V, |
| Q_stage_eq_dO_stage, |
| load_kv=not kv_loaded, |
| ) |
| kv_loaded = True |
|
|
| if const_expr(full_cnt is not None): |
| for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): |
| sparse_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
| m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset |
|
|
| if m_block < m_block_max: |
| producer_state_Q, producer_state_dO = _load_q_do_block_sm90( |
| m_block, |
| producer_state_Q, |
| producer_state_dO, |
| pipeline_Q, |
| pipeline_dO, |
| load_K, |
| load_V, |
| load_Q, |
| load_dO, |
| load_LSE, |
| load_dPsum, |
| tma_copy_bytes_K, |
| tma_copy_bytes_V, |
| Q_stage_eq_dO_stage, |
| load_kv=not kv_loaded, |
| ) |
| kv_loaded = True |
|
|
| return producer_state_Q, producer_state_dO |
|
|
|
|
| @cute.jit |
| def consume_block_sparse_mma_bwd_sm90( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| n_block, |
| consumer_state_Q, |
| consumer_state_dO, |
| mma_one_m_block_fn, |
| mask, |
| mask_mod, |
| is_causal: cutlass.Constexpr, |
| is_local: cutlass.Constexpr, |
| thr_mma_SdP, |
| score_mod_fn=None, |
| score_mod_bwd_fn=None, |
| subtile_factor: cutlass.Constexpr = 1, |
| m_block_max: int = 0, |
| aux_tensors=None, |
| fastdiv_mods=(None, None), |
| ): |
| """SM90 backward block sparse MMA consumption with separate partial/full loops. |
| |
| Partial blocks are processed first (with mask_mod applied), then full blocks |
| (without mask_mod). This ensures mask_mod is only applied where needed. |
| |
| Returns updated (consumer_state_Q, consumer_state_dO). |
| """ |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] |
|
|
| if const_expr(full_cnt is not None): |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] |
| else: |
| curr_full_cnt = Int32(0) |
| curr_full_idx = None |
|
|
| dKV_accumulate = False |
|
|
| mask_fn_partial = partial( |
| mask.apply_mask, |
| batch_idx=batch_idx, |
| head_idx=head_idx, |
| n_block=n_block, |
| thr_mma=thr_mma_SdP, |
| mask_seqlen=True, |
| mask_causal=is_causal, |
| mask_local=is_local, |
| mask_mod=mask_mod, |
| aux_tensors=aux_tensors, |
| fastdiv_mods=fastdiv_mods, |
| ) |
|
|
| mask_fn_full = partial( |
| mask.apply_mask, |
| batch_idx=batch_idx, |
| head_idx=head_idx, |
| n_block=n_block, |
| thr_mma=thr_mma_SdP, |
| mask_seqlen=True, |
| mask_causal=is_causal, |
| mask_local=is_local, |
| aux_tensors=aux_tensors, |
| fastdiv_mods=fastdiv_mods, |
| ) |
|
|
| for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): |
| sparse_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
| m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset |
|
|
| if m_block < m_block_max: |
| consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( |
| m_block, |
| consumer_state_Q, |
| consumer_state_dO, |
| mask_fn=mask_fn_partial, |
| score_mod_fn=score_mod_fn, |
| score_mod_bwd_fn=score_mod_bwd_fn, |
| dKV_accumulate=dKV_accumulate, |
| ) |
| dKV_accumulate = True |
|
|
| if const_expr(full_cnt is not None): |
| for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): |
| sparse_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
| m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset |
|
|
| if m_block < m_block_max: |
| consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( |
| m_block, |
| consumer_state_Q, |
| consumer_state_dO, |
| mask_fn=mask_fn_full, |
| score_mod_fn=score_mod_fn, |
| score_mod_bwd_fn=score_mod_bwd_fn, |
| dKV_accumulate=dKV_accumulate, |
| ) |
| dKV_accumulate = True |
|
|
| return consumer_state_Q, consumer_state_dO |
|
|
|
|
| @cute.jit |
| def _store_one_dQaccum_sm90( |
| m_block, |
| sdQaccum: cute.Tensor, |
| gdQaccum: cute.Tensor, |
| num_dQ_warp_groups: cutlass.Constexpr, |
| num_threads_per_warp_group: cutlass.Constexpr, |
| tma_copy_bytes_dQ, |
| ): |
| """Store dQaccum for a single m_block.""" |
| for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): |
| cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True) |
| cute.arch.barrier_arrive( |
| barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, |
| number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, |
| ) |
| for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): |
| cute.arch.barrier( |
| barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, |
| number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, |
| ) |
| with cute.arch.elect_one(): |
| copy_utils.cpasync_reduce_bulk_add_f32( |
| sdQaccum[None, warp_group_idx].iterator, |
| gdQaccum[(None, warp_group_idx), m_block].iterator, |
| tma_copy_bytes_dQ, |
| ) |
| cute.arch.cp_async_bulk_commit_group() |
|
|
|
|
| @cute.jit |
| def dQaccum_store_block_sparse_bwd_sm90( |
| blocksparse_tensors: BlockSparseTensors, |
| batch_idx, |
| head_idx, |
| n_block, |
| sdQaccum: cute.Tensor, |
| gdQaccum: cute.Tensor, |
| subtile_factor: cutlass.Constexpr, |
| m_block_max: int, |
| num_dQ_warp_groups: cutlass.Constexpr, |
| num_threads_per_warp_group: cutlass.Constexpr, |
| tma_copy_bytes_dQ, |
| ): |
| """SM90 backward block sparse dQaccum store with separate partial/full loops. |
| |
| Iterates partial blocks first, then full blocks, matching producer/consumer order. |
| """ |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] |
|
|
| if const_expr(full_cnt is not None): |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] |
| else: |
| curr_full_cnt = Int32(0) |
| curr_full_idx = None |
|
|
| for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): |
| sparse_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
| m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset |
|
|
| if m_block < m_block_max: |
| _store_one_dQaccum_sm90( |
| m_block, |
| sdQaccum, |
| gdQaccum, |
| num_dQ_warp_groups, |
| num_threads_per_warp_group, |
| tma_copy_bytes_dQ, |
| ) |
|
|
| if const_expr(full_cnt is not None): |
| for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): |
| sparse_idx = iter_idx // subtile_factor |
| subtile_offset = iter_idx % subtile_factor |
| m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset |
|
|
| if m_block < m_block_max: |
| _store_one_dQaccum_sm90( |
| m_block, |
| sdQaccum, |
| gdQaccum, |
| num_dQ_warp_groups, |
| num_threads_per_warp_group, |
| tma_copy_bytes_dQ, |
| ) |
|
|