| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | #pragma once |
| |
|
| | |
| | |
| |
|
| | #include "cutlass/cutlass.h" |
| | #include "cutlass/arch/barrier.h" |
| |
|
| | #include "cute/tensor.hpp" |
| | #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" |
| |
|
| | namespace cutlass::epilogue::fusion { |
| |
|
| | using namespace cute; |
| | using namespace detail; |
| |
|
| | |
| | template< |
| | int Stages, |
| | class CtaTileShapeMNK, |
| | class Element, |
| | class StrideMNL = Stride<_0,_1,_0>, |
| | int Alignment = 128 / sizeof_bits_v<Element> |
| | > |
| | struct Sm90RowOrScalarBroadcastArray { |
| | static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); |
| | static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); |
| | static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); |
| |
|
| | struct SharedStorage { |
| | array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem; |
| | }; |
| |
|
| | |
| | |
| | |
| | struct Arguments { |
| | const Element* const* ptr_row_array = nullptr; |
| | bool row_broadcast = true; |
| | StrideMNL dRow = {}; |
| | }; |
| |
|
| | using Params = Arguments; |
| |
|
| | template <class ProblemShape> |
| | static constexpr Params |
| | to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
| | return args; |
| | } |
| |
|
| | template <class ProblemShape> |
| | static bool |
| | can_implement(ProblemShape const& problem_shape, Arguments const& args) { |
| | return true; |
| | } |
| |
|
| | template <class ProblemShape> |
| | static size_t |
| | get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
| | return 0; |
| | } |
| |
|
| | template <class ProblemShape> |
| | static cutlass::Status |
| | initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, |
| | CudaHostAdapter* cuda_adapter = nullptr) { |
| | return cutlass::Status::kSuccess; |
| | } |
| |
|
| | CUTLASS_HOST_DEVICE |
| | Sm90RowOrScalarBroadcastArray() { } |
| |
|
| | CUTLASS_HOST_DEVICE |
| | Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) |
| | : params(params) |
| | , smem(const_cast<Element*>(shared_storage.smem.data())) { } |
| |
|
| | Params params; |
| | Element *smem = nullptr; |
| |
|
| | CUTLASS_DEVICE bool |
| | is_producer_load_needed() const { |
| | return false; |
| | } |
| |
|
| | CUTLASS_DEVICE bool |
| | is_C_load_needed() const { |
| | return false; |
| | } |
| |
|
| | CUTLASS_DEVICE bool |
| | is_zero() const { |
| | return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); |
| | } |
| |
|
| | template <class... Args> |
| | CUTLASS_DEVICE auto |
| | get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { |
| | return EmptyProducerLoadCallbacks{}; |
| | } |
| |
|
| | template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum> |
| | struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { |
| | CUTLASS_DEVICE |
| | ConsumerStoreCallbacks( |
| | GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, |
| | GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, |
| | SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, |
| | CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, |
| | int group, Params const& params_) |
| | : tGS_gRow(tGS_gRow_) |
| | , tGS_sRow(tGS_sRow_) |
| | , tGS_cRow(tGS_cRow_) |
| | , tiled_G2S(tiled_g2s_) |
| | , tSR_sRow(tSR_sRow_) |
| | , tSR_rRow(tSR_rRow_) |
| | , tCcRow(tCcRow_) |
| | , residue_tCcRow(residue_tCcRow_) |
| | , group(group) |
| | , params(params_) {} |
| |
|
| | GS_GTensor tGS_gRow; |
| | GS_STensor tGS_sRow; |
| | GS_CTensor tGS_cRow; |
| | Tiled_G2S tiled_G2S; |
| |
|
| | SR_STensor tSR_sRow; |
| | SR_RTensor tSR_rRow; |
| | |
| | CTensor tCcRow; |
| | ThrResidue residue_tCcRow; |
| | ThrNum thr_num; |
| | int group; |
| | Params const& params; |
| |
|
| | CUTLASS_DEVICE void |
| | begin() { |
| | if (!params.row_broadcast) { |
| | fill(tSR_rRow, *(params.ptr_row_array[group])); |
| | return; |
| | } |
| |
|
| | auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; |
| | Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); |
| | Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); |
| | Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); |
| |
|
| | for (int i = 0; i < size(tGS_gRow_flt); ++i) { |
| | if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { |
| | continue; |
| | } |
| | if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { |
| | tGS_sRow_flt(i) = tGS_gRow_flt(i); |
| | } |
| | else { |
| | tGS_sRow_flt(i) = Element(0); |
| | } |
| | } |
| | synchronize(); |
| | } |
| |
|
| | CUTLASS_DEVICE void |
| | begin_loop(int epi_m, int epi_n) { |
| | if (epi_m == 0) { |
| | if (!params.row_broadcast) return; |
| | Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); |
| | Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); |
| | copy(tSR_sRow_flt, tSR_rRow_flt); |
| | } |
| | } |
| |
|
| | template <typename ElementAccumulator, int FragmentSize> |
| | CUTLASS_DEVICE Array<Element, FragmentSize> |
| | visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { |
| | Array<Element, FragmentSize> frg_row; |
| |
|
| | CUTLASS_PRAGMA_UNROLL |
| | for (int i = 0; i < FragmentSize; ++i) { |
| | frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); |
| | } |
| |
|
| | return frg_row; |
| | } |
| | }; |
| |
|
| | template < |
| | bool ReferenceSrc, |
| | class... Args |
| | > |
| | CUTLASS_DEVICE auto |
| | get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { |
| | auto [M, N, K, L] = args.problem_shape_mnkl; |
| | auto [m, n, k, l] = args.tile_coord_mnkl; |
| | using ThreadCount = decltype(size(args.tiled_copy)); |
| |
|
| | Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); |
| | Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); |
| | Tensor sRow = make_tensor(make_smem_ptr(smem), |
| | make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); |
| | |
| | auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, |
| | Layout< Shape<_1, ThreadCount>, |
| | Stride<_0, _1>>{}, |
| | Layout<_1>{}); |
| | auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); |
| | Tensor tGS_gRow = thr_g2s.partition_S(gRow); |
| | Tensor tGS_sRow = thr_g2s.partition_D(sRow); |
| |
|
| | |
| | auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); |
| | Tensor tGS_cRow = thr_g2s.partition_S(cRow); |
| |
|
| | |
| | Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); |
| | Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); |
| |
|
| | return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>( |
| | tGS_gRow, |
| | tGS_sRow, |
| | tGS_cRow, tiled_g2s, |
| | tSR_sRow, |
| | tSR_rRow, |
| | args.tCcD, |
| | args.residue_cD, |
| | ThreadCount{}, |
| | l, |
| | params); |
| | } |
| | }; |
| |
|
| | |
| |
|
| | |
| | template< |
| | int Stages, |
| | class CtaTileShapeMNK, |
| | class Element, |
| | class StrideMNL = Stride<_1,_0,_0>, |
| | int Alignment = 128 / sizeof_bits_v<Element> |
| | > |
| | struct Sm90ColOrScalarBroadcastArray { |
| | static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); |
| | static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet"); |
| | static_assert( |
| | (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || |
| | (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); |
| |
|
| | |
| | struct SharedStorage { }; |
| |
|
| | |
| | |
| | |
| | struct Arguments { |
| | const Element* const* ptr_col_array = nullptr; |
| | bool col_broadcast = true; |
| | StrideMNL dCol = {}; |
| | }; |
| |
|
| | using Params = Arguments; |
| |
|
| | template <class ProblemShape> |
| | static constexpr Params |
| | to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
| | return args; |
| | } |
| |
|
| | template <class ProblemShape> |
| | static bool |
| | can_implement(ProblemShape const& problem_shape, Arguments const& args) { |
| | return true; |
| | } |
| |
|
| | template <class ProblemShape> |
| | static size_t |
| | get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
| | return 0; |
| | } |
| |
|
| | template <class ProblemShape> |
| | static cutlass::Status |
| | initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, |
| | CudaHostAdapter* cuda_adapter = nullptr) { |
| | return cutlass::Status::kSuccess; |
| | } |
| |
|
| | CUTLASS_DEVICE bool |
| | is_producer_load_needed() const { |
| | return false; |
| | } |
| |
|
| | CUTLASS_DEVICE bool |
| | is_C_load_needed() const { |
| | return false; |
| | } |
| |
|
| | CUTLASS_DEVICE bool |
| | is_zero() const { |
| | return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); |
| | } |
| |
|
| | CUTLASS_HOST_DEVICE |
| | Sm90ColOrScalarBroadcastArray() { } |
| |
|
| | CUTLASS_HOST_DEVICE |
| | Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) |
| | : params(params) { } |
| |
|
| | Params params; |
| |
|
| | template <class... Args> |
| | CUTLASS_DEVICE auto |
| | get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { |
| | return EmptyProducerLoadCallbacks{}; |
| | } |
| |
|
| | template<class GTensor, class RTensor, class CTensor, class ProblemShape> |
| | struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { |
| | CUTLASS_DEVICE |
| | ConsumerStoreCallbacks( |
| | GTensor&& tCgCol, |
| | RTensor&& tCrCol, |
| | CTensor&& tCcCol, |
| | ProblemShape problem_shape, |
| | int group, |
| | Params const& params |
| | ): |
| | tCgCol(cute::forward<GTensor>(tCgCol)), |
| | tCrCol(cute::forward<RTensor>(tCrCol)), |
| | tCcCol(cute::forward<CTensor>(tCcCol)), |
| | m(get<0>(problem_shape)), |
| | group(group), |
| | params(params) {} |
| |
|
| | GTensor tCgCol; |
| | RTensor tCrCol; |
| | CTensor tCcCol; |
| | Params const& params; |
| | int m; |
| | int group; |
| |
|
| | CUTLASS_DEVICE void |
| | begin() { |
| | Tensor pred = make_tensor<bool>(shape(tCgCol)); |
| | CUTLASS_PRAGMA_UNROLL |
| | for (int i = 0; i < size(pred); ++i) { |
| | pred(i) = get<0>(tCcCol(i)) < m; |
| | } |
| |
|
| | if (!params.col_broadcast) { |
| | fill(tCrCol, *(params.ptr_col_array[group])); |
| | return; |
| | } |
| |
|
| | |
| | |
| | copy_if(pred, filter(tCgCol), filter(tCrCol)); |
| | } |
| |
|
| | template <typename ElementAccumulator, int FragmentSize> |
| | CUTLASS_DEVICE Array<Element, FragmentSize> |
| | visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { |
| | Array<Element, FragmentSize> frg_col; |
| | Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); |
| |
|
| | CUTLASS_PRAGMA_UNROLL |
| | for (int i = 0; i < FragmentSize; ++i) { |
| | frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); |
| | } |
| |
|
| | return frg_col; |
| | } |
| |
|
| | }; |
| |
|
| | template < |
| | bool ReferenceSrc, |
| | class... Args |
| | > |
| | CUTLASS_DEVICE auto |
| | get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { |
| |
|
| | auto [M, N, K, L] = args.problem_shape_mnkl; |
| | auto [m, n, k, l] = args.tile_coord_mnkl; |
| |
|
| | Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); |
| | Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( |
| | mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); |
| | Tensor tCrCol = make_tensor_like(tCgCol); |
| |
|
| | |
| | |
| | |
| | Tensor cCol = make_identity_tensor(mCol.shape()); |
| | Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( |
| | cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); |
| |
|
| | return ConsumerStoreCallbacks( |
| | cute::move(tCgCol), |
| | cute::move(tCrCol), |
| | cute::move(tCcCol), |
| | args.problem_shape_mnkl, |
| | l, |
| | params |
| | ); |
| | } |
| | }; |
| |
|
| | } |
| |
|