File size: 5,770 Bytes
0da5bf5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | #pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
must contain a public type named EVTCompute of type Sm90EVT, as well as a
static prepare_args function that constructs an EVTCompute::Arguments struct.
*/
using namespace cute;
namespace vllm {
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
} // namespace vllm
|