#pragma once #include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass_extensions/gemm/dispatch_policy.hpp" #include "cutlass_extensions/gemm/collective/collective_builder.hpp" #include "cutlass_gemm_caller.cuh" namespace vllm { using namespace cute; // clang-format off template struct cutlass_3x_gemm_fp8_blockwise { static constexpr bool swap_ab = swap_ab_; using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; using LayoutA = cutlass::layout::RowMajor; using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; using ElementB = ElementAB; using LayoutB = cutlass::layout::ColumnMajor; using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; using ElementD = OutType; using LayoutD = cutlass::layout::RowMajor; using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; using ElementC = void; // TODO: support bias using LayoutC = LayoutD; using LayoutC_Transpose = LayoutD_Transpose; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; using ScaleConfig = conditional_t, cutlass::detail::Sm100BlockwiseScaleConfig< ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::UMMA::Major::MN, cute::UMMA::Major::K>>; // layout_SFA and layout_SFB cannot be swapped since they are deduced. using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; using ElementScalar = float; using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, conditional_t, AlignmentC, ElementD, conditional_t, AlignmentD, EpilogueScheduler, DefaultOperation >::CollectiveOp; using StageCountType = cutlass::gemm::collective::StageCountAuto; using CollectiveMainloop = conditional_t, AlignmentB, ElementA, cute::tuple, AlignmentA, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduler >::CollectiveOp, typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, cute::tuple, AlignmentA, ElementB, cute::tuple, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduler >::CollectiveOp>; using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; }; template void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { static constexpr bool swap_ab = Gemm::swap_ab; using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; using StrideD = typename Gemm::GemmKernel::StrideD; using StrideC = typename Gemm::GemmKernel::StrideC; using LayoutSFA = typename Gemm::LayoutSFA; using LayoutSFB = typename Gemm::LayoutSFB; using ScaleConfig = typename Gemm::ScaleConfig; using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int32_t m = a.size(0), n = b.size(1), k = a.size(1); StrideA a_stride; StrideB b_stride; StrideC c_stride; a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); c_stride = cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); LayoutSFA layout_SFA = swap_ab ? ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); LayoutSFB layout_SFB = swap_ab ? ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto a_scales_ptr = static_cast(a_scales.data_ptr()); auto b_scales_ptr = static_cast(b_scales.data_ptr()); auto mainloop_args = [&](){ // layout_SFA and layout_SFB cannot be swapped since they are deduced. if (swap_ab) { return typename GemmKernel::MainloopArguments{ b_ptr, b_stride, a_ptr, a_stride, b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB }; } else { return typename GemmKernel::MainloopArguments{ a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB }; } }(); auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, epilogue_args); } template void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); constexpr int TILE_K = 128; // TODO: better heuristics bool swap_ab = (m < 16) || (m % 4 != 0); bool use_tma_epilogue = (m * n) % 4 == 0; if (!swap_ab) { constexpr int TILE_N = 128; int tile_m = 256; if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) { tile_m = 64; } else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) { tile_m = 128; } if (tile_m == 64) { if (use_tma_epilogue) { cutlass_gemm_caller_blockwise, Int>, Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } else { cutlass_gemm_caller_blockwise, Int>, Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } } else if (tile_m == 128) { if (use_tma_epilogue) { cutlass_gemm_caller_blockwise, Int>, Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } else { cutlass_gemm_caller_blockwise, Int>, Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } } else { // tile_m == 256 if (use_tma_epilogue) { cutlass_gemm_caller_blockwise, Int>, Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( out, a, b, a_scales, b_scales); } else { cutlass_gemm_caller_blockwise, Int>, Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( out, a, b, a_scales, b_scales); } } } else { // TODO: Test more tile N configs constexpr int TILE_M = 128; constexpr int TILE_N = 16; // TMA epilogue isn't compatible with Swap A/B cutlass_gemm_caller_blockwise, Int, Int>, Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( out, a, b, a_scales, b_scales); } } } // namespace vllm