| |
| |
| |
|
|
| #pragma once |
|
|
| #include "namespace_config.h" |
| namespace FLASH_NAMESPACE { |
|
|
| |
|
|
| template<bool Varlen=true> |
| struct BlockInfo { |
|
|
| template<typename Params> |
| __device__ BlockInfo(const Params ¶ms, const int bidb) |
| : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) |
| , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) |
| , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) |
| |
| |
| , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) |
| , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) |
| , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) |
| { |
| } |
|
|
| template <typename index_t> |
| __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { |
| return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; |
| } |
|
|
| template <typename index_t> |
| __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { |
| return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; |
| } |
|
|
| const int sum_s_q; |
| const int sum_s_k; |
| const int actual_seqlen_q; |
| |
| const int leftpad_k; |
| const int seqlen_k_cache; |
| const int actual_seqlen_k; |
| }; |
|
|
| |
|
|
| } |
|
|