| | #pragma once |
| |
|
| | #include "scaled_mm_c2x.cuh" |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | namespace vllm { |
| |
|
| | template <typename InType, typename OutType, |
| | template <typename, typename> typename Epilogue> |
| | struct sm80_config_default { |
| | |
| | |
| | |
| | |
| | static_assert(std::is_same<InType, int8_t>()); |
| | using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; |
| | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; |
| | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| | using Cutlass2xGemm = |
| | cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType, |
| | Epilogue, TileShape, WarpShape, InstructionShape, 5>; |
| | }; |
| |
|
| | template <typename InType, typename OutType, |
| | template <typename, typename> typename Epilogue> |
| | struct sm80_config_M64 { |
| | |
| | |
| | |
| | |
| | static_assert(std::is_same<InType, int8_t>()); |
| | using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; |
| | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; |
| | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| | using Cutlass2xGemm = |
| | cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType, |
| | Epilogue, TileShape, WarpShape, InstructionShape, 5>; |
| | }; |
| |
|
| | template <typename InType, typename OutType, |
| | template <typename, typename> typename Epilogue> |
| | struct sm80_config_M32 { |
| | |
| | |
| | static_assert(std::is_same<InType, int8_t>()); |
| | using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; |
| | using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; |
| | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| | using Cutlass2xGemm = |
| | cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType, |
| | Epilogue, TileShape, WarpShape, InstructionShape, 5>; |
| | }; |
| |
|
| | template <typename InType, typename OutType, |
| | template <typename, typename> typename Epilogue> |
| | struct sm80_config_M16 { |
| | |
| | |
| | static_assert(std::is_same<InType, int8_t>()); |
| | using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; |
| | using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; |
| | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; |
| | using Cutlass2xGemm = |
| | cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType, |
| | Epilogue, TileShape, WarpShape, InstructionShape, 5>; |
| | }; |
| |
|
| | template <typename InType, typename OutType, |
| | template <typename, typename> typename Epilogue, |
| | typename... EpilogueArgs> |
| | inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out, |
| | torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | EpilogueArgs&&... args) { |
| | static_assert(std::is_same<InType, int8_t>()); |
| | TORCH_CHECK(a.dtype() == torch::kInt8); |
| | TORCH_CHECK(b.dtype() == torch::kInt8); |
| |
|
| | using Cutlass2xGemmDefault = |
| | typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm; |
| | using Cutlass2xGemmM128BigN = |
| | typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm; |
| | using Cutlass2xGemmM128SmallN = |
| | typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm; |
| | using Cutlass2xGemmM64 = |
| | typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm; |
| | using Cutlass2xGemmM32 = |
| | typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm; |
| | using Cutlass2xGemmM16 = |
| | typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | using FallbackGemm = |
| | typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm; |
| |
|
| | uint32_t const m = a.size(0); |
| | uint32_t const mp2 = |
| | std::max(static_cast<uint32_t>(16), next_pow_2(m)); |
| | if (mp2 <= 16) { |
| | |
| | return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } else if (mp2 <= 32) { |
| | |
| | return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } else if (mp2 <= 64) { |
| | |
| | return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } else if (mp2 <= 128) { |
| | |
| | uint32_t const n = out.size(1); |
| | bool const small_n = n < 8192; |
| | if (small_n) { |
| | return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN, |
| | FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } else { |
| | return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } |
| | } else { |
| | |
| | return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>( |
| | out, a, b, std::forward<EpilogueArgs>(args)...); |
| | } |
| | } |
| |
|
| | } |
| |
|