| | #pragma once |
| |
|
| | |
| | |
| |
|
| | #include "cutlass/cutlass.h" |
| |
|
| | #include "cute/tensor.hpp" |
| | #include "cute/atom/mma_atom.hpp" |
| | #include "cutlass/numeric_types.h" |
| |
|
| | #include "cutlass/gemm/device/gemm_universal_adapter.h" |
| | #include "cutlass/gemm/kernel/gemm_universal.hpp" |
| | #include "cutlass/epilogue/collective/collective_builder.hpp" |
| | #include "cutlass/gemm/collective/collective_builder.hpp" |
| |
|
| | #include "core/math.hpp" |
| | #include "cutlass_extensions/common.hpp" |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | using namespace cute; |
| |
|
| | namespace vllm { |
| |
|
| | template <typename ElementAB_, typename ElementD_, |
| | template <typename, typename, typename> typename Epilogue_, |
| | typename TileShape, typename ClusterShape, typename KernelSchedule, |
| | typename EpilogueSchedule> |
| | struct cutlass_3x_gemm { |
| | using ElementAB = ElementAB_; |
| | using ElementD = ElementD_; |
| | using ElementAcc = |
| | typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, |
| | float>::type; |
| |
|
| | using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>; |
| |
|
| | using StrideD = Stride<int64_t, Int<1>, Int<0>>; |
| | using ElementC = void; |
| | using StrideC = StrideD; |
| |
|
| | using EVTCompute = typename Epilogue::EVTCompute; |
| |
|
| | |
| | static constexpr int AlignmentAB = |
| | 128 / cutlass::sizeof_bits<ElementAB>::value; |
| | static constexpr int AlignmentCD = 4; |
| |
|
| | using CollectiveEpilogue = |
| | typename cutlass::epilogue::collective::CollectiveBuilder< |
| | cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, |
| | ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, |
| | ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, |
| | AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; |
| |
|
| | static constexpr size_t CEStorageSize = |
| | sizeof(typename CollectiveEpilogue::SharedStorage); |
| | using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< |
| | static_cast<int>(CEStorageSize)>; |
| |
|
| | |
| | using CollectiveMainloop = |
| | typename cutlass::gemm::collective::CollectiveBuilder< |
| | cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, |
| | ElementAB, cutlass::layout::RowMajor, AlignmentAB, |
| | ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, |
| | ElementAcc, TileShape, ClusterShape, |
| | Stages, |
| | KernelSchedule>::CollectiveOp; |
| | |
| |
|
| | using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< |
| | cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, |
| | cutlass::gemm::PersistentScheduler>>; |
| |
|
| | struct GemmKernel : public KernelType {}; |
| | }; |
| |
|
| | template <typename ElementAB_, typename ElementD_, |
| | template <typename, typename, typename> typename Epilogue_, |
| | typename TileShape, typename ClusterShape, typename KernelSchedule, |
| | typename EpilogueSchedule> |
| | struct cutlass_3x_gemm_sm100 { |
| | using ElementAB = ElementAB_; |
| | using LayoutA = cutlass::layout::RowMajor; |
| | static constexpr int AlignmentA = |
| | 128 / cutlass::sizeof_bits<ElementAB>::value; |
| |
|
| | using LayoutB = cutlass::layout::ColumnMajor; |
| | static constexpr int AlignmentB = |
| | 128 / cutlass::sizeof_bits<ElementAB>::value; |
| |
|
| | using ElementC = void; |
| | using LayoutC = cutlass::layout::RowMajor; |
| | static constexpr int AlignmentC = |
| | 128 / cutlass::sizeof_bits<ElementD_>::value; |
| |
|
| | using ElementD = ElementD_; |
| | using LayoutD = cutlass::layout::RowMajor; |
| | static constexpr int AlignmentD = AlignmentC; |
| |
|
| | using ElementAcc = |
| | typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, |
| | float>::type; |
| | using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>; |
| |
|
| | |
| | using ElementAccumulator = float; |
| |
|
| | |
| | using ElementBias = cutlass::half_t; |
| | using ElementCompute = float; |
| | using ElementAux = ElementD; |
| | using LayoutAux = LayoutD; |
| | using ElementAmax = float; |
| |
|
| | using EVTCompute = typename Epilogue::EVTCompute; |
| |
|
| | using CollectiveEpilogue = |
| | typename cutlass::epilogue::collective::CollectiveBuilder< |
| | cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, |
| | ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, |
| | ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, |
| | ElementD, LayoutD, AlignmentD, EpilogueSchedule, |
| | EVTCompute>::CollectiveOp; |
| |
|
| | using CollectiveMainloop = |
| | typename cutlass::gemm::collective::CollectiveBuilder< |
| | cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, |
| | LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, |
| | ElementAccumulator, TileShape, ClusterShape, |
| | cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| | sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| | KernelSchedule>::CollectiveOp; |
| |
|
| | using GemmKernel = cutlass::gemm::kernel::GemmUniversal< |
| | Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; |
| | }; |
| |
|
| | } |
| |
|