| | #pragma once |
| | #include <stddef.h> |
| | #include <torch/all.h> |
| |
|
| | #include <ATen/cuda/CUDAContext.h> |
| |
|
| | |
| | |
| | #include "cute/tensor.hpp" |
| | #include "cute/atom/mma_atom.hpp" |
| | #include "cutlass/numeric_types.h" |
| |
|
| | #include "cutlass/cutlass.h" |
| | #include "cutlass/gemm_coord.h" |
| | #include "cutlass/arch/mma_sm75.h" |
| | #include "cutlass/arch/arch.h" |
| | #include "cutlass/arch/mma.h" |
| | #include "cutlass/gemm/device/gemm.h" |
| | #include "cutlass/gemm/device/gemm_universal_adapter.h" |
| |
|
| | #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" |
| | #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" |
| |
|
| | #include "core/math.hpp" |
| | #include "cutlass_extensions/common.hpp" |
| | |
| |
|
| | using namespace cute; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | namespace vllm { |
| |
|
| | |
| | |
| | |
| | |
| | |
| | template <typename Kernel> |
| | struct enable_sm75_to_sm80 : Kernel { |
| | template <typename... Args> |
| | CUTLASS_DEVICE static void invoke(Args&&... args) { |
| | #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800 |
| | Kernel::invoke(std::forward<Args>(args)...); |
| | #endif |
| | } |
| | }; |
| |
|
| | template <typename Kernel> |
| | struct enable_sm80_to_sm89 : Kernel { |
| | template <typename... Args> |
| | CUTLASS_DEVICE static void invoke(Args&&... args) { |
| | #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890 |
| | Kernel::invoke(std::forward<Args>(args)...); |
| | #endif |
| | } |
| | }; |
| |
|
| | template <typename Kernel> |
| | struct enable_sm89_to_sm90 : Kernel { |
| | template <typename... Args> |
| | CUTLASS_DEVICE static void invoke(Args&&... args) { |
| | #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900 |
| | Kernel::invoke(std::forward<Args>(args)...); |
| | #endif |
| | } |
| | }; |
| | template <typename Arch, template <typename> typename ArchGuard, |
| | typename ElementAB_, typename ElementD_, |
| | template <typename, typename> typename Epilogue_, typename TileShape, |
| | typename WarpShape, typename InstructionShape, int32_t MainLoopStages, |
| | typename FP8MathOperator = cutlass::arch::OpMultiplyAdd> |
| | struct cutlass_2x_gemm { |
| | using ElementAB = ElementAB_; |
| | using ElementD = ElementD_; |
| |
|
| | using ElementAcc = |
| | typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, |
| | float>::type; |
| |
|
| | using Operator = |
| | typename std::conditional<std::is_same_v<ElementAB, int8_t>, |
| | cutlass::arch::OpMultiplyAddSaturate, |
| | FP8MathOperator>::type; |
| |
|
| | using OutputTileThreadMap = |
| | cutlass::epilogue::threadblock::OutputTileThreadLayout< |
| | TileShape, WarpShape, float, 4, 1 |
| | >; |
| |
|
| | using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>; |
| | using EVTCompute = typename Epilogue::EVTCompute; |
| |
|
| | using D = cutlass::epilogue::threadblock::VisitorAuxStore< |
| | OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, |
| | Stride<int64_t, Int<1>, Int<0>>>; |
| |
|
| | using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>; |
| |
|
| | |
| | static constexpr int AlignmentAB = |
| | 128 / cutlass::sizeof_bits<ElementAB>::value; |
| | static constexpr int AlignmentCD = 4; |
| |
|
| | |
| | using RowMajor = typename cutlass::layout::RowMajor; |
| | using ColumnMajor = typename cutlass::layout::ColumnMajor; |
| | using KernelType = |
| | ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor< |
| | ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB, |
| | ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB, |
| | float, cutlass::layout::RowMajor, AlignmentCD, |
| | ElementAcc, float, cutlass::arch::OpClassTensorOp, |
| | Arch, |
| | TileShape, WarpShape, InstructionShape, |
| | EVTD, |
| | cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, |
| | MainLoopStages, Operator, |
| | 1 |
| | >::GemmKernel>; |
| | |
| |
|
| | using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>; |
| | }; |
| |
|
| | template <typename Gemm, typename... EpilogueArgs> |
| | inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | EpilogueArgs&&... epilogue_params) { |
| | using ElementAB = typename Gemm::ElementAB; |
| | using ElementD = typename Gemm::ElementD; |
| |
|
| | int32_t m = a.size(0); |
| | int32_t n = b.size(1); |
| | int32_t k = a.size(1); |
| | cutlass::gemm::GemmCoord problem_size{m, n, k}; |
| |
|
| | int64_t lda = a.stride(0); |
| | int64_t ldb = b.stride(1); |
| | int64_t ldc = out.stride(0); |
| |
|
| | using StrideC = Stride<int64_t, Int<1>, Int<0>>; |
| | StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; |
| |
|
| | auto a_ptr = static_cast<ElementAB const*>(a.data_ptr()); |
| | auto b_ptr = static_cast<ElementAB const*>(b.data_ptr()); |
| | auto c_ptr = static_cast<ElementD*>(out.data_ptr()); |
| |
|
| | typename Gemm::D::Arguments d_args{c_ptr, c_stride}; |
| |
|
| | using Epilogue = typename Gemm::Epilogue; |
| | auto evt_args = |
| | Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...); |
| |
|
| | typename Gemm::EVTD::Arguments epilogue_args{ |
| | evt_args, |
| | d_args, |
| | }; |
| |
|
| | typename Gemm::Op::Arguments args{ |
| | cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, |
| | problem_size, |
| | 1, |
| | epilogue_args, |
| | a_ptr, |
| | b_ptr, |
| | nullptr, |
| | nullptr, |
| | 0, |
| | 0, |
| | 0, |
| | 0, |
| | lda, |
| | ldb, |
| | ldc, |
| | ldc}; |
| |
|
| | |
| | typename Gemm::Op gemm_op; |
| | size_t workspace_size = gemm_op.get_workspace_size(args); |
| | auto const workspace_options = |
| | torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); |
| | auto workspace = torch::empty(workspace_size, workspace_options); |
| |
|
| | auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); |
| |
|
| | CUTLASS_CHECK(gemm_op.can_implement(args)); |
| | cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream); |
| | CUTLASS_CHECK(status); |
| | } |
| |
|
| | template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs> |
| | inline void fallback_cutlass_gemm_caller(torch::Tensor& out, |
| | torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | EpilogueArgs&&... args) { |
| | |
| | |
| | |
| | static const int max_shared_mem_per_block_opt_in = |
| | get_cuda_max_shared_memory_per_block_opt_in(0); |
| |
|
| | size_t const gemm_shared_mem_size = |
| | sizeof(typename Gemm::KernelType::SharedStorage); |
| | size_t const fallback_gemm_shared_mem_size = |
| | sizeof(typename FallbackGemm::KernelType::SharedStorage); |
| |
|
| | if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) { |
| | return cutlass_gemm_caller<Gemm>(out, a, b, |
| | std::forward<EpilogueArgs>(args)...); |
| | } else { |
| | TORCH_CHECK(fallback_gemm_shared_mem_size <= |
| | max_shared_mem_per_block_opt_in); |
| | return cutlass_gemm_caller<FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } |
| | } |
| |
|
| | } |
| |
|