| #pragma once |
|
|
| #include "cutlass/bfloat16.h" |
|
|
| enum class ModelType { |
| V32, |
| MODEL1 |
| }; |
|
|
| struct __align__(4*8) DecodingSchedMeta { |
| int begin_req_idx, end_req_idx; |
| int begin_block_idx, end_block_idx; |
| int begin_split_idx; |
| int is_first_req_splitted, is_last_req_splitted; |
| int _pad[1]; |
| }; |
| static constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta); |
|
|
| struct DenseAttnDecodeParams { |
| using index_t = int64_t; |
|
|
| int b; |
| int s_q; |
| int q_seq_per_hk; |
| int d, d_v; |
| int h_q, h_k; |
| int num_blocks; |
| int q_head_per_hk; |
| bool is_causal; |
| float scale_softmax, scale_softmax_log2; |
| |
| void *__restrict__ q_ptr; |
| void *__restrict__ k_ptr; |
| void *__restrict__ o_ptr; |
| float *__restrict__ softmax_lse_ptr; |
|
|
| index_t q_batch_stride; |
| index_t k_batch_stride; |
| index_t o_batch_stride; |
| index_t q_row_stride; |
| index_t k_row_stride; |
| index_t o_row_stride; |
| index_t q_head_stride; |
| index_t k_head_stride; |
| index_t o_head_stride; |
|
|
| int *__restrict__ block_table; |
| index_t block_table_batch_stride; |
| int page_block_size; |
| int *__restrict__ seqlens_k_ptr; |
|
|
| DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr; |
| int num_sm_parts; |
| int *__restrict__ num_splits_ptr; |
|
|
| int total_num_splits; |
| float *__restrict__ softmax_lseaccum_ptr; |
| float *__restrict__ oaccum_ptr; |
|
|
| cudaStream_t stream; |
| }; |
|
|
| struct SparseAttnDecodeParams { |
| int b, s_q; |
| int h_q, h_kv; |
| int d_qk, d_v; |
| float sm_scale, sm_scale_div_log2; |
| int num_blocks, page_block_size, topk; |
| ModelType model_type; |
|
|
| cutlass::bfloat16_t* __restrict__ q; |
| cutlass::bfloat16_t* __restrict__ kv; |
| int* __restrict__ indices; |
| int* __restrict__ topk_length; |
| float* __restrict__ attn_sink; |
|
|
| float* __restrict__ lse; |
| cutlass::bfloat16_t* __restrict__ out; |
| |
| int extra_num_blocks, extra_page_block_size, extra_topk; |
| cutlass::bfloat16_t* __restrict__ extra_kv; |
| int* __restrict__ extra_indices; |
| int* __restrict__ extra_topk_length; |
| |
| int stride_q_b, stride_q_s_q, stride_q_h_q; |
| int stride_kv_block, stride_kv_row; |
| int stride_indices_b, stride_indices_s_q; |
| int stride_lse_b, stride_lse_s_q; |
| int stride_o_b, stride_o_s_q, stride_o_h_q; |
| int stride_extra_kv_block, stride_extra_kv_row; |
| int stride_extra_indices_b, stride_extra_indices_s_q; |
| |
| cudaStream_t stream; |
| |
| |
| float* __restrict__ lse_accum; |
| float* __restrict__ o_accum; |
| int stride_lse_accum_split, stride_lse_accum_s_q; |
| int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q; |
| DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; |
| int* __restrict__ num_splits_ptr; |
| int num_sm_parts; |
| }; |
|
|
| struct CombineParams { |
| int b, s_q, h_q, d_v; |
|
|
| float* __restrict__ lse; |
| void* __restrict__ out; |
| int stride_lse_b, stride_lse_s_q; |
| int stride_o_b, stride_o_s_q, stride_o_h_q; |
|
|
| float* __restrict__ lse_accum; |
| float* __restrict__ o_accum; |
| int stride_lse_accum_split, stride_lse_accum_s_q; |
| int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q; |
|
|
| DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; |
| int* __restrict__ num_splits_ptr; |
| int num_sm_parts; |
|
|
| float* attn_sink; |
|
|
| cudaStream_t stream; |
| }; |
|
|
| struct GetDecodeSchedMetaParams { |
| int b; |
| int s_q; |
| int block_size_n; |
| int fixed_overhead_num_blocks; |
|
|
| int topk, extra_topk; |
| int *__restrict__ topk_length, *__restrict__ extra_topk_length; |
|
|
| int *__restrict__ seqlens_k_ptr; |
|
|
| DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr; |
| int *__restrict__ num_splits_ptr; |
| int num_sm_parts; |
|
|
| cudaStream_t stream; |
| }; |
|
|
| struct SparseAttnFwdParams { |
| int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk; |
| float sm_scale, sm_scale_div_log2; |
|
|
| |
| cutlass::bfloat16_t* __restrict__ q; |
| cutlass::bfloat16_t* __restrict__ kv; |
| int* __restrict__ indices; |
| float* __restrict__ attn_sink; |
| int* __restrict__ topk_length; |
|
|
| |
| int stride_q_s_q; int stride_q_h_q; |
| int stride_kv_s_kv; int stride_kv_h_kv; |
| int stride_indices_s_q; int stride_indices_h_kv; |
|
|
| |
| cutlass::bfloat16_t* __restrict__ out; |
| float* __restrict__ max_logits; |
| float* __restrict__ lse; |
|
|
| int num_sm; |
| cudaStream_t stream; |
| }; |
|
|
| |
| enum class SparseAttnFwdMode { |
| Prefill, |
| DecodeWithSplitKV, |
| }; |
|
|
| template<SparseAttnFwdMode FWD_MODE> |
| inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwdMode::DecodeWithSplitKV>::value; |
|
|
| template<SparseAttnFwdMode FWD_MODE> |
| using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>; |
|
|