Instructions to use kernels-community/flash-attn4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-attn4 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-attn4") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Block-sparse runtime utilities for CUTE DSL kernels. | |
| This module contains runtime execution functions for block-sparse attention kernels. | |
| These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. | |
| """ | |
| from typing import Callable, Optional | |
| from functools import partial | |
| import math | |
| import cutlass | |
| import cutlass.cute as cute | |
| from cutlass import Float32, Int32, const_expr | |
| from .quack import copy_utils | |
| # Import data structures from block_sparsity | |
| from .block_sparsity import BlockSparseTensors | |
| from .named_barrier import NamedBarrierBwd | |
| # NOTE [SM100 block-sparse empty tiles: mbarrier contract] | |
| # | |
| # For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active | |
| # KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so | |
| # the softmax warp-group has no row stats to publish. | |
| # | |
| # The correction warp-group seeds fully-masked-row stats and runs the usual correction | |
| # epilogue so output/LSE have well-defined values. Both warp-groups must still perform | |
| # the softmax<->correction mbarrier handshake so phases advance correctly across | |
| # empty->empty and empty->non-empty tile sequences. | |
| # | |
| # In the no-sink case, this corresponds to the usual fully-masked-row convention: | |
| # output is zero and LSE is -inf. | |
| # | |
| # Barrier contract (each is `mbar_ptr + <offset> + stage`): | |
| # | |
| # Producer/consumer pairs: | |
| # - `mbar_softmax_corr_full` : softmax arrive -> correction wait | |
| # - `mbar_softmax_corr_empty` : correction arrive -> softmax wait | |
| # - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait | |
| # - `mbar_P_full_2` : softmax arrive -> MMA wait | |
| # - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate) | |
| # | |
| # Empty tile (`total_block_cnt == 0`): | |
| # - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`). | |
| # It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal. | |
| # At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty` | |
| # before each tile (when block-sparse) to drain a prior correction arrival and keep | |
| # phases aligned across non-empty -> empty transitions. | |
| # - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`, | |
| # and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable). | |
| # - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction | |
| # (and correction<->epilogue) handshakes advance phases. | |
| # | |
| # Non-empty tile: | |
| # - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to | |
| # publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases; | |
| # arrives `mbar_P_full_*` when P is stored. | |
| # - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty` | |
| # to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed. | |
| # | |
| # Backward (SM100): | |
| # - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute. | |
| # - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles | |
| # skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward). | |
| # - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros | |
| # even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`). | |
| def load_block_list( | |
| block_indices: cute.Tensor, | |
| block_count, | |
| load_q_with_first: cutlass.Constexpr, | |
| first_block_preloaded: cutlass.Constexpr, | |
| kv_producer_state, | |
| load_Q, | |
| load_K, | |
| load_V, | |
| pipeline_k, | |
| pipeline_v, | |
| use_tma_q: cutlass.Constexpr, | |
| tma_q_bytes: cutlass.Constexpr, | |
| intra_wg_overlap: cutlass.Constexpr, | |
| ): | |
| """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. | |
| for the intra_wg_overlap case, we overlap the loads of K and V. And this | |
| means we need to pipeline the last V load from the partial block case, | |
| with the loads for the full blocks. Set first_block_preloaded when the | |
| caller has already issued the first K load for the list. | |
| Note: | |
| we iterate along the block_n indices in reverse. | |
| Returns: | |
| Updated kv_producer_state after processing the block list. | |
| """ | |
| if block_count > 0: | |
| if const_expr(not intra_wg_overlap): | |
| # Peel first iteration: the first block may need to load Q alongside K, | |
| # Parameters are already Constexpr, so no need to wrap in const_expr() | |
| n_block_first = block_indices[block_count - 1] | |
| extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 | |
| pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) | |
| if const_expr(load_q_with_first and use_tma_q): | |
| load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) | |
| load_K(src_idx=n_block_first, producer_state=kv_producer_state) | |
| pipeline_v.producer_acquire(kv_producer_state) | |
| load_V(src_idx=n_block_first, producer_state=kv_producer_state) | |
| kv_producer_state.advance() | |
| for offset in cutlass.range(1, block_count): | |
| n_block = block_indices[block_count - 1 - offset] | |
| pipeline_k.producer_acquire(kv_producer_state) | |
| load_K(src_idx=n_block, producer_state=kv_producer_state) | |
| pipeline_v.producer_acquire(kv_producer_state) | |
| load_V(src_idx=n_block, producer_state=kv_producer_state) | |
| kv_producer_state.advance() | |
| else: | |
| n_block_first = block_indices[block_count - 1] | |
| if const_expr(not first_block_preloaded): | |
| extra_tx = ( | |
| tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 | |
| ) | |
| pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) | |
| if const_expr(load_q_with_first and use_tma_q): | |
| load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) | |
| load_K(src_idx=n_block_first, producer_state=kv_producer_state) | |
| for idx in cutlass.range(block_count - 1, unroll=1): | |
| n_block_prev = block_indices[block_count - 1 - idx] | |
| n_block = block_indices[block_count - 2 - idx] | |
| kv_producer_state_prev = kv_producer_state.clone() | |
| kv_producer_state.advance() | |
| pipeline_k.producer_acquire(kv_producer_state) | |
| load_K(src_idx=n_block, producer_state=kv_producer_state) | |
| pipeline_v.producer_acquire(kv_producer_state_prev) | |
| load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) | |
| return kv_producer_state | |
| def finish_overlap_v_load( | |
| block_indices: cute.Tensor, | |
| block_count, | |
| load_V, | |
| pipeline_v, | |
| kv_producer_state, | |
| ): | |
| """Load the final V block after overlapped K/V loads.""" | |
| if block_count > 0: | |
| n_block_last = block_indices[0] | |
| pipeline_v.producer_acquire(kv_producer_state) | |
| load_V(src_idx=n_block_last, producer_state=kv_producer_state) | |
| kv_producer_state.advance() | |
| return kv_producer_state | |
| def sparse_tensor_m_block( | |
| m_block, | |
| qhead_per_kvhead: cutlass.Constexpr[int], | |
| q_subtile_factor: cutlass.Constexpr[int], | |
| ): | |
| """Map packed m_block indices to block-sparse tensor indices.""" | |
| block = m_block | |
| if const_expr(qhead_per_kvhead != 1): | |
| block = block // qhead_per_kvhead | |
| if const_expr(q_subtile_factor != 1): | |
| block = block // q_subtile_factor | |
| return block | |
| def produce_block_sparse_loads( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| m_block, | |
| kv_producer_state, | |
| load_Q, | |
| load_K, | |
| load_V, | |
| pipeline_k, | |
| pipeline_v, | |
| use_tma_q: cutlass.Constexpr, | |
| tma_q_bytes: cutlass.Constexpr, | |
| intra_wg_overlap: cutlass.Constexpr, | |
| qhead_per_kvhead: cutlass.Constexpr[int] = 1, | |
| q_subtile_factor: cutlass.Constexpr[int] = 1, | |
| ): | |
| """Iterate over the mask and full block lists for a single tile. | |
| The masked (partial) list may leave the last V load pending when intra-warp-group | |
| overlap is enabled. The first full block must consume that pending V while | |
| issuing its own K load on the next pipeline stage. | |
| In the intra-wg-overlap path, the last masked block leaves its V copy in flight | |
| while we advance the producer state to start the next full K. Either the full list | |
| overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. | |
| Args: | |
| qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and | |
| must be converted to unpacked for sparse tensor indexing. | |
| """ | |
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors | |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) | |
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| if const_expr(full_block_cnt is not None): | |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| else: | |
| curr_full_block_cnt = Int32(0) | |
| curr_full_block_idx = None | |
| mask_empty = curr_mask_block_cnt == 0 | |
| full_empty = curr_full_block_cnt == 0 | |
| if mask_empty: | |
| # No masked blocks: the full list owns the initial Q+K load. | |
| kv_producer_state = load_block_list( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_q_with_first=True, | |
| first_block_preloaded=False, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_k=pipeline_k, | |
| pipeline_v=pipeline_v, | |
| use_tma_q=use_tma_q, | |
| tma_q_bytes=tma_q_bytes, | |
| intra_wg_overlap=intra_wg_overlap, | |
| ) | |
| if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: | |
| kv_producer_state = finish_overlap_v_load( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_V, | |
| pipeline_v, | |
| kv_producer_state, | |
| ) | |
| else: | |
| # Masked blocks present: load Q together with the first masked K so consumers can | |
| # start immediately. When overlap is disabled this fully drains the list. | |
| kv_producer_state = load_block_list( | |
| curr_mask_block_idx, | |
| curr_mask_block_cnt, | |
| load_q_with_first=True, | |
| first_block_preloaded=False, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_k=pipeline_k, | |
| pipeline_v=pipeline_v, | |
| use_tma_q=use_tma_q, | |
| tma_q_bytes=tma_q_bytes, | |
| intra_wg_overlap=intra_wg_overlap, | |
| ) | |
| if full_empty: | |
| if const_expr(intra_wg_overlap): | |
| kv_producer_state = finish_overlap_v_load( | |
| curr_mask_block_idx, | |
| curr_mask_block_cnt, | |
| load_V, | |
| pipeline_v, | |
| kv_producer_state, | |
| ) | |
| else: | |
| if const_expr(intra_wg_overlap): | |
| # Bridge the masked list to the full list by overlapping the pending masked V | |
| # with the first full K load. | |
| n_block_mask_last = curr_mask_block_idx[0] | |
| n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] | |
| kv_producer_state_prev = kv_producer_state.clone() | |
| kv_producer_state.advance() | |
| pipeline_k.producer_acquire(kv_producer_state) | |
| load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) | |
| pipeline_v.producer_acquire(kv_producer_state_prev) | |
| load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) | |
| kv_producer_state = load_block_list( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_q_with_first=False, | |
| first_block_preloaded=True, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_k=pipeline_k, | |
| pipeline_v=pipeline_v, | |
| use_tma_q=use_tma_q, | |
| tma_q_bytes=tma_q_bytes, | |
| intra_wg_overlap=intra_wg_overlap, | |
| ) | |
| kv_producer_state = finish_overlap_v_load( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_V, | |
| pipeline_v, | |
| kv_producer_state, | |
| ) | |
| else: | |
| # Non-overlap path with both lists: run the full list normally (skipping the Q | |
| # reload because the masked list already issued it). | |
| kv_producer_state = load_block_list( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_q_with_first=False, | |
| first_block_preloaded=False, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_k=pipeline_k, | |
| pipeline_v=pipeline_v, | |
| use_tma_q=use_tma_q, | |
| tma_q_bytes=tma_q_bytes, | |
| intra_wg_overlap=intra_wg_overlap, | |
| ) | |
| return kv_producer_state | |
| def consume_block_sparse_loads( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| m_block, | |
| seqlen, | |
| kv_consumer_state, | |
| mma_pv_fn, | |
| mma_one_n_block, | |
| process_first_half_block, | |
| process_last_half_block, | |
| mask_fn, | |
| score_mod_fn, | |
| O_should_accumulate, | |
| mask_mod, | |
| fastdiv_mods, | |
| intra_wg_overlap: cutlass.Constexpr, | |
| warp_scheduler_barrier_sync: Callable, | |
| warp_scheduler_barrier_arrive: Callable, | |
| qhead_per_kvhead: cutlass.Constexpr[int] = 1, | |
| q_subtile_factor: cutlass.Constexpr[int] = 1, | |
| ): | |
| """Consume the mask and full block lists for a single tile on the consumer side. | |
| Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses | |
| the same sparse tensor indexing. | |
| Args: | |
| qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and | |
| must be converted to unpacked for sparse tensor indexing. | |
| """ | |
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors | |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) | |
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 | |
| if const_expr(not intra_wg_overlap): | |
| if curr_mask_block_cnt > 0: | |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] | |
| warp_scheduler_barrier_sync() | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=mask_n_block, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial( | |
| mask_fn, | |
| mask_mod=mask_mod, | |
| mask_seqlen=True, | |
| fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, | |
| ), | |
| is_first_n_block=True, | |
| ) | |
| O_should_accumulate = True | |
| for i in cutlass.range(1, curr_mask_block_cnt): | |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=mask_n_block, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), | |
| is_first_n_block=False, | |
| ) | |
| O_should_accumulate = True | |
| if curr_full_block_cnt == 0: | |
| warp_scheduler_barrier_arrive() | |
| if curr_full_block_cnt > 0: | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] | |
| if curr_mask_block_cnt == 0: | |
| warp_scheduler_barrier_sync() | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=full_n_block, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_seqlen=True), | |
| is_first_n_block=True, | |
| ) | |
| O_should_accumulate = True | |
| for i in cutlass.range(1, curr_full_block_cnt): | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=full_n_block, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_seqlen=False), | |
| is_first_n_block=False, | |
| ) | |
| O_should_accumulate = True | |
| else: | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=full_n_block, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), | |
| is_first_n_block=False, | |
| ) | |
| O_should_accumulate = True | |
| for i in cutlass.range(1, curr_full_block_cnt): | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=full_n_block, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), | |
| is_first_n_block=False, | |
| ) | |
| O_should_accumulate = True | |
| warp_scheduler_barrier_arrive() | |
| else: | |
| if curr_mask_block_cnt > 0: | |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] | |
| kv_consumer_state = process_first_half_block( | |
| n_block=mask_n_block, | |
| seqlen=seqlen, | |
| kv_consumer_state=kv_consumer_state, | |
| mask_fn=partial( | |
| mask_fn, | |
| mask_mod=mask_mod, | |
| mask_seqlen=True, | |
| fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, | |
| ), | |
| score_mod_fn=score_mod_fn, | |
| is_first_block=True, | |
| ) | |
| for i in cutlass.range(1, curr_mask_block_cnt): | |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=mask_n_block, | |
| seqlen=seqlen, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), | |
| ) | |
| O_should_accumulate = True | |
| if curr_full_block_cnt > 0: | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] | |
| if curr_mask_block_cnt == 0: | |
| kv_consumer_state = process_first_half_block( | |
| n_block=full_n_block, | |
| seqlen=seqlen, | |
| kv_consumer_state=kv_consumer_state, | |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), | |
| score_mod_fn=score_mod_fn, | |
| is_first_block=True, | |
| ) | |
| else: | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=full_n_block, | |
| seqlen=seqlen, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), | |
| ) | |
| O_should_accumulate = True | |
| for i in cutlass.range(1, curr_full_block_cnt): | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] | |
| kv_consumer_state = mma_one_n_block( | |
| kv_consumer_state, | |
| n_block=full_n_block, | |
| seqlen=seqlen, | |
| mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), | |
| mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), | |
| ) | |
| O_should_accumulate = True | |
| if curr_mask_block_cnt + curr_full_block_cnt > 0: | |
| kv_consumer_state = process_last_half_block( | |
| kv_consumer_state=kv_consumer_state, | |
| zero_init=not O_should_accumulate, | |
| ) | |
| O_should_accumulate = True | |
| return kv_consumer_state, O_should_accumulate, processed_any | |
| def load_block_list_sm100( | |
| block_indices: cute.Tensor, | |
| block_count, | |
| load_q_with_first: cutlass.Constexpr, | |
| q_stage: cutlass.Constexpr, | |
| kv_producer_state, | |
| load_Q, | |
| load_K, | |
| load_V, | |
| pipeline_kv, | |
| ): | |
| """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" | |
| if block_count > 0: | |
| # First iteration: load Q alongside K if requested | |
| n_block_first = block_indices[block_count - 1] | |
| if const_expr(load_q_with_first): | |
| # SM100 loads Q0 and optionally Q1 | |
| load_Q(block=0, stage=0) | |
| if const_expr(q_stage == 2): | |
| load_Q(block=1, stage=1) | |
| # SM100 doesn't use producer_acquire for pipeline_kv in load path | |
| # The pipeline barriers are handled inside load_KV | |
| load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) | |
| kv_producer_state.advance() | |
| load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) | |
| kv_producer_state.advance() | |
| # Remaining blocks | |
| for offset in cutlass.range(1, block_count): | |
| n_block = block_indices[block_count - 1 - offset] | |
| load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) | |
| kv_producer_state.advance() | |
| load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) | |
| kv_producer_state.advance() | |
| return kv_producer_state | |
| # SM100-specific tile processor using SM100 helpers | |
| def produce_block_sparse_loads_sm100( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| m_block, | |
| kv_producer_state, | |
| load_Q, | |
| load_K, | |
| load_V, | |
| pipeline_kv, | |
| q_stage: cutlass.Constexpr, | |
| q_producer_phase: Int32, | |
| qhead_per_kvhead: cutlass.Constexpr, | |
| q_subtile_factor: cutlass.Constexpr, | |
| ): | |
| """SM100 entry point for sparse block iteration. | |
| SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use | |
| simplified block processing that just calls producer_acquire without extras. | |
| Args: | |
| m_block: which tile of m we are processing | |
| qhead_per_kvhead: Constexpr pack factor | |
| """ | |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) | |
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors | |
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| if const_expr(full_block_cnt is not None): | |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| else: | |
| curr_full_block_cnt = Int32(0) | |
| curr_full_block_idx = None | |
| mask_empty = curr_mask_block_cnt == 0 | |
| full_empty = curr_full_block_cnt == 0 | |
| q_phase_flipped = False | |
| if mask_empty: | |
| # No masked blocks: process full list with Q loading | |
| kv_producer_state = load_block_list_sm100( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_q_with_first=True, | |
| q_stage=q_stage, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_kv=pipeline_kv, | |
| ) | |
| q_phase_flipped = not full_empty | |
| else: | |
| # Process masked blocks with Q loading | |
| kv_producer_state = load_block_list_sm100( | |
| curr_mask_block_idx, | |
| curr_mask_block_cnt, | |
| load_q_with_first=True, | |
| q_stage=q_stage, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_kv=pipeline_kv, | |
| ) | |
| q_phase_flipped = True | |
| if not full_empty: | |
| # Process full blocks without Q loading | |
| kv_producer_state = load_block_list_sm100( | |
| curr_full_block_idx, | |
| curr_full_block_cnt, | |
| load_q_with_first=False, | |
| q_stage=q_stage, | |
| kv_producer_state=kv_producer_state, | |
| load_Q=load_Q, | |
| load_K=load_K, | |
| load_V=load_V, | |
| pipeline_kv=pipeline_kv, | |
| ) | |
| if q_phase_flipped: | |
| q_producer_phase ^= 1 | |
| return kv_producer_state, q_producer_phase | |
| def get_total_block_count( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| m_block, | |
| qhead_per_kvhead: cutlass.Constexpr, | |
| q_subtile_factor: cutlass.Constexpr, | |
| ): | |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) | |
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors | |
| if const_expr(full_block_cnt is not None): | |
| return ( | |
| mask_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| + full_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| ) | |
| else: | |
| return mask_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| def handle_block_sparse_empty_tile_correction_sm100( | |
| tidx: Int32, | |
| q_stage: cutlass.Constexpr, | |
| m_block_size: cutlass.Constexpr, | |
| qhead_per_kvhead, | |
| pack_gqa: cutlass.Constexpr, | |
| is_split_kv: cutlass.Constexpr, | |
| learnable_sink, | |
| mLSE, | |
| seqlen, | |
| m_block: Int32, | |
| head_idx: Int32, | |
| batch_idx: Int32, | |
| split_idx: Int32, | |
| sScale: cute.Tensor, | |
| stats: list, | |
| correction_epilogue: Callable, | |
| thr_mma_pv: cute.core.ThrMma, | |
| tOtO: cute.Tensor, | |
| sO: cute.Tensor, | |
| pipeline_sm_stats: cutlass.pipeline.PipelineAsync, | |
| sm_stats_barrier: cutlass.pipeline.NamedBarrier, | |
| pipeline_o_epi: cutlass.pipeline.PipelineAsync, | |
| sm_stats_consumer_phase: Int32, | |
| o_corr_consumer_phase: Int32, | |
| corr_epi_producer_phase: Int32, | |
| softmax_scale_log2: Float32, | |
| mO_cur: Optional[cute.Tensor] = None, | |
| gO: Optional[cute.Tensor] = None, | |
| gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, | |
| ): | |
| """Handle SM100 forward block-sparse tiles with no active KV blocks. | |
| This path is taken when `total_block_cnt == 0`. The softmax warp-group still | |
| arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction | |
| warp-group can: | |
| - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE | |
| - run `correction_epilogue` with `scale=0` so the output tile is written as zeros | |
| (independent of any prior tmem contents) | |
| - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty` | |
| (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles | |
| This helper intentionally does not touch `mbar_P_full_*` since no P is produced. | |
| See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. | |
| """ | |
| LOG2_E = Float32(math.log2(math.e)) | |
| warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 | |
| for stage in cutlass.range_constexpr(q_stage): | |
| row_sum_value = Float32(1.0) | |
| row_max_value = ( | |
| -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None | |
| ) | |
| if const_expr(learnable_sink is not None): | |
| sink_val = -Float32.inf | |
| if const_expr(not pack_gqa): | |
| sink_val = Float32(learnable_sink[head_idx]) | |
| elif tidx < m_block_size: | |
| q_head_idx = ( | |
| (q_stage * m_block + stage) * m_block_size + tidx | |
| ) % qhead_per_kvhead + head_idx * qhead_per_kvhead | |
| sink_val = Float32(learnable_sink[q_head_idx]) | |
| if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): | |
| if row_max_value == -Float32.inf: | |
| row_max_value = sink_val * (LOG2_E / softmax_scale_log2) | |
| row_sum_value = Float32(1.0) | |
| else: | |
| row_sum_value = row_sum_value + cute.math.exp2( | |
| sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True | |
| ) | |
| if tidx < m_block_size: | |
| scale_row_idx = tidx + stage * m_block_size | |
| sScale[scale_row_idx] = row_sum_value | |
| if const_expr(mLSE is not None or learnable_sink is not None): | |
| sScale[scale_row_idx + q_stage * m_block_size] = row_max_value | |
| acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value | |
| stats[stage] = (row_sum_value, row_max_value, acc_flag) | |
| # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. | |
| # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) | |
| sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) | |
| pipeline_sm_stats.consumer_release_w_index(stage) | |
| if const_expr(gmem_tiled_copy_O is None): | |
| pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) | |
| correction_epilogue( | |
| thr_mma_pv, | |
| tOtO[None, None, None, stage], | |
| tidx, | |
| stage, | |
| m_block, | |
| seqlen.seqlen_q, | |
| Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs | |
| sO[None, None, stage], | |
| mO_cur, | |
| gO[None, None, stage], | |
| gmem_tiled_copy_O, | |
| ) | |
| if const_expr(gmem_tiled_copy_O is None): | |
| pipeline_o_epi.producer_commit_w_index(stage) | |
| sm_stats_consumer_phase ^= 1 | |
| corr_epi_producer_phase ^= 1 | |
| return ( | |
| sm_stats_consumer_phase, | |
| o_corr_consumer_phase, | |
| corr_epi_producer_phase, | |
| ) | |
| def softmax_block_sparse_sm100( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| m_block, | |
| softmax_step: Callable, | |
| mask_fn: Callable, | |
| mask_fn_none: Callable, | |
| mma_si_consumer_phase: Int32, | |
| si_corr_producer_phase: Int32, | |
| s0_s1_sequence_phase: Int32, | |
| pipeline_sm_stats: cutlass.pipeline.PipelineAsync, | |
| sm_stats_barrier: cutlass.pipeline.NamedBarrier, | |
| q_stage: cutlass.Constexpr, | |
| stage_idx: Int32, | |
| check_m_boundary: bool, | |
| qhead_per_kvhead: cutlass.Constexpr, | |
| q_subtile_factor: cutlass.Constexpr[int] = 1, | |
| ): | |
| warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 | |
| m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) | |
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors | |
| curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| if const_expr(full_block_cnt is not None): | |
| curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] | |
| curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] | |
| else: | |
| curr_full_block_cnt = Int32(0) | |
| curr_full_block_idx = None | |
| total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt | |
| if total_block_cnt == 0: | |
| # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. | |
| # pipeline_sm_stats.producer_commit_w_index(stage_idx) | |
| sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) | |
| else: | |
| if curr_mask_block_cnt > 0: | |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] | |
| ( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| ) = softmax_step( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| mask_n_block, | |
| is_first=True, | |
| mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), | |
| ) | |
| for i in cutlass.range(1, curr_mask_block_cnt): | |
| mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] | |
| ( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| ) = softmax_step( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| mask_n_block, | |
| mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), | |
| ) | |
| if curr_full_block_cnt > 0: | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] | |
| if curr_mask_block_cnt == 0: | |
| ( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| ) = softmax_step( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| full_n_block, | |
| is_first=True, | |
| mask_fn=partial( | |
| mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary | |
| ), | |
| ) | |
| else: | |
| ( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| ) = softmax_step( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| full_n_block, | |
| is_first=False, | |
| mask_fn=partial( | |
| mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary | |
| ), | |
| ) | |
| for i in cutlass.range(1, curr_full_block_cnt): | |
| full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] | |
| ( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| ) = softmax_step( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| full_n_block, | |
| mask_fn=partial( | |
| mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary | |
| ), | |
| ) | |
| return ( | |
| mma_si_consumer_phase, | |
| si_corr_producer_phase, | |
| s0_s1_sequence_phase, | |
| total_block_cnt == 0, | |
| ) | |
| # ============================================================================= | |
| # Backward-specific block-sparse helpers (SM100) | |
| # ============================================================================= | |
| # | |
| # In backward, iteration is transposed compared to forward: | |
| # - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles) | |
| # - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles) | |
| # | |
| # The backward block-sparse tensors use "Q direction" indexing: | |
| # - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile | |
| # - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process | |
| # | |
| def get_total_q_block_count_bwd( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| n_block, | |
| subtile_factor: cutlass.Constexpr = 1, | |
| m_block_max: int = 0, | |
| ): | |
| """Count total tile iterations for given n_block (KV tile) in backward.""" | |
| q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors | |
| total = q_block_cnt[batch_idx, head_idx, n_block] | |
| if const_expr(full_block_cnt is not None): | |
| total = total + full_block_cnt[batch_idx, head_idx, n_block] | |
| return total * subtile_factor | |
| def produce_block_sparse_q_loads_bwd_sm100( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| n_block, | |
| # Pipeline states (will be returned after advancing) | |
| producer_state_Q_LSE, | |
| producer_state_dO_dPsum, | |
| # Pipelines | |
| pipeline_Q, | |
| pipeline_LSE, | |
| pipeline_dO, | |
| pipeline_dPsum, | |
| # Load functions | |
| load_K, | |
| load_V, | |
| load_Q, | |
| load_dO, | |
| copy_stats, | |
| # Global tensors for LSE/dPsum | |
| gLSE, | |
| sLSE, | |
| gdPsum, | |
| sdPsum, | |
| # TMA copy bytes for extra_tx_count | |
| tma_copy_bytes_K, | |
| tma_copy_bytes_V, | |
| # Flags for which loads to perform | |
| should_load_Q: cutlass.Constexpr, | |
| should_load_dO: cutlass.Constexpr, | |
| # Subtiling factor and bounds | |
| subtile_factor: cutlass.Constexpr = 1, | |
| m_block_max: int = 0, | |
| ): | |
| """SM100 backward block sparse loading with subtiling. | |
| Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum). | |
| First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO. | |
| """ | |
| ( | |
| curr_q_cnt, | |
| curr_q_idx, | |
| curr_full_cnt, | |
| curr_full_idx, | |
| loop_count, | |
| ) = get_block_sparse_iteration_info_bwd( | |
| blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max | |
| ) | |
| for iter_idx in cutlass.range(loop_count, unroll=1): | |
| m_block, _ = get_m_block_from_iter_bwd( | |
| iter_idx, | |
| curr_q_cnt, | |
| curr_q_idx, | |
| curr_full_cnt, | |
| curr_full_idx, | |
| subtile_factor, | |
| m_block_max, | |
| ) | |
| m_block_safe = m_block | |
| if m_block_max > 0: | |
| m_block_safe = cutlass.min(m_block, m_block_max - 1) | |
| if iter_idx == 0: | |
| # First block: load K/V alongside Q/dO | |
| if const_expr(should_load_Q): | |
| pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K) | |
| load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) | |
| load_Q(m_block_safe, producer_state=producer_state_Q_LSE) | |
| pipeline_Q.producer_commit(producer_state_Q_LSE) | |
| pipeline_LSE.producer_acquire(producer_state_Q_LSE) | |
| with cute.arch.elect_one(): | |
| copy_stats( | |
| gLSE[None, m_block_safe], | |
| sLSE[None, producer_state_Q_LSE.index], | |
| mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), | |
| ) | |
| producer_state_Q_LSE.advance() | |
| if const_expr(should_load_dO): | |
| pipeline_dO.producer_acquire( | |
| producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V | |
| ) | |
| load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) | |
| load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) | |
| pipeline_dO.producer_commit(producer_state_dO_dPsum) | |
| pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) | |
| with cute.arch.elect_one(): | |
| copy_stats( | |
| gdPsum[None, m_block_safe], | |
| sdPsum[None, producer_state_dO_dPsum.index], | |
| mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), | |
| ) | |
| producer_state_dO_dPsum.advance() | |
| else: | |
| # Subsequent blocks: just load Q/dO (K/V already loaded) | |
| if const_expr(should_load_Q): | |
| pipeline_Q.producer_acquire(producer_state_Q_LSE) | |
| load_Q(m_block_safe, producer_state=producer_state_Q_LSE) | |
| pipeline_Q.producer_commit(producer_state_Q_LSE) | |
| pipeline_LSE.producer_acquire(producer_state_Q_LSE) | |
| with cute.arch.elect_one(): | |
| copy_stats( | |
| gLSE[None, m_block_safe], | |
| sLSE[None, producer_state_Q_LSE.index], | |
| mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), | |
| ) | |
| producer_state_Q_LSE.advance() | |
| if const_expr(should_load_dO): | |
| pipeline_dO.producer_acquire(producer_state_dO_dPsum) | |
| load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) | |
| pipeline_dO.producer_commit(producer_state_dO_dPsum) | |
| pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) | |
| with cute.arch.elect_one(): | |
| copy_stats( | |
| gdPsum[None, m_block_safe], | |
| sdPsum[None, producer_state_dO_dPsum.index], | |
| mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), | |
| ) | |
| producer_state_dO_dPsum.advance() | |
| return producer_state_Q_LSE, producer_state_dO_dPsum | |
| def get_block_sparse_iteration_info_bwd( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| n_block, | |
| subtile_factor: cutlass.Constexpr = 1, | |
| m_block_max: int = 0, | |
| ): | |
| """Extract block-sparse iteration info for backward pass. | |
| Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). | |
| """ | |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors | |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] | |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] | |
| if const_expr(full_cnt is not None): | |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] | |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] | |
| else: | |
| curr_full_cnt = Int32(0) | |
| curr_full_idx = None | |
| sparse_block_count = curr_q_cnt | |
| if const_expr(full_cnt is not None): | |
| sparse_block_count = sparse_block_count + curr_full_cnt | |
| total_count = sparse_block_count * subtile_factor | |
| return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count | |
| def get_m_block_from_iter_bwd( | |
| iter_idx, | |
| curr_q_cnt, | |
| curr_q_idx: cute.Tensor, | |
| curr_full_cnt, | |
| curr_full_idx: Optional[cute.Tensor], | |
| subtile_factor: cutlass.Constexpr = 1, | |
| m_block_max: int = 0, | |
| ): | |
| """Derive m_block index and is_full_block flag from iteration index. | |
| Returns (m_block, is_full_block): | |
| - m_block: The actual Q-tile block index | |
| - is_full_block: True if this is a full block (no mask_mod needed) | |
| """ | |
| sparse_iter_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| sparse_m_block = Int32(0) | |
| is_full_block = False | |
| if const_expr(curr_full_idx is not None): | |
| if sparse_iter_idx < curr_q_cnt: | |
| sparse_m_block = curr_q_idx[sparse_iter_idx] | |
| else: | |
| sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt] | |
| is_full_block = True | |
| else: | |
| sparse_m_block = curr_q_idx[sparse_iter_idx] | |
| return sparse_m_block * subtile_factor + subtile_offset, is_full_block | |
| def _load_q_do_block_sm90( | |
| m_block, | |
| producer_state_Q, | |
| producer_state_dO, | |
| pipeline_Q, | |
| pipeline_dO, | |
| load_K, | |
| load_V, | |
| load_Q, | |
| load_dO, | |
| load_LSE, | |
| load_dPsum, | |
| tma_copy_bytes_K, | |
| tma_copy_bytes_V, | |
| Q_stage_eq_dO_stage: cutlass.Constexpr, | |
| load_kv: bool, | |
| ): | |
| """Load one Q/dO block, optionally loading K/V on first iteration.""" | |
| if load_kv: | |
| pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K) | |
| load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) | |
| else: | |
| pipeline_Q.producer_acquire(producer_state_Q) | |
| load_Q(m_block, producer_state=producer_state_Q) | |
| load_LSE(m_block, producer_state=producer_state_Q) | |
| producer_state_dO_cur = ( | |
| producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q | |
| ) | |
| if load_kv: | |
| pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V) | |
| load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) | |
| else: | |
| pipeline_dO.producer_acquire(producer_state_dO_cur) | |
| load_dO(m_block, producer_state=producer_state_dO_cur) | |
| load_dPsum(m_block, producer_state=producer_state_dO_cur) | |
| producer_state_Q.advance() | |
| producer_state_dO.advance() | |
| return producer_state_Q, producer_state_dO | |
| def produce_block_sparse_q_loads_bwd_sm90( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| n_block, | |
| producer_state_Q, | |
| producer_state_dO, | |
| pipeline_Q, | |
| pipeline_dO, | |
| load_K, | |
| load_V, | |
| load_Q, | |
| load_dO, | |
| load_LSE, | |
| load_dPsum, | |
| tma_copy_bytes_K, | |
| tma_copy_bytes_V, | |
| Q_stage_eq_dO_stage: cutlass.Constexpr, | |
| subtile_factor: cutlass.Constexpr, | |
| m_block_max: int, | |
| ): | |
| """SM90 backward block sparse loading with separate partial/full loops. | |
| K/V are loaded with the first valid block. Iterates partial blocks first, | |
| then full blocks, matching consumer order. | |
| Returns updated (producer_state_Q, producer_state_dO). | |
| """ | |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors | |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] | |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] | |
| if const_expr(full_cnt is not None): | |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] | |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] | |
| else: | |
| curr_full_cnt = Int32(0) | |
| curr_full_idx = None | |
| kv_loaded = False | |
| for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): | |
| sparse_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset | |
| if m_block < m_block_max: | |
| producer_state_Q, producer_state_dO = _load_q_do_block_sm90( | |
| m_block, | |
| producer_state_Q, | |
| producer_state_dO, | |
| pipeline_Q, | |
| pipeline_dO, | |
| load_K, | |
| load_V, | |
| load_Q, | |
| load_dO, | |
| load_LSE, | |
| load_dPsum, | |
| tma_copy_bytes_K, | |
| tma_copy_bytes_V, | |
| Q_stage_eq_dO_stage, | |
| load_kv=not kv_loaded, | |
| ) | |
| kv_loaded = True | |
| if const_expr(full_cnt is not None): | |
| for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): | |
| sparse_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset | |
| if m_block < m_block_max: | |
| producer_state_Q, producer_state_dO = _load_q_do_block_sm90( | |
| m_block, | |
| producer_state_Q, | |
| producer_state_dO, | |
| pipeline_Q, | |
| pipeline_dO, | |
| load_K, | |
| load_V, | |
| load_Q, | |
| load_dO, | |
| load_LSE, | |
| load_dPsum, | |
| tma_copy_bytes_K, | |
| tma_copy_bytes_V, | |
| Q_stage_eq_dO_stage, | |
| load_kv=not kv_loaded, | |
| ) | |
| kv_loaded = True | |
| return producer_state_Q, producer_state_dO | |
| def consume_block_sparse_mma_bwd_sm90( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| n_block, | |
| consumer_state_Q, | |
| consumer_state_dO, | |
| mma_one_m_block_fn, | |
| mask, | |
| mask_mod, | |
| is_causal: cutlass.Constexpr, | |
| is_local: cutlass.Constexpr, | |
| thr_mma_SdP, | |
| score_mod_fn=None, | |
| score_mod_bwd_fn=None, | |
| subtile_factor: cutlass.Constexpr = 1, | |
| m_block_max: int = 0, | |
| aux_tensors=None, | |
| fastdiv_mods=(None, None), | |
| ): | |
| """SM90 backward block sparse MMA consumption with separate partial/full loops. | |
| Partial blocks are processed first (with mask_mod applied), then full blocks | |
| (without mask_mod). This ensures mask_mod is only applied where needed. | |
| Returns updated (consumer_state_Q, consumer_state_dO). | |
| """ | |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors | |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] | |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] | |
| if const_expr(full_cnt is not None): | |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] | |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] | |
| else: | |
| curr_full_cnt = Int32(0) | |
| curr_full_idx = None | |
| dKV_accumulate = False | |
| mask_fn_partial = partial( | |
| mask.apply_mask, | |
| batch_idx=batch_idx, | |
| head_idx=head_idx, | |
| n_block=n_block, | |
| thr_mma=thr_mma_SdP, | |
| mask_seqlen=True, | |
| mask_causal=is_causal, | |
| mask_local=is_local, | |
| mask_mod=mask_mod, | |
| aux_tensors=aux_tensors, | |
| fastdiv_mods=fastdiv_mods, | |
| ) | |
| mask_fn_full = partial( | |
| mask.apply_mask, | |
| batch_idx=batch_idx, | |
| head_idx=head_idx, | |
| n_block=n_block, | |
| thr_mma=thr_mma_SdP, | |
| mask_seqlen=True, | |
| mask_causal=is_causal, | |
| mask_local=is_local, | |
| aux_tensors=aux_tensors, | |
| fastdiv_mods=fastdiv_mods, | |
| ) | |
| for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): | |
| sparse_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset | |
| if m_block < m_block_max: | |
| consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( | |
| m_block, | |
| consumer_state_Q, | |
| consumer_state_dO, | |
| mask_fn=mask_fn_partial, | |
| score_mod_fn=score_mod_fn, | |
| score_mod_bwd_fn=score_mod_bwd_fn, | |
| dKV_accumulate=dKV_accumulate, | |
| ) | |
| dKV_accumulate = True | |
| if const_expr(full_cnt is not None): | |
| for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): | |
| sparse_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset | |
| if m_block < m_block_max: | |
| consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( | |
| m_block, | |
| consumer_state_Q, | |
| consumer_state_dO, | |
| mask_fn=mask_fn_full, | |
| score_mod_fn=score_mod_fn, | |
| score_mod_bwd_fn=score_mod_bwd_fn, | |
| dKV_accumulate=dKV_accumulate, | |
| ) | |
| dKV_accumulate = True | |
| return consumer_state_Q, consumer_state_dO | |
| def _store_one_dQaccum_sm90( | |
| m_block, | |
| sdQaccum: cute.Tensor, | |
| gdQaccum: cute.Tensor, | |
| num_mma_warp_groups: cutlass.Constexpr, | |
| num_threads_per_warp_group: cutlass.Constexpr, | |
| tma_copy_bytes_dQ, | |
| ): | |
| """Store dQaccum for a single m_block.""" | |
| for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): | |
| cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) | |
| cute.arch.barrier_arrive( | |
| barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, | |
| number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, | |
| ) | |
| for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): | |
| cute.arch.barrier( | |
| barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, | |
| number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, | |
| ) | |
| with cute.arch.elect_one(): | |
| copy_utils.cpasync_reduce_bulk_add_f32( | |
| sdQaccum[None, warp_group_idx].iterator, | |
| gdQaccum[None, warp_group_idx, m_block].iterator, | |
| tma_copy_bytes_dQ, | |
| ) | |
| cute.arch.cp_async_bulk_commit_group() | |
| def dQaccum_store_block_sparse_bwd_sm90( | |
| blocksparse_tensors: BlockSparseTensors, | |
| batch_idx, | |
| head_idx, | |
| n_block, | |
| sdQaccum: cute.Tensor, | |
| gdQaccum: cute.Tensor, | |
| subtile_factor: cutlass.Constexpr, | |
| m_block_max: int, | |
| num_mma_warp_groups: cutlass.Constexpr, | |
| num_threads_per_warp_group: cutlass.Constexpr, | |
| tma_copy_bytes_dQ, | |
| ): | |
| """SM90 backward block sparse dQaccum store with separate partial/full loops. | |
| Iterates partial blocks first, then full blocks, matching producer/consumer order. | |
| """ | |
| q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors | |
| curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] | |
| curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] | |
| if const_expr(full_cnt is not None): | |
| curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] | |
| curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] | |
| else: | |
| curr_full_cnt = Int32(0) | |
| curr_full_idx = None | |
| for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): | |
| sparse_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset | |
| if m_block < m_block_max: | |
| _store_one_dQaccum_sm90( | |
| m_block, | |
| sdQaccum, | |
| gdQaccum, | |
| num_mma_warp_groups, | |
| num_threads_per_warp_group, | |
| tma_copy_bytes_dQ, | |
| ) | |
| if const_expr(full_cnt is not None): | |
| for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): | |
| sparse_idx = iter_idx // subtile_factor | |
| subtile_offset = iter_idx % subtile_factor | |
| m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset | |
| if m_block < m_block_max: | |
| _store_one_dQaccum_sm90( | |
| m_block, | |
| sdQaccum, | |
| gdQaccum, | |
| num_mma_warp_groups, | |
| num_threads_per_warp_group, | |
| tma_copy_bytes_dQ, | |
| ) | |