| #pragma once |
|
|
| #include "cutlass/cutlass.h" |
| #include "cutlass/numeric_types.h" |
|
|
| #include "cute/tensor.hpp" |
| #include "cutlass/tensor_ref.h" |
| #include "cutlass/gemm/dispatch_policy.hpp" |
| #include "cutlass/gemm/collective/collective_builder.hpp" |
| #include "cutlass/gemm/device/gemm_universal_adapter.h" |
| #include "cutlass/gemm/kernel/gemm_universal.hpp" |
| #include "cutlass/gemm/kernel/tile_scheduler_params.h" |
| #include "cutlass/epilogue/dispatch_policy.hpp" |
| #include "cutlass/epilogue/collective/collective_builder.hpp" |
|
|
| #include "cutlass_extensions/gemm/dispatch_policy.hpp" |
| #include "cutlass_extensions/gemm/collective/collective_builder.hpp" |
|
|
| #include "cutlass_gemm_caller.cuh" |
|
|
| namespace vllm { |
|
|
| using namespace cute; |
|
|
| template <typename SchedulerType, typename OutType, int GroupSizeM_, |
| int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128, |
| class ClusterShape = Shape<_1, _2, _1>> |
| struct cutlass_3x_gemm_fp8_blockwise { |
| using GroupSizeM = Int<GroupSizeM_>; |
| using GroupSizeN = Int<GroupSizeN_>; |
| using GroupSizeK = Int<GroupSizeK_>; |
| using TileSizeM = Int<TileSizeM_>; |
|
|
| static_assert(TileSizeM_ % GroupSizeM_ == 0, |
| "TileSizeM must be a multiple of GroupSizeM"); |
|
|
| using ElementAB = cutlass::float_e4m3_t; |
|
|
| using ElementA = ElementAB; |
| using LayoutA = cutlass::layout::RowMajor; |
| static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
|
|
| using ElementB = ElementAB; |
| using LayoutB = cutlass::layout::ColumnMajor; |
| static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; |
|
|
| using ElementD = OutType; |
| using StrideD = Stride<int64_t, Int<1>, Int<0>>; |
| static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; |
|
|
| using ElementC = void; |
| using StrideC = StrideD; |
| static constexpr int AlignmentC = AlignmentD; |
|
|
| using ElementAccumulator = float; |
| using ElementBlockScale = float; |
| using ElementCompute = float; |
| using ArchTag = cutlass::arch::Sm90; |
| using OperatorClass = cutlass::arch::OpClassTensorOp; |
| using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>; |
|
|
| using KernelSchedule = cutlass::gemm:: |
| KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< |
| GroupSizeM_>; |
| using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; |
| using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; |
|
|
| using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< |
| cutlass::epilogue::fusion::Sm90AccFetch>; |
|
|
| using CollectiveEpilogue = |
| typename cutlass::epilogue::collective::CollectiveBuilder< |
| ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, |
| ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, |
| ElementD, StrideD, AlignmentD, EpilogueSchedule, |
| StoreEpilogueCompute>::CollectiveOp; |
|
|
| using CollectiveMainloop = |
| typename cutlass::gemm::collective::CollectiveBuilder< |
| ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, |
| LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, |
| cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| KernelSchedule>::CollectiveOp; |
|
|
| using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< |
| Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, |
| SchedulerType>>; |
|
|
| struct GemmKernel : public KernelType {}; |
|
|
| using StrideA = typename GemmKernel::StrideA; |
| using StrideB = typename GemmKernel::StrideB; |
| }; |
|
|
| template <typename Gemm> |
| void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, |
| torch::Tensor const& b, |
| torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales) { |
| using GemmKernel = typename Gemm::GemmKernel; |
|
|
| using ElementAB = typename Gemm::ElementAB; |
| using ElementD = typename Gemm::ElementD; |
|
|
| auto prob_shape = c3x::get_problem_shape(a, b); |
| int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), |
| k = get<2>(prob_shape); |
|
|
| int64_t lda = a.stride(0); |
| int64_t ldb = b.stride(1); |
| int64_t ldc = out.stride(0); |
|
|
| using StrideA = Stride<int64_t, Int<1>, int64_t>; |
| using StrideB = Stride<int64_t, Int<1>, int64_t>; |
| using StrideC = typename Gemm::StrideC; |
|
|
| StrideA a_stride{lda, Int<1>{}, 0}; |
| StrideB b_stride{ldb, Int<1>{}, 0}; |
| StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; |
|
|
| auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); |
| auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); |
| auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr()); |
| auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr()); |
|
|
| |
| |
| auto is_contiguous_vector = [](const torch::Tensor& t) { |
| auto t_sizes = t.sizes(); |
| return t.is_contiguous() && |
| (t.dim() == 1 || |
| (t.dim() == 2 && |
| *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); |
| }; |
|
|
| |
| |
| TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); |
| TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); |
| TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), |
| "a_scales must be M major"); |
| TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); |
| TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); |
| TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), |
| "b_scales must be K major"); |
| typename GemmKernel::MainloopArguments mainloop_args{ |
| a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; |
|
|
| auto c_ptr = static_cast<ElementD*>(out.data_ptr()); |
| typename GemmKernel::EpilogueArguments epilogue_args{ |
| {}, c_ptr, c_stride, c_ptr, c_stride}; |
|
|
| typename GemmKernel::TileSchedulerArguments scheduler; |
|
|
| static constexpr bool UsesStreamKScheduler = |
| cute::is_same_v<typename GemmKernel::TileSchedulerTag, |
| cutlass::gemm::StreamKScheduler>; |
|
|
| if constexpr (UsesStreamKScheduler) { |
| using DecompositionMode = typename cutlass::gemm::kernel::detail:: |
| PersistentTileSchedulerSm90StreamKParams::DecompositionMode; |
| using ReductionMode = typename cutlass::gemm::kernel::detail:: |
| PersistentTileSchedulerSm90StreamKParams::ReductionMode; |
|
|
| scheduler.decomposition_mode = DecompositionMode::StreamK; |
| scheduler.reduction_mode = ReductionMode::Nondeterministic; |
| } |
|
|
| c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args, |
| epilogue_args, scheduler); |
| } |
|
|
| template <typename OutType> |
| void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, |
| torch::Tensor const& a, |
| torch::Tensor const& b, |
| torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales) { |
| auto k = a.size(1); |
| auto n = b.size(1); |
|
|
| if (k > 3 * n) { |
| cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
| cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>( |
| out, a, b, a_scales, b_scales); |
| } else { |
| cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
| cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>( |
| out, a, b, a_scales, b_scales); |
| } |
| } |
|
|
| } |