| | |
| | |
| | |
| |
|
| | #pragma once |
| |
|
| | #include "namespace_config.h" |
| | #include <cute/tensor.hpp> |
| |
|
| | #include <cutlass/cutlass.h> |
| | #include <cutlass/array.h> |
| | #include <cutlass/numeric_types.h> |
| |
|
| | #include "block_info.h" |
| | #include "kernel_traits.h" |
| | #include "utils.h" |
| |
|
| | namespace FLASH_NAMESPACE { |
| |
|
| | using namespace cute; |
| |
|
| | |
| |
|
| | template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1> |
| | inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o, |
| | Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) { |
| | static_assert(Layout0::rank == 3, "Only support 3D Tensor"); |
| | static_assert(Layout1::rank == 1, "Only support 1D Tensor"); |
| | CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); |
| | |
| | |
| | Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), |
| | make_layout(get<0>(do_.layout()), |
| | get<2>(do_.layout())))); |
| | Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); |
| | Tensor do_fp32 = FLASH_NAMESPACE::convert_type<float>(do_reshaped); |
| | Tensor o_fp32 = FLASH_NAMESPACE::convert_type<float>(o_reshaped); |
| | #pragma unroll |
| | for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { |
| | float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); |
| | #pragma unroll |
| | for (int ni = 1; ni < size<1>(do_reshaped); ni++) { |
| | dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); |
| | } |
| | FLASH_NAMESPACE::SumOp<float> sum_op; |
| | dP_sum_cur = FLASH_NAMESPACE::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale; |
| | if (threadIdx.x % THREADS_PER_ROW == 0) { |
| | dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; |
| | } |
| | } |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params> |
| | inline __device__ void compute_dot_do_o(const Params ¶ms) { |
| | using Element = typename Kernel_traits::Element; |
| | using ElementAccum = typename Kernel_traits::ElementAccum; |
| | using index_t = typename Kernel_traits::index_t; |
| |
|
| | const int m_block = blockIdx.x; |
| | |
| | const int bidb = blockIdx.y; |
| | |
| | const int bidh = blockIdx.z; |
| | |
| | const int tidx = threadIdx.x; |
| |
|
| | constexpr int kBlockM = Kernel_traits::kBlockM; |
| | constexpr int kHeadDim = Kernel_traits::kHeadDim; |
| |
|
| | const BlockInfo binfo(params, bidb); |
| | if (m_block * kBlockM >= binfo.actual_seqlen_q) return; |
| |
|
| | const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) |
| | + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; |
| | const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) |
| | + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; |
| | const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) |
| | + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; |
| | |
| | const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; |
| |
|
| | Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do), |
| | Shape<Int<kBlockM>, Int<kHeadDim>>{}, |
| | make_stride(params.do_row_stride, _1{})); |
| | Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o), |
| | Shape<Int<kBlockM>, Int<kHeadDim>>{}, |
| | make_stride(params.o_row_stride, _1{})); |
| | Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum), |
| | Shape<Int<kBlockM>, Int<kHeadDim>>{}, |
| | make_stride(params.h * params.d_rounded, _1{})); |
| | Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum), |
| | Shape<Int<kBlockM>>{}, Stride<_1>{}); |
| |
|
| | typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; |
| | auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); |
| | |
| | |
| | typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; |
| | auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); |
| |
|
| | Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); |
| | Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); |
| | Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); |
| |
|
| | Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); |
| | Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); |
| |
|
| | |
| | Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO))); |
| | |
| | #pragma unroll |
| | for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} |
| |
|
| | Tensor tdOrdO = make_fragment_like(tdOgdO); |
| | Tensor tdOrO = make_fragment_like(tdOgO); |
| | FLASH_NAMESPACE::copy</*Is_even_MN=*/false, false, true>( |
| | gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM |
| | ); |
| | FLASH_NAMESPACE::copy</*Is_even_MN=*/false, false, true>( |
| | gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM |
| | ); |
| | |
| | |
| | |
| | dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum, |
| | Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); |
| | if (Clear_dQaccum) { |
| | |
| | |
| | Tensor zero = make_fragment_like(tdQgdQaccum); |
| | clear(zero); |
| | cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); |
| | } |
| | } |
| |
|
| | |
| |
|
| | template<typename Kernel_traits, typename Params> |
| | inline __device__ void clear_dKVaccum(const Params ¶ms) { |
| | using ElementAccum = typename Kernel_traits::ElementAccum; |
| | using index_t = typename Kernel_traits::index_t; |
| |
|
| | const int n_block = blockIdx.x; |
| | |
| | const int bidb = blockIdx.y; |
| | |
| | const int bidh = blockIdx.z; |
| | |
| | const int tidx = threadIdx.x; |
| |
|
| | constexpr int kBlockN = Kernel_traits::kBlockN; |
| | constexpr int kHeadDim = Kernel_traits::kHeadDim; |
| |
|
| | const BlockInfo binfo(params, bidb); |
| | if (n_block * kBlockN >= binfo.actual_seqlen_k) return; |
| |
|
| | const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; |
| |
|
| | Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum), |
| | Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{}); |
| | Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum), |
| | Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{}); |
| |
|
| | typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; |
| | auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); |
| | Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); |
| | Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); |
| | Tensor zero = make_fragment_like(tdKgdKaccum); |
| | clear(zero); |
| | cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); |
| | cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | template<typename Kernel_traits, typename Params> |
| | inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { |
| | using Element = typename Kernel_traits::Element; |
| | using ElementAccum = typename Kernel_traits::ElementAccum; |
| | using index_t = typename Kernel_traits::index_t; |
| |
|
| | |
| | extern __shared__ char smem_[]; |
| |
|
| | const int m_block = blockIdx.x; |
| | |
| | const int bidb = blockIdx.y; |
| | |
| | const int bidh = blockIdx.z; |
| | |
| | const int tidx = threadIdx.x; |
| |
|
| | constexpr int kBlockM = Kernel_traits::kBlockM; |
| | constexpr int kHeadDim = Kernel_traits::kHeadDim; |
| |
|
| | const BlockInfo binfo(params, bidb); |
| | if (m_block * kBlockM >= binfo.actual_seqlen_q) return; |
| |
|
| | const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) |
| | + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; |
| | const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) |
| | + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; |
| |
|
| | Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq), |
| | Shape<Int<kBlockM>, Int<kHeadDim>>{}, |
| | make_stride(params.dq_row_stride, _1{})); |
| | Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum), |
| | Shape<Int<kBlockM>, Int<kHeadDim>>{}, |
| | make_stride(params.h * params.d_rounded, _1{})); |
| |
|
| | Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), |
| | typename Kernel_traits::SmemLayoutdQ{}); |
| |
|
| | typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; |
| | auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); |
| | typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; |
| | auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); |
| |
|
| | typename Kernel_traits::TiledMmadQ tiled_mma_dq; |
| | auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); |
| | auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); |
| | Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); |
| |
|
| | Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); |
| | Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); |
| | Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); |
| |
|
| | Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); |
| | CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); |
| |
|
| | Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); |
| | clear(acc_dq); |
| | for (int s = 0; s < nsplits; ++s) { |
| | cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); |
| | #pragma unroll |
| | for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } |
| | tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; |
| | } |
| | #pragma unroll |
| | for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } |
| | |
| | Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq); |
| | Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); |
| | cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); |
| | __syncthreads(); |
| | Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ)); |
| | cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); |
| |
|
| | Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); |
| | Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); |
| | Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ))); |
| | #pragma unroll |
| | for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } |
| | |
| | FLASH_NAMESPACE::copy</*Is_even_MN=*/false, false, false, false>( |
| | gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM |
| | ); |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | template<typename Kernel_traits, typename Params> |
| | inline __device__ void convert_dKV(const Params ¶ms) { |
| | using Element = typename Kernel_traits::Element; |
| | using ElementAccum = typename Kernel_traits::ElementAccum; |
| | using index_t = typename Kernel_traits::index_t; |
| |
|
| | |
| | extern __shared__ char smem_[]; |
| |
|
| | const int n_block = blockIdx.x; |
| | |
| | const int bidb = blockIdx.y; |
| | |
| | const int bidh = blockIdx.z; |
| | |
| | const int tidx = threadIdx.x; |
| |
|
| | constexpr int kBlockN = Kernel_traits::kBlockN; |
| | constexpr int kHeadDim = Kernel_traits::kHeadDim; |
| |
|
| | const BlockInfo binfo(params, bidb); |
| | if (n_block * kBlockN >= binfo.actual_seqlen_k) return; |
| |
|
| | const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) |
| | + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; |
| | const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) |
| | + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; |
| | const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded |
| | + n_block * kBlockN) * params.d_rounded; |
| |
|
| | Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk), |
| | Shape<Int<kBlockN>, Int<kHeadDim>>{}, |
| | make_stride(params.dk_row_stride, _1{})); |
| | Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv), |
| | Shape<Int<kBlockN>, Int<kHeadDim>>{}, |
| | make_stride(params.dv_row_stride, _1{})); |
| | Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum), |
| | Shape<Int<kBlockN>, Int<kHeadDim>>{}, |
| | Stride<Int<kHeadDim>, _1>{}); |
| | Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum), |
| | Shape<Int<kBlockN>, Int<kHeadDim>>{}, |
| | Stride<Int<kHeadDim>, _1>{}); |
| |
|
| | Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), |
| | typename Kernel_traits::SmemLayoutdKV{}); |
| | Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); |
| |
|
| | typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; |
| | auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); |
| | typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; |
| | auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); |
| |
|
| | typename Kernel_traits::TiledMmadKV tiled_mma_dkv; |
| | auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); |
| | auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); |
| | Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); |
| | Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); |
| |
|
| | Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); |
| | Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); |
| | Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); |
| | Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); |
| | Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); |
| | Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); |
| |
|
| | Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); |
| | Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); |
| | CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); |
| | CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); |
| |
|
| | Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); |
| | Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); |
| | cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); |
| | cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); |
| | #pragma unroll |
| | for (int i = 0; i < size(acc_dk); ++i) { |
| | acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; |
| | } |
| | #pragma unroll |
| | for (int i = 0; i < size(acc_dv); ++i) { |
| | acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; |
| | } |
| | |
| | Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk); |
| | Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv); |
| | Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); |
| | Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); |
| | cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); |
| | cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); |
| | __syncthreads(); |
| | Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK)); |
| | Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV)); |
| | cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); |
| | cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); |
| |
|
| | Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); |
| | Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); |
| | Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK))); |
| | #pragma unroll |
| | for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } |
| | |
| | FLASH_NAMESPACE::copy</*Is_even_MN=*/false, false, false, false>( |
| | gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN |
| | ); |
| | FLASH_NAMESPACE::copy</*Is_even_MN=*/false, false, false, false>( |
| | gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN |
| | ); |
| | } |
| |
|
| | } |
| |
|