| #pragma once |
|
|
| #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| namespace vllm::c2x { |
|
|
| using namespace cute; |
|
|
| |
| |
| |
| |
| template <typename ElementD, typename OutputTileThreadMap> |
| struct ScaledEpilogueBase { |
| protected: |
| using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; |
|
|
| template <typename T> |
| using ColOrScalarLoad = |
| cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< |
| OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>; |
|
|
| template <typename T> |
| using RowOrScalarLoad = |
| cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< |
| OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>; |
|
|
| template <typename T> |
| using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< |
| OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>; |
|
|
| template <typename T> |
| using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< |
| OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>; |
|
|
| template <typename T> |
| using RowOrZeroLoad = |
| cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< |
| OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>; |
|
|
| |
| |
| |
| template <typename Descriptor, typename T> |
| static auto args_from_tensor(torch::Tensor const& tensor) { |
| using Arguments = typename Descriptor::Arguments; |
| auto* data_ptr = static_cast<T*>(tensor.data_ptr()); |
| if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || |
| std::is_same_v<Descriptor, RowOrScalarLoad<T>>) { |
| return Arguments{data_ptr, tensor.numel() != 1}; |
| } else { |
| |
| static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>); |
| return Arguments{data_ptr}; |
| } |
| } |
|
|
| |
| |
| template <typename Descriptor, typename T> |
| static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) { |
| static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>); |
| using Arguments = typename Descriptor::Arguments; |
| auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; |
| return Arguments{data_ptr}; |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename ElementD, typename OutputTileThreadMap> |
| struct ScaledEpilogue |
| : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> { |
| private: |
| using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>; |
| using Accum = typename SUPER::Accum; |
| using ScaleA = typename SUPER::template ColOrScalarLoad<float>; |
| using ScaleB = typename SUPER::template RowOrScalarLoad<float>; |
|
|
| using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiplies, float, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTCompute0 = |
| cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>; |
|
|
| using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiplies, ElementD, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| public: |
| using EVTCompute = |
| cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>; |
| using ArgumentType = typename EVTCompute::Arguments; |
|
|
| static ArgumentType prepare_args(torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales) { |
| auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); |
| auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); |
|
|
| typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; |
| return ArgumentType{a_args, evt0_args, {}}; |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename ElementD, typename OutputTileThreadMap> |
| struct ScaledEpilogueBias |
| : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> { |
| protected: |
| using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>; |
| using Accum = typename SUPER::Accum; |
| using ScaleA = typename SUPER::template ColOrScalarLoad<float>; |
| using ScaleB = typename SUPER::template RowOrScalarLoad<float>; |
| using Bias = typename SUPER::template RowLoad<ElementD>; |
| using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiplies, float, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTCompute0 = |
| cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>; |
|
|
| using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiply_add, ElementD, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| public: |
| using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, |
| EVTCompute0, Bias>; |
| using ArgumentType = typename EVTCompute::Arguments; |
| static ArgumentType prepare_args(torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales, |
| torch::Tensor const& bias) { |
| auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); |
| auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); |
| auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); |
|
|
| typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; |
| return ArgumentType{a_args, evt0_args, bias_args, {}}; |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename ElementD, typename OutputTileThreadMap> |
| struct ScaledEpilogueBiasAzp |
| : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> { |
| private: |
| using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>; |
| using Accum = typename SUPER::Accum; |
| using ScaleA = typename SUPER::template ColOrScalarLoad<float>; |
| using ScaleB = typename SUPER::template RowOrScalarLoad<float>; |
| using Bias = typename SUPER::template RowOrZeroLoad<ElementD>; |
|
|
| |
| using AzpWithAdj = typename SUPER::template RowLoad<int32_t>; |
|
|
| |
| using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::minus, float, int32_t, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTComputeAzp = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>; |
|
|
| using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiplies, float, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTComputeScaleB = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB, |
| EVTComputeAzp>; |
|
|
| using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiply_add, ElementD, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| public: |
| using EVTCompute = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA, |
| EVTComputeScaleB, Bias>; |
|
|
| using ArgumentType = typename EVTCompute::Arguments; |
|
|
| static ArgumentType prepare_args(torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales, |
| torch::Tensor const& azp_adj, |
| std::optional<torch::Tensor> const& bias) { |
| auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); |
| auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); |
| auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); |
| auto azp_adj_args = |
| SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj); |
|
|
| typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; |
| typename EVTComputeScaleB::Arguments evt_scale_b_args{ |
| b_args, evt_azp_args, {}}; |
| return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <typename ElementD, typename OutputTileThreadMap> |
| struct ScaledEpilogueBiasAzpToken |
| : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> { |
| private: |
| using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>; |
| using Accum = typename SUPER::Accum; |
| using ScaleA = typename SUPER::template ColOrScalarLoad<float>; |
| using ScaleB = typename SUPER::template RowOrScalarLoad<float>; |
| using Bias = typename SUPER::template RowOrZeroLoad<ElementD>; |
|
|
| |
| using Azp = typename SUPER::template ColLoad<int32_t>; |
|
|
| |
| using AzpAdj = typename SUPER::template RowLoad<int32_t>; |
|
|
| |
| using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiplies, int32_t, int32_t, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTComputeAzp = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>; |
|
|
| |
| using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::minus, float, int32_t, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTComputeAcc = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>; |
|
|
| using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiplies, float, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| using EVTComputeScaleB = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB, |
| EVTComputeAcc>; |
|
|
| using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< |
| cutlass::multiply_add, ElementD, float, |
| cutlass::FloatRoundStyle::round_to_nearest>; |
|
|
| public: |
| using EVTCompute = |
| cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA, |
| EVTComputeScaleB, Bias>; |
|
|
| using ArgumentType = typename EVTCompute::Arguments; |
|
|
| static ArgumentType prepare_args(torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales, |
| torch::Tensor const& azp_adj, |
| torch::Tensor const& azp, |
| std::optional<torch::Tensor> const& bias) { |
| auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); |
| auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); |
| auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); |
| auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp); |
| auto azp_adj_args = |
| SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj); |
|
|
| typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; |
| typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; |
| typename EVTComputeScaleB::Arguments evt_scale_b_args{ |
| b_args, evt_acc_args, {}}; |
| return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; |
| } |
| }; |
|
|
| }; |
|
|