| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #pragma once |
|
|
| |
| |
|
|
| #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" |
| #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" |
| #include "cute/tensor.hpp" |
|
|
| namespace cutlass::epilogue::threadblock { |
|
|
| using namespace cute; |
| using namespace detail; |
|
|
| template< |
| class ThreadMap, |
| class Element, |
| class StrideMNL |
| > |
| struct VisitorRowOrScalarBroadcast { |
|
|
| |
| |
| struct Arguments { |
| Element const* ptr_row = nullptr; |
| bool row_broadcast = true; |
| StrideMNL dRow = {}; |
| }; |
|
|
| using Params = Arguments; |
|
|
| template <class ProblemShape> |
| static constexpr Params |
| to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
| return args; |
| } |
|
|
| template <class ProblemShape> |
| static size_t |
| get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
| return 0; |
| } |
|
|
| struct SharedStorage {}; |
|
|
| |
| static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value; |
| using VecType = uint_bit_t<cute::min(128, vec_bits)>; |
| static int constexpr VecLength = sizeof(VecType) / sizeof(Element); |
|
|
| CUTLASS_HOST_DEVICE |
| VisitorRowOrScalarBroadcast() { } |
|
|
| CUTLASS_HOST_DEVICE |
| VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) |
| : params_ptr(¶ms) { } |
|
|
| Params const* params_ptr; |
|
|
| template <class GTensor, class RTensor, class CTensor, class ProblemShape> |
| struct Callbacks : EmptyCallbacks { |
| CUTLASS_DEVICE |
| Callbacks( |
| GTensor&& tC_gRow, |
| RTensor&& tC_rRow, |
| CTensor&& tC_cRow, |
| ProblemShape problem_shape, |
| Params const* params_ptr |
| ): |
| tC_gRow(cute::forward<GTensor>(tC_gRow)), |
| tC_rRow(cute::forward<RTensor>(tC_rRow)), |
| tC_cRow(cute::forward<CTensor>(tC_cRow)), |
| n(get<1>(problem_shape)), |
| params_ptr(params_ptr) { } |
|
|
| GTensor tC_gRow; |
| RTensor tC_rRow; |
| CTensor tC_cRow; |
| Params const* params_ptr; |
| int n; |
|
|
| |
| CUTLASS_DEVICE void |
| begin_epilogue() { |
| clear(tC_rRow); |
| auto src_v = filter(tC_gRow); |
| auto coord_v = filter(tC_cRow); |
| auto dst_v = filter(tC_rRow); |
|
|
| if (params_ptr->row_broadcast) { |
| |
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < size(src_v); ++i) { |
| bool guard = get<1>(coord_v(i)) < n; |
| cutlass::arch::global_load<VecType, sizeof(VecType)>( |
| dst_v(i), (void const*)&src_v(i), guard); |
| } |
| } else { |
| |
| VecType filled_vec; |
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < VecLength; i++) { |
| reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row); |
| } |
|
|
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < size(src_v); ++i) { |
| if (get<1>(coord_v(i)) < n) { |
| dst_v(i) = filled_vec; |
| } |
| } |
| } |
| } |
|
|
| template <class ElementAccumulator, int FragmentSize> |
| CUTLASS_DEVICE auto |
| visit(int iter_idx, int row_idx, int column_idx, int frg_idx, |
| Array<ElementAccumulator, FragmentSize> const& frg_acc) { |
| Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow)); |
| return rRow_frg(column_idx); |
| } |
| }; |
|
|
| template <class ProblemShape> |
| CUTLASS_DEVICE auto |
| get_callbacks( |
| gemm::GemmCoord threadblock_tile_offset, |
| int thread_idx, |
| ProblemShape problem_shape |
| ) { |
| Tensor mRow = make_tensor( |
| make_gmem_ptr(params_ptr->ptr_row), |
| problem_shape, |
| params_ptr->dRow); |
|
|
| |
| Tensor tC_gRow = recast<VecType>( |
| ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) |
| )(_,_,_0{},_0{},_0{},_0{}); |
| Tensor tC_rRow = make_tensor_like(tC_gRow); |
|
|
| |
| Tensor cRow = make_identity_tensor(mRow.shape()); |
| Tensor tC_cRow = outer_partition( |
| ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), |
| Shape<Int<VecLength>>{}, |
| (_0{}) |
| ); |
|
|
| return Callbacks< |
| decltype(tC_gRow), decltype(tC_rRow), |
| decltype(tC_cRow), ProblemShape>( |
| cute::move(tC_gRow), |
| cute::move(tC_rRow), |
| cute::move(tC_cRow), |
| problem_shape, |
| params_ptr |
| ); |
| } |
|
|
| }; |
|
|
| |
|
|
| |
| template< |
| class ThreadMap, |
| class Element, |
| class StrideMNL |
| > |
| struct VisitorRowOrZeroBroadcast { |
|
|
| |
| struct Arguments { |
| Element const* ptr_row = nullptr; |
| StrideMNL dRow = {}; |
| }; |
|
|
| using Params = Arguments; |
|
|
| template <class ProblemShape> |
| static constexpr Params |
| to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
| return args; |
| } |
|
|
| template <class ProblemShape> |
| static size_t |
| get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
| return 0; |
| } |
|
|
| struct SharedStorage {}; |
|
|
| |
| static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value; |
| using VecType = uint_bit_t<cute::min(128, vec_bits)>; |
| static int constexpr VecLength = sizeof(VecType) / sizeof(Element); |
|
|
| CUTLASS_HOST_DEVICE |
| VisitorRowOrZeroBroadcast() { } |
|
|
| CUTLASS_HOST_DEVICE |
| VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) |
| : params_ptr(¶ms) { } |
|
|
| Params const* params_ptr; |
|
|
| template <class GTensor, class RTensor, class CTensor, class ProblemShape> |
| struct Callbacks : EmptyCallbacks { |
| CUTLASS_DEVICE |
| Callbacks( |
| GTensor&& tC_gRow, |
| RTensor&& tC_rRow, |
| CTensor&& tC_cRow, |
| ProblemShape problem_shape, |
| Params const* params_ptr |
| ): |
| tC_gRow(cute::forward<GTensor>(tC_gRow)), |
| tC_rRow(cute::forward<RTensor>(tC_rRow)), |
| tC_cRow(cute::forward<CTensor>(tC_cRow)), |
| n(get<1>(problem_shape)), |
| params_ptr(params_ptr) { } |
|
|
| GTensor tC_gRow; |
| RTensor tC_rRow; |
| CTensor tC_cRow; |
| Params const* params_ptr; |
| int n; |
|
|
| |
| CUTLASS_DEVICE void |
| begin_epilogue() { |
| clear(tC_rRow); |
| auto src_v = filter(tC_gRow); |
| auto coord_v = filter(tC_cRow); |
| auto dst_v = filter(tC_rRow); |
|
|
| if (params_ptr->ptr_row != nullptr) { |
| |
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < size(src_v); ++i) { |
| bool guard = get<1>(coord_v(i)) < n; |
| cutlass::arch::global_load<VecType, sizeof(VecType)>( |
| dst_v(i), (void const*)&src_v(i), guard); |
| } |
| } else { |
| |
| VecType filled_vec; |
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < VecLength; i++) { |
| reinterpret_cast<Element*>(&filled_vec)[i] = Element{0}; |
| } |
|
|
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < size(src_v); ++i) { |
| if (get<1>(coord_v(i)) < n) { |
| dst_v(i) = filled_vec; |
| } |
| } |
| } |
| } |
|
|
| template <class ElementAccumulator, int FragmentSize> |
| CUTLASS_DEVICE auto |
| visit(int iter_idx, int row_idx, int column_idx, int frg_idx, |
| Array<ElementAccumulator, FragmentSize> const& frg_acc) { |
| Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow)); |
| return rRow_frg(column_idx); |
| } |
| }; |
|
|
| template <class ProblemShape> |
| CUTLASS_DEVICE auto |
| get_callbacks( |
| gemm::GemmCoord threadblock_tile_offset, |
| int thread_idx, |
| ProblemShape problem_shape |
| ) { |
| Tensor mRow = make_tensor( |
| make_gmem_ptr(params_ptr->ptr_row), |
| problem_shape, |
| params_ptr->dRow); |
|
|
| |
| Tensor tC_gRow = recast<VecType>( |
| ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) |
| )(_,_,_0{},_0{},_0{},_0{}); |
| Tensor tC_rRow = make_tensor_like(tC_gRow); |
|
|
| |
| Tensor cRow = make_identity_tensor(mRow.shape()); |
| Tensor tC_cRow = outer_partition( |
| ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), |
| Shape<Int<VecLength>>{}, |
| (_0{}) |
| ); |
|
|
| return Callbacks< |
| decltype(tC_gRow), decltype(tC_rRow), |
| decltype(tC_cRow), ProblemShape>( |
| cute::move(tC_gRow), |
| cute::move(tC_rRow), |
| cute::move(tC_cRow), |
| problem_shape, |
| params_ptr |
| ); |
| } |
|
|
| }; |
|
|
|
|
| |
|
|
| |
| template< |
| class ThreadMap, |
| class Element, |
| class StrideMNL = Stride<_1,_0,_0> |
| > |
| struct VisitorColOrScalarBroadcast { |
|
|
| |
| |
| struct Arguments { |
| Element const* ptr_col = nullptr; |
| bool col_broadcast = true; |
| StrideMNL dCol = {}; |
| }; |
|
|
| using Params = Arguments; |
|
|
| template <class ProblemShape> |
| static constexpr Params |
| to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
| return args; |
| } |
|
|
| template <class ProblemShape> |
| static size_t |
| get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
| return 0; |
| } |
|
|
| struct SharedStorage { }; |
|
|
| CUTLASS_HOST_DEVICE |
| VisitorColOrScalarBroadcast() { } |
|
|
| CUTLASS_HOST_DEVICE |
| VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) |
| : params_ptr(¶ms) { } |
|
|
| Params const* params_ptr; |
|
|
| template <class GTensor, class RTensor, class CTensor, class ProblemShape> |
| struct Callbacks : EmptyCallbacks { |
| CUTLASS_DEVICE |
| Callbacks( |
| GTensor&& tC_gCol, |
| RTensor&& tC_rCol, |
| CTensor&& tC_cCol, |
| ProblemShape problem_shape, |
| Params const* params_ptr |
| ): |
| tC_gCol(cute::forward<GTensor>(tC_gCol)), |
| tC_rCol(cute::forward<RTensor>(tC_rCol)), |
| tC_cCol(cute::forward<CTensor>(tC_cCol)), |
| m(get<0>(problem_shape)), |
| params_ptr(params_ptr) { } |
|
|
| GTensor tC_gCol; |
| RTensor tC_rCol; |
| CTensor tC_cCol; |
| Params const* params_ptr; |
| int m; |
|
|
| |
| CUTLASS_DEVICE void |
| begin_epilogue() { |
| clear(tC_rCol); |
|
|
| Tensor pred = make_tensor<bool>(shape(tC_gCol)); |
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < size(pred); ++i) { |
| pred(i) = get<0>(tC_cCol(i)) < m; |
| } |
|
|
| if (params_ptr->col_broadcast) { |
| |
| copy_if(pred, tC_gCol, tC_rCol); |
| } else { |
| |
| auto dst_v = filter(tC_rCol); |
|
|
| CUTLASS_PRAGMA_UNROLL |
| for (int i = 0; i < size(dst_v); ++i) { |
| if (pred(i)) { |
| dst_v(i) = *(params_ptr->ptr_col); |
| } |
| } |
| } |
| } |
|
|
| template <class ElementAccumulator, int FragmentSize> |
| CUTLASS_DEVICE auto |
| visit(int iter_idx, int row_idx, int column_idx, int frg_idx, |
| Array<ElementAccumulator, FragmentSize> const& frg_acc) { |
| Array<Element, FragmentSize> frg_col; |
| frg_col.fill(tC_rCol(row_idx,iter_idx)); |
| return frg_col; |
| } |
| }; |
|
|
| template <class ProblemShape> |
| CUTLASS_DEVICE auto |
| get_callbacks( |
| gemm::GemmCoord threadblock_tile_offset, |
| int thread_idx, |
| ProblemShape problem_shape |
| ) { |
| Tensor mCol = make_tensor( |
| make_gmem_ptr(params_ptr->ptr_col), |
| problem_shape, |
| params_ptr->dCol); |
|
|
| |
| Tensor tC_gCol = group_modes<1,4>( |
| ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); |
| Tensor tC_rCol = make_tensor_like(tC_gCol); |
|
|
| |
| Tensor cCol = make_identity_tensor(mCol.shape()); |
| Tensor tC_cCol = group_modes<1,4>( |
| ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); |
|
|
| return Callbacks< |
| decltype(tC_gCol), decltype(tC_rCol), |
| decltype(tC_cCol), ProblemShape>( |
| cute::move(tC_gCol), |
| cute::move(tC_rCol), |
| cute::move(tC_cCol), |
| problem_shape, |
| params_ptr |
| ); |
| } |
| }; |
|
|
| } |
|
|