| #pragma once |
|
|
| |
| |
|
|
| #include "cutlass/cutlass.h" |
|
|
| #include "cute/tensor.hpp" |
| #include "cute/atom/mma_atom.hpp" |
| #include "cutlass/numeric_types.h" |
|
|
| #include "cutlass/gemm/device/gemm_universal_adapter.h" |
| #include "cutlass/gemm/kernel/gemm_universal.hpp" |
| #include "cutlass/epilogue/collective/collective_builder.hpp" |
| #include "cutlass/gemm/collective/collective_builder.hpp" |
|
|
| #include "core/math.hpp" |
| #include "cutlass_extensions/common.hpp" |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| using namespace cute; |
|
|
| namespace vllm { |
|
|
| template <typename ElementAB_, typename ElementD_, |
| template <typename, typename, typename> typename Epilogue_, |
| typename TileShape, typename ClusterShape, typename KernelSchedule, |
| typename EpilogueSchedule> |
| struct cutlass_3x_gemm { |
| using ElementAB = ElementAB_; |
| using ElementD = ElementD_; |
| using ElementAcc = |
| typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, |
| float>::type; |
|
|
| using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>; |
|
|
| using StrideD = Stride<int64_t, Int<1>, Int<0>>; |
| using ElementC = void; |
| using StrideC = StrideD; |
|
|
| using EVTCompute = typename Epilogue::EVTCompute; |
|
|
| |
| static constexpr int AlignmentAB = |
| 128 / cutlass::sizeof_bits<ElementAB>::value; |
| static constexpr int AlignmentCD = 4; |
|
|
| using CollectiveEpilogue = |
| typename cutlass::epilogue::collective::CollectiveBuilder< |
| cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, |
| ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, |
| ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, |
| AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; |
|
|
| static constexpr size_t CEStorageSize = |
| sizeof(typename CollectiveEpilogue::SharedStorage); |
| using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< |
| static_cast<int>(CEStorageSize)>; |
|
|
| |
| using CollectiveMainloop = |
| typename cutlass::gemm::collective::CollectiveBuilder< |
| cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, |
| ElementAB, cutlass::layout::RowMajor, AlignmentAB, |
| ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, |
| ElementAcc, TileShape, ClusterShape, |
| Stages, |
| KernelSchedule>::CollectiveOp; |
| |
|
|
| using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< |
| cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, |
| cutlass::gemm::PersistentScheduler>>; |
|
|
| struct GemmKernel : public KernelType {}; |
| }; |
|
|
| template <typename ElementAB_, typename ElementD_, |
| template <typename, typename, typename> typename Epilogue_, |
| typename TileShape, typename ClusterShape, typename KernelSchedule, |
| typename EpilogueSchedule> |
| struct cutlass_3x_gemm_sm100 { |
| using ElementAB = ElementAB_; |
| using LayoutA = cutlass::layout::RowMajor; |
| static constexpr int AlignmentA = |
| 128 / cutlass::sizeof_bits<ElementAB>::value; |
|
|
| using LayoutB = cutlass::layout::ColumnMajor; |
| static constexpr int AlignmentB = |
| 128 / cutlass::sizeof_bits<ElementAB>::value; |
|
|
| using ElementC = void; |
| using LayoutC = cutlass::layout::RowMajor; |
| static constexpr int AlignmentC = |
| 128 / cutlass::sizeof_bits<ElementD_>::value; |
|
|
| using ElementD = ElementD_; |
| using LayoutD = cutlass::layout::RowMajor; |
| static constexpr int AlignmentD = AlignmentC; |
|
|
| using ElementAcc = |
| typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, |
| float>::type; |
| using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>; |
|
|
| |
| using ElementAccumulator = float; |
|
|
| |
| using ElementBias = cutlass::half_t; |
| using ElementCompute = float; |
| using ElementAux = ElementD; |
| using LayoutAux = LayoutD; |
| using ElementAmax = float; |
|
|
| using EVTCompute = typename Epilogue::EVTCompute; |
|
|
| using CollectiveEpilogue = |
| typename cutlass::epilogue::collective::CollectiveBuilder< |
| cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, |
| ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, |
| ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, |
| ElementD, LayoutD, AlignmentD, EpilogueSchedule, |
| EVTCompute>::CollectiveOp; |
|
|
| using CollectiveMainloop = |
| typename cutlass::gemm::collective::CollectiveBuilder< |
| cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, |
| LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, |
| ElementAccumulator, TileShape, ClusterShape, |
| cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| KernelSchedule>::CollectiveOp; |
|
|
| using GemmKernel = cutlass::gemm::kernel::GemmUniversal< |
| Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; |
| }; |
|
|
| } |
|
|